import random import numpy as np from numba import njit, prange import torch from torch import Tensor import torch.nn.functional as F from torch.utils.data import DataLoader import fastremap import fill_voids from skimage import morphology from skimage.segmentation import find_boundaries from scipy.ndimage import mean, find_objects from monai.data.dataset import Dataset from monai.transforms import * # type: ignore from monai.inferers.utils import sliding_window_inference from monai.metrics.cumulative_average import CumulativeAverage import matplotlib.pyplot as plt import matplotlib.colors as mcolors import os import glob import copy import tifffile as tiff from pprint import pformat from tabulate import tabulate from typing import Any, Dict, Literal, Optional, Tuple, List, Union from tqdm import tqdm import wandb from config import Config from core.models import * from core.losses import * from core.optimizers import * from core.schedulers import * from core.utils import ( compute_batch_segmentation_tp_fp_fn, compute_f1_score, compute_average_precision_score ) from core.logger import get_logger logger = get_logger() class CellSegmentator: def __init__(self, config: Config) -> None: self.__set_seed(config.dataset_config.common.seed) self.__parse_config(config) self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu") self._scaler = ( torch.amp.GradScaler(self._device.type) # type: ignore if self._dataset_setup.is_training and self._dataset_setup.common.use_amp else None ) self._train_dataloader: Optional[DataLoader] = None self._valid_dataloader: Optional[DataLoader] = None self._test_dataloader: Optional[DataLoader] = None self._predict_dataloader: Optional[DataLoader] = None def create_dataloaders( self, train_transforms: Optional[Compose] = None, valid_transforms: Optional[Compose] = None, test_transforms: Optional[Compose] = None, predict_transforms: Optional[Compose] = None ) -> None: """ Creates train, validation, test, and prediction dataloaders based on dataset configuration and provided transforms. Args: train_transforms (Optional[Compose]): Transformations for training data. valid_transforms (Optional[Compose]): Transformations for validation data. test_transforms (Optional[Compose]): Transformations for testing data. predict_transforms (Optional[Compose]): Transformations for prediction data. Raises: ValueError: If required transforms are missing. RuntimeError: If critical dataset config values are missing. """ if self._dataset_setup.is_training and train_transforms is None: raise ValueError("Training mode requires 'train_transforms' to be provided.") elif not self._dataset_setup.is_training and test_transforms is None and predict_transforms is None: raise ValueError("In inference mode, at least one of 'test_transforms' or 'predict_transforms' must be provided.") if self._dataset_setup.is_training: # Training mode: handle either pre-split datasets or splitting on the fly if self._dataset_setup.training.is_split: # Validate presence of validation transforms if validation directory and size are set if ( self._dataset_setup.training.pre_split.valid_dir and self._dataset_setup.training.valid_size and valid_transforms is None ): raise ValueError("Validation transforms must be provided when using pre-split validation data.") # Use explicitly split directories train_dir = self._dataset_setup.training.pre_split.train_dir valid_dir = self._dataset_setup.training.pre_split.valid_dir test_dir = self._dataset_setup.training.pre_split.test_dir train_offset = self._dataset_setup.training.train_offset valid_offset = self._dataset_setup.training.valid_offset test_offset = self._dataset_setup.training.test_offset shuffle = False else: # Same validation for split mode with full data directory if ( self._dataset_setup.training.split.all_data_dir and self._dataset_setup.training.valid_size and valid_transforms is None ): raise ValueError("Validation transforms must be provided when splitting dataset.") # Automatically split dataset from one directory train_dir = valid_dir = test_dir = self._dataset_setup.training.split.all_data_dir number_of_images = len(os.listdir(os.path.join(train_dir, 'images'))) if number_of_images == 0: raise FileNotFoundError(f"No images found in '{train_dir}/images'") # Calculate train/valid sizes train_size = ( self._dataset_setup.training.train_size if isinstance(self._dataset_setup.training.train_size, int) else int(number_of_images * self._dataset_setup.training.train_size) ) valid_size = ( self._dataset_setup.training.valid_size if isinstance(self._dataset_setup.training.valid_size, int) else int(number_of_images * self._dataset_setup.training.valid_size) ) train_offset = self._dataset_setup.training.train_offset valid_offset = self._dataset_setup.training.valid_offset + train_size test_offset = self._dataset_setup.training.test_offset + train_size + valid_size shuffle = True # Train dataloader train_dataset = self.__get_dataset( images_dir=os.path.join(train_dir, 'images'), masks_dir=os.path.join(train_dir, 'masks'), transforms=train_transforms, # type: ignore size=self._dataset_setup.training.train_size, offset=train_offset, shuffle=shuffle ) self._train_dataloader = DataLoader(train_dataset, batch_size=self._dataset_setup.training.batch_size, shuffle=True) logger.info(f"Loaded training dataset with {len(train_dataset)} samples.") # Validation dataloader if valid_transforms is not None: if not valid_dir or not self._dataset_setup.training.valid_size: raise RuntimeError("Validation directory or size is not properly configured.") valid_dataset = self.__get_dataset( images_dir=os.path.join(valid_dir, 'images'), masks_dir=os.path.join(valid_dir, 'masks'), transforms=valid_transforms, size=self._dataset_setup.training.valid_size, offset=valid_offset, shuffle=shuffle ) self._valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded validation dataset with {len(valid_dataset)} samples.") # Test dataloader if test_transforms is not None: if not test_dir or not self._dataset_setup.training.test_size: raise RuntimeError("Test directory or size is not properly configured.") test_dataset = self.__get_dataset( images_dir=os.path.join(test_dir, 'images'), masks_dir=os.path.join(test_dir, 'masks'), transforms=test_transforms, size=self._dataset_setup.training.test_size, offset=test_offset, shuffle=shuffle ) self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded test dataset with {len(test_dataset)} samples.") # Prediction dataloader if predict_transforms is not None: if not test_dir or not self._dataset_setup.training.test_size: raise RuntimeError("Prediction directory or size is not properly configured.") predict_dataset = self.__get_dataset( images_dir=os.path.join(test_dir, 'images'), masks_dir=None, transforms=predict_transforms, size=self._dataset_setup.training.test_size, offset=test_offset, shuffle=shuffle ) self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.") else: # Inference mode (no training) test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images') test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks') if test_transforms is not None: test_dataset = self.__get_dataset( images_dir=test_images, masks_dir=test_masks, transforms=test_transforms, size=self._dataset_setup.testing.test_size, offset=self._dataset_setup.testing.test_offset, shuffle=self._dataset_setup.testing.shuffle ) self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded test dataset with {len(test_dataset)} samples.") if predict_transforms is not None: predict_dataset = self.__get_dataset( images_dir=test_images, masks_dir=None, transforms=predict_transforms, size=self._dataset_setup.testing.test_size, offset=self._dataset_setup.testing.test_offset, shuffle=self._dataset_setup.testing.shuffle ) self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.") def print_data_info( self, loader_type: Literal["train", "valid", "test", "predict"], index: Optional[int] = None ) -> None: """ Prints statistics for a single sample from the specified dataloader. Args: loader_type: One of "train", "valid", "test", "predict". index: The sample index; if None, a random index is selected. """ # Retrieve the dataloader attribute, e.g., self._train_dataloader loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None) if loader is None: logger.error(f"Dataloader '{loader_type}' is not initialized.") return dataset = loader.dataset total = len(dataset) # type: ignore if total == 0: logger.error(f"Dataset for '{loader_type}' is empty.") return # Choose index idx = index if index is not None else random.randint(0, total - 1) if not (0 <= idx < total): logger.error(f"Index {idx} is out of range [0, {total}).") return # Fetch the sample and apply transforms sample = dataset[idx] # Expecting a dict with {'image': ..., 'mask': ...} or {'image': ...} img = sample["image"] # Convert tensor to numpy if needed if isinstance(img, torch.Tensor): img = img.cpu().numpy() img = np.asarray(img) # Compute image statistics img_min, img_max = img.min(), img.max() img_mean, img_std = float(img.mean()), float(img.std()) img_shape = img.shape # Prepare log lines lines = [ "=" * 40, f"Dataloader: '{loader_type}', sample index: {idx} / {total - 1}", f"Image — shape: {img_shape}, min: {img_min:.4f}, max: {img_max:.4f}, mean: {img_mean:.4f}, std: {img_std:.4f}" ] # For 'predict', no mask is available if loader_type != "predict": mask = sample.get("mask", None) if mask is not None: # Convert tensor to numpy if needed if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() mask = np.asarray(mask) m_min, m_max = mask.min(), mask.max() m_mean, m_std = float(mask.mean()), float(mask.std()) m_shape = mask.shape lines.append( f"Mask — shape: {m_shape}, min: {m_min:.4f}, " f"max: {m_max:.4f}, mean: {m_mean:.4f}, std: {m_std:.4f}" ) else: lines.append("Mask — not available for this sample.") lines.append("=" * 40) # Output via logger for l in lines: logger.info(l) def train(self) -> None: """ Train the model over multiple epochs, including validation and test. """ # Ensure training is enabled in dataset setup if not self._dataset_setup.is_training: raise RuntimeError("Dataset setup indicates training is disabled.") # Determine device name for logging if self._device.type == "cpu": device_name = "cpu" else: idx = self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device() device_name = torch.cuda.get_device_name(idx) logger.info(f"\n{'=' * 50}\n") logger.info(f"Training starts on device: {device_name}") logger.info(f"\n{'=' * 50}") best_f1_score = 0.0 best_weights = None for epoch in range(1, self._dataset_setup.training.num_epochs + 1): train_metrics = self.__run_epoch("train", epoch) self.__print_with_logging(train_metrics, epoch) # Step the scheduler after training if self._scheduler is not None: self._scheduler.step() # Periodic validation or tuning if epoch % self._dataset_setup.training.val_freq == 0: if self._valid_dataloader is not None: valid_metrics = self.__run_epoch("valid", epoch) self.__print_with_logging(valid_metrics, epoch) # Update best model on improved F1 f1 = valid_metrics.get("valid_f1_score", 0.0) if f1 > best_f1_score: best_f1_score = f1 # Deep copy weights to avoid reference issues best_weights = copy.deepcopy(self._model.state_dict()) logger.info(f"Updated best model weights with F1 score: {f1:.4f}") # Restore best model weights if available if best_weights is not None: self._model.load_state_dict(best_weights) if self._test_dataloader is not None: test_metrics = self.__run_epoch("test") self.__print_with_logging(test_metrics, 0) def evaluate(self) -> None: """ Run a full test epoch and display/log the resulting metrics. """ test_metrics = self.__run_epoch("test") self.__print_with_logging(test_metrics, 0) def predict(self) -> None: """ Run inference on the predict set and save the resulting instance masks. """ # Ensure the predict DataLoader has been set if self._predict_dataloader is None: raise RuntimeError("DataLoader for mode 'predict' is not set.") batch_counter = 0 for batch in tqdm(self._predict_dataloader, desc="Predicting"): # Move input images to the configured device (CPU/GPU) inputs = batch["img"].to(self._device) # Use automatic mixed precision if enabled in dataset setup with torch.amp.autocast( # type: ignore self._device.type, enabled=self._dataset_setup.common.use_amp ): # Disable gradient computation for inference with torch.no_grad(): # Run the model’s forward pass in ‘predict’ mode raw_output = self.__run_inference(inputs, mode="predict") # Convert logits/probabilities to discrete instance masks # ground_truth is not passed here; only predictions are needed preds, _ = self.__post_process_predictions(raw_output) # Save out the predicted masks, using batch_counter to index files self.__save_prediction_masks(batch, preds, batch_counter) # Increment counter by batch size for unique file naming batch_counter += inputs.shape[0] def run(self) -> None: """ Orchestrate the full workflow: - If training is enabled in the dataset setup, start training. - Otherwise, if a test DataLoader is provided, run evaluation. - Else if a prediction DataLoader is provided, run inference/prediction. - If neither loader is available in non‐training mode, raise an error. """ # 1) TRAINING PATH if self._dataset_setup.is_training: # Launch the full training loop (with validation, scheduler steps, etc.) self.train() return # 2) NON-TRAINING PATH (TEST or PREDICT) # Prefer test if available if self._test_dataloader is not None: # Run a single evaluation epoch on the test set and log metrics self.evaluate() return # If no test loader, fall back to prediction if available if self._predict_dataloader is not None: # Run inference on the predict set and save outputs self.predict() return # 3) ERROR: no appropriate loader found raise RuntimeError( "Neither test nor predict DataLoader is set for non‐training mode." ) def load_from_checkpoint(self, checkpoint_path: str) -> None: """ Loads model weights from a specified checkpoint into the current model. Args: checkpoint_path (str): Path to the checkpoint file containing the model weights. """ # Load the checkpoint onto the correct device (CPU or GPU) checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True) # Load the state dict into the model, allowing for missing keys self._model.load_state_dict(checkpoint['state_dict'], strict=False) def save_checkpoint(self, checkpoint_path: str) -> None: """ Saves the current model weights to a checkpoint file. Args: checkpoint_path (str): Path where the checkpoint file will be saved. """ # Create a checkpoint dictionary containing the model’s state_dict checkpoint = { 'state_dict': self._model.state_dict() } # Write the checkpoint to disk torch.save(checkpoint, checkpoint_path) def __parse_config(self, config: Config) -> None: """ Parses the given configuration object to initialize model, criterion, optimizer, scheduler, and dataset setup. Args: config (Config): Configuration object with model, optimizer, scheduler, criterion, and dataset setup information. """ model = config.model criterion = config.criterion optimizer = config.optimizer scheduler = config.scheduler logger.info("========== Parsed Configuration ==========") logger.info("Model Config:\n%s", pformat(model.dump(), indent=2)) if criterion: logger.info("Criterion Config:\n%s", pformat(criterion.dump(), indent=2)) if optimizer: logger.info("Optimizer Config:\n%s", pformat(optimizer.dump(), indent=2)) if scheduler: logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2)) logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2)) logger.info("==========================================") # Initialize model using the model registry self._model = ModelRegistry.get_model_class(model.name)(model.params) # Loads model weights from a specified checkpoint if config.dataset_config.is_training: if config.dataset_config.training.pretrained_weights: self.load_from_checkpoint(config.dataset_config.training.pretrained_weights) # Initialize loss criterion if specified self._criterion = ( CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params) if criterion is not None else None ) if hasattr(self._criterion, "num_classes"): nc_model = self._model.num_classes nc_crit = getattr(self._criterion, "num_classes") if nc_model != nc_crit: raise ValueError( f"Number of classes mismatch: model.num_classes={nc_model} " f"but criterion.num_classes={nc_crit}" ) # Initialize optimizer if specified self._optimizer = ( OptimizerRegistry.get_optimizer_class(optimizer.name)( model_params=self._model.parameters(), optim_params=optimizer.params ) if optimizer is not None else None ) # Initialize scheduler only if both scheduler and optimizer are defined self._scheduler = ( SchedulerRegistry.get_scheduler_class(scheduler.name)( optimizer=self._optimizer.optim, params=scheduler.params ) if scheduler is not None and self._optimizer is not None and self._optimizer.optim is not None else None ) logger.info("========== Model Components Initialization ==========") logger.info("├─ Model: " + (f"{model.name}" if self._model else "Not specified")) logger.info("├─ Criterion: " + (f"{criterion.name}" if self._criterion else "Not specified")) # type: ignore logger.info("├─ Optimizer: " + (f"{optimizer.name}" if self._optimizer else "Not specified")) # type: ignore logger.info("└─ Scheduler: " + (f"{scheduler.name}" if self._scheduler else "Not specified")) # type: ignore logger.info("=====================================================") # Save dataset config self._dataset_setup = config.dataset_config common = config.dataset_config.common logger.info("========== Dataset Setup ==========") logger.info("[COMMON]") logger.info(f"├─ Seed: {common.seed}") logger.info(f"├─ Device: {common.device}") logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}") logger.info(f"└─ Predictions output dir: {common.predictions_dir}") if config.dataset_config.is_training: training = config.dataset_config.training logger.info("[MODE] Training") logger.info(f"├─ Batch size: {training.batch_size}") logger.info(f"├─ Epochs: {training.num_epochs}") logger.info(f"├─ Validation frequency: {training.val_freq}") logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}") if training.is_split: logger.info(f"├─ Using pre-split directories:") logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}") logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}") logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}") else: logger.info(f"├─ Using unified dataset with splits:") logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}") logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}") logger.info(f"└─ Dataset split:") logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}") logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}") logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}") else: testing = config.dataset_config.testing logger.info("[MODE] Inference") logger.info(f"├─ Test dir: {testing.test_dir}") logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})") logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}") logger.info(f"├─ Use ensemble: {'yes' if testing.use_ensemble else 'no'}") logger.info(f"└─ Pretrained weights:") logger.info(f" ├─ Single model: {testing.pretrained_weights}") logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}") logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}") logger.info("===================================") def __set_seed(self, seed: Optional[int]) -> None: """ Sets the random seed for reproducibility across Python, NumPy, and PyTorch. Args: seed (Optional[int]): Seed value. If None, no seeding is performed. """ if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False logger.info(f"Random seed set to {seed}") else: logger.info("Seed not set (None provided)") def __get_dataset( self, images_dir: str, masks_dir: Optional[str], transforms: Compose, size: Union[int, float], offset: int, shuffle: bool ) -> Dataset: """ Loads and returns a dataset object from image and optional mask directories. Args: images_dir (str): Path to directory or glob pattern for input images. masks_dir (Optional[str]): Path to directory or glob pattern for masks. transforms (Compose): Transformations to apply to each image or pair. size (Union[int, float]): Either an integer or a fraction of the dataset. offset (int): Number of images to skip from the start. shuffle (bool): Whether to shuffle the dataset before slicing. Returns: Dataset: A dataset containing image and optional mask paths. Raises: FileNotFoundError: If no images are found. ValueError: If masks are provided but do not match image count. ValueError: If dataset is too small for requested size or offset. """ # Collect sorted list of image paths images = sorted(glob.glob(images_dir)) if not images: raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'") if masks_dir is not None: # Collect and validate sorted list of mask paths masks = sorted(glob.glob(masks_dir)) if len(images) != len(masks): raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})") # Convert float size (fraction) to absolute count size = size if isinstance(size, int) else int(size * len(images)) if size <= 0: raise ValueError(f"Size must be positive, got: {size}") if len(images) < size: raise ValueError(f"Not enough images ({len(images)}) for requested size ({size})") if len(images) < size + offset: raise ValueError(f"Offset ({offset}) + size ({size}) exceeds dataset length ({len(images)})") # Shuffle image-mask pairs if requested if shuffle: if masks_dir is not None: combined = list(zip(images, masks)) # type: ignore random.shuffle(combined) images, masks = zip(*combined) else: random.shuffle(images) # Apply offset and limit by size images = images[offset: offset + size] if masks_dir is not None: masks = masks[offset: offset + size] # type: ignore # Prepare data structure for Dataset class if masks_dir is not None: data = [ {"image": image, "mask": mask} for image, mask in zip(images, masks) # type: ignore ] else: data = [{"image": image} for image in images] return Dataset(data, transforms) def __print_with_logging(self, results: Dict[str, float], step: int) -> None: """ Print metrics in a tabular format and log to W&B. Args: results (Dict[str, float]): results dictionary. step (int): epoch index. """ table = tabulate( tabular_data=results.items(), headers=["Metric", "Value"], floatfmt=".4f", tablefmt="fancy_grid" ) print(table, "\n") wandb.log(results, step=step) def __run_epoch(self, mode: Literal["train", "valid", "test"], epoch: Optional[int] = None ) -> Dict[str, float]: """ Execute one epoch of training, validation, or testing. Args: mode (str): One of 'train', 'valid', or 'test'. epoch (int, optional): Current epoch number for logging. Returns: Dict[str, float]: Loss metrics and F1 score for valid/test. """ # Ensure required components are available if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None): raise RuntimeError("Optimizer and loss function must be initialized for train/valid.") # Set model mode and choose the appropriate data loader if mode == "train": self._model.train() loader = self._train_dataloader else: self._model.eval() loader = self._valid_dataloader if mode == "valid" else self._test_dataloader # Check that the data loader is provided if loader is None: raise RuntimeError(f"DataLoader for mode '{mode}' is not set.") all_tp, all_fp, all_fn = [], [], [] # Prepare tqdm description with epoch info if available if epoch is not None: desc = f"Epoch {epoch}/{self._dataset_setup.training.num_epochs} [{mode}]" else: desc = f"Epoch ({mode})" # Iterate over batches batch_counter = 0 for batch in tqdm(loader, desc=desc): inputs = batch["img"].to(self._device) targets = batch["label"].to(self._device) # Zero gradients for training if self._optimizer is not None: self._optimizer.zero_grad() # Mixed precision context if enabled with torch.amp.autocast( # type: ignore self._device.type, enabled=self._dataset_setup.common.use_amp ): # Only compute gradients in training phase with torch.set_grad_enabled(mode == "train"): # Forward pass raw_output = self.__run_inference(inputs, mode) if self._criterion is not None: # Convert label masks to flow representations (one-hot) flow_targets = self.__compute_flows_from_masks(targets) # Compute loss for this batch batch_loss = self._criterion(raw_output, flow_targets) # type: ignore # Post-process and compute F1 during validation and testing if mode in ("valid", "test"): preds, labels_post = self.__post_process_predictions( raw_output, targets ) # Collecting statistics on the batch tp, fp, fn = self.__compute_stats( predicted_masks=preds, ground_truth_masks=labels_post, # type: ignore iou_threshold=0.5 ) all_tp.append(tp) all_fp.append(fp) all_fn.append(fn) if mode == "test": self.__save_prediction_masks( batch, preds, batch_counter ) # Backpropagation and optimizer step in training if mode == "train": if self._dataset_setup.common.use_amp and self._scaler is not None: self._scaler.scale(batch_loss).backward() # type: ignore self._scaler.unscale_(self._optimizer.optim) # type: ignore self._scaler.step(self._optimizer.optim) # type: ignore self._scaler.update() else: batch_loss.backward() # type: ignore if self._optimizer is not None: self._optimizer.step() batch_counter += inputs.shape[0] if self._criterion is not None: # Collect loss metrics epoch_metrics = {f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()} # Reset internal loss metrics accumulator self._criterion.reset_metrics() else: epoch_metrics = {} # Include F1 and mAP for validation and testing if mode in ("valid", "test"): # Concatenating by batch: shape (num_batches*B, C) tp_array = np.vstack(all_tp) fp_array = np.vstack(all_fp) fn_array = np.vstack(all_fn) epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore tp_array, fp_array, fn_array, reduction="micro" ) epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore tp_array, fp_array, fn_array, reduction="macro" ) return epoch_metrics def __run_inference( self, inputs: torch.Tensor, mode: Literal["train", "valid", "test", "predict"] = "train" ) -> torch.Tensor: """ Perform model inference for different stages. Args: inputs (torch.Tensor): Input tensor of shape (B, C, H, W). stage (Literal[...]): One of "train", "valid", "test", "predict". Returns: torch.Tensor: Model outputs tensor. """ if mode != "train": # Use sliding window inference for non-training phases outputs = sliding_window_inference( inputs, roi_size=512, sw_batch_size=4, predictor=self._model, padding_mode="constant", mode="gaussian", overlap=0.5, ) else: # Direct forward pass during training outputs = self._model(inputs) return outputs # type: ignore def __post_process_predictions( self, raw_outputs: torch.Tensor, ground_truth: Optional[torch.Tensor] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Post-process raw network outputs to extract instance segmentation masks. Args: raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W). ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W). Returns: Tuple[np.ndarray, Optional[np.ndarray]]: - instance_masks: Instance-wise masks array of shape (B, С, H, W). - labels_np: Converted ground truth of shape (B, С, H, W) or None if ground_truth was not provided. """ # Move outputs to CPU and convert to numpy outputs_np = raw_outputs.cpu().numpy() # Split channels: gradient flows then class logits gradflow = outputs_np[:, :2 * self._model.num_classes] logits = outputs_np[:, -self._model.num_classes :] # Apply sigmoid to logits to get probabilities probabilities = self.__sigmoid(logits) batch_size, _, height, width = probabilities.shape # Prepare container for instance masks instance_masks = np.zeros((batch_size, self._model.num_classes, height, width), dtype=np.uint16) for idx in range(batch_size): instance_masks[idx] = self.__segment_instances( probability_map=probabilities[idx], flow=gradflow[idx], prob_threshold=0.0, flow_threshold=0.4, min_object_size=15 ) # Convert ground truth to numpy labels_np = ground_truth.cpu().numpy() if ground_truth is not None else None return instance_masks, labels_np def __compute_stats( self, predicted_masks: np.ndarray, ground_truth_masks: np.ndarray, iou_threshold: float = 0.5 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Compute batch-wise true positives, false positives, and false negatives for instance segmentation, using a configurable IoU threshold. Args: predicted_masks (np.ndarray): Predicted instance masks of shape (B, C, H, W). ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W). iou_threshold (float): Intersection-over-Union threshold for matching predictions to ground truths (default: 0.5). Returns: Tuple[np.ndarray, np.ndarray, np.ndarray]: - tp: True positives per batch and class, shape (B, C) - fp: False positives per batch and class, shape (B, C) - fn: False negatives per batch and class, shape (B, C) """ stats = compute_batch_segmentation_tp_fp_fn( batch_ground_truth=ground_truth_masks, batch_prediction=predicted_masks, iou_threshold=iou_threshold, remove_boundary_objects=True ) tp = stats["tp"] fp = stats["fp"] fn = stats["fn"] return tp, fp, fn def __compute_f1_metric( self, true_positives: np.ndarray, false_positives: np.ndarray, false_negatives: np.ndarray, reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro" ) -> Union[float, np.ndarray]: """ Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes. Args: true_positives: array of TP counts per sample and class. false_positives: array of FP counts per sample and class. false_negatives: array of FN counts per sample and class. reduction: - 'none': return F1 for each sample, class → shape (batch_size, num_classes) - 'micro': global F1 over all samples & classes - 'imagewise': F1 per sample (summing over classes), then average over samples - 'macro': average class-wise F1 (classes summed over batch) - 'weighted': class-wise F1 weighted by support (TP+FN) Returns: float for reductions 'micro', 'imagewise', 'macro', 'weighted'; or np.ndarray of shape (batch_size, num_classes) if reduction='none'. """ batch_size, num_classes = true_positives.shape # 1) No reduction: compute F1 for each (sample, class) if reduction == "none": f1_matrix = np.zeros((batch_size, num_classes), dtype=float) for i in range(batch_size): for c in range(num_classes): tp_val = int(true_positives[i, c]) fp_val = int(false_positives[i, c]) fn_val = int(false_negatives[i, c]) _, _, f1_val = compute_f1_score(tp_val, fp_val, fn_val) f1_matrix[i, c] = f1_val return f1_matrix # 2) Micro: sum all TP/FP/FN and compute a single F1 if reduction == "micro": tp_total = int(true_positives.sum()) fp_total = int(false_positives.sum()) fn_total = int(false_negatives.sum()) _, _, f1_global = compute_f1_score(tp_total, fp_total, fn_total) return f1_global # 3) Imagewise: compute per-sample F1 (sum over classes), then average if reduction == "imagewise": f1_per_image = np.zeros(batch_size, dtype=float) for i in range(batch_size): tp_i = int(true_positives[i].sum()) fp_i = int(false_positives[i].sum()) fn_i = int(false_negatives[i].sum()) _, _, f1_i = compute_f1_score(tp_i, fp_i, fn_i) f1_per_image[i] = f1_i return float(f1_per_image.mean()) # For macro/weighted, first aggregate per class across the batch tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,) fp_per_class = false_positives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int) # 4) Macro: average F1 across classes equally if reduction == "macro": f1_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): _, _, f1_c = compute_f1_score( tp_per_class[c], fp_per_class[c], fn_per_class[c] ) f1_per_class[c] = f1_c return float(f1_per_class.mean()) # 5) Weighted: class-wise F1 weighted by support = TP + FN if reduction == "weighted": f1_per_class = np.zeros(num_classes, dtype=float) support = np.zeros(num_classes, dtype=float) for c in range(num_classes): tp_c = tp_per_class[c] fp_c = fp_per_class[c] fn_c = fn_per_class[c] _, _, f1_c = compute_f1_score(tp_c, fp_c, fn_c) f1_per_class[c] = f1_c support[c] = tp_c + fn_c total_support = support.sum() if total_support == 0: # fallback to unweighted macro if no positives return float(f1_per_class.mean()) weights = support / total_support return float((f1_per_class * weights).sum()) raise ValueError(f"Unknown reduction mode: {reduction}") def __compute_average_precision_metric( self, true_positives: np.ndarray, false_positives: np.ndarray, false_negatives: np.ndarray, reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro" ) -> Union[float, np.ndarray]: """ Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes. AP is defined here as: AP = TP / (TP + FP + FN) Args: true_positives: array of true positives per sample and class. false_positives: array of false positives per sample and class. false_negatives: array of false negatives per sample and class. reduction: - 'none': return AP for each sample and class → shape (batch_size, num_classes) - 'micro': global AP over all samples & classes - 'imagewise': AP per sample (summing stats over classes), then average over batch - 'macro': average class-wise AP (each class summed over batch) - 'weighted': class-wise AP weighted by support (TP+FN) Returns: float for reductions 'micro', 'imagewise', 'macro', 'weighted'; or np.ndarray of shape (batch_size, num_classes) if reduction='none'. """ batch_size, num_classes = true_positives.shape # 1) No reduction: AP per (sample, class) if reduction == "none": ap_matrix = np.zeros((batch_size, num_classes), dtype=float) for i in range(batch_size): for c in range(num_classes): tp_val = int(true_positives[i, c]) fp_val = int(false_positives[i, c]) fn_val = int(false_negatives[i, c]) ap_val = compute_average_precision_score(tp_val, fp_val, fn_val) ap_matrix[i, c] = ap_val return ap_matrix # 2) Micro: sum all TP/FP/FN and compute one AP if reduction == "micro": tp_total = int(true_positives.sum()) fp_total = int(false_positives.sum()) fn_total = int(false_negatives.sum()) return compute_average_precision_score(tp_total, fp_total, fn_total) # 3) Imagewise: compute per-sample AP (sum over classes), then mean if reduction == "imagewise": ap_per_image = np.zeros(batch_size, dtype=float) for i in range(batch_size): tp_i = int(true_positives[i].sum()) fp_i = int(false_positives[i].sum()) fn_i = int(false_negatives[i].sum()) ap_per_image[i] = compute_average_precision_score(tp_i, fp_i, fn_i) return float(ap_per_image.mean()) # For macro and weighted: first aggregate per class across batch tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,) fp_per_class = false_positives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int) # 4) Macro: average AP across classes equally if reduction == "macro": ap_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): ap_per_class[c] = compute_average_precision_score( tp_per_class[c], fp_per_class[c], fn_per_class[c] ) return float(ap_per_class.mean()) # 5) Weighted: class-wise AP weighted by support = TP + FN if reduction == "weighted": ap_per_class = np.zeros(num_classes, dtype=float) support = np.zeros(num_classes, dtype=float) for c in range(num_classes): tp_c = tp_per_class[c] fp_c = fp_per_class[c] fn_c = fn_per_class[c] ap_per_class[c] = compute_average_precision_score(tp_c, fp_c, fn_c) support[c] = tp_c + fn_c total_support = support.sum() if total_support == 0: # fallback to unweighted macro if no positive instances return float(ap_per_class.mean()) weights = support / total_support return float((ap_per_class * weights).sum()) raise ValueError(f"Unknown reduction mode: {reduction}") @staticmethod def __sigmoid(z: np.ndarray) -> np.ndarray: """ Numerically stable sigmoid activation for numpy arrays. Args: z (np.ndarray): Input array. Returns: np.ndarray: Sigmoid of the input. """ return 1 / (1 + np.exp(-z)) def __save_prediction_masks( self, sample: Dict[str, Any], predicted_mask: Union[np.ndarray, torch.Tensor], start_index: int = 0, ) -> None: """ Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders. Args: sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict'). predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W). start_index (int): Starting index for naming when metadata is missing. """ # Determine base paths base_output_dir = self._dataset_setup.common.predictions_dir masks_dir = base_output_dir plots_dir = os.path.join(base_output_dir, "plots") os.makedirs(masks_dir, exist_ok=True) os.makedirs(plots_dir, exist_ok=True) # Extract image (C, H, W) or batch of images (B, C, H, W), and metadata image_obj = sample.get("image") # Expected shape: (C, H, W) or (B, C, H, W) mask_obj = sample.get("mask") # Expected shape: (C, H, W) or (B, C, H, W) image_meta = sample.get("image_meta_dict") # Convert tensors to numpy def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(x, torch.Tensor): return x.detach().cpu().numpy() return x image_array = to_numpy(image_obj) if image_obj is not None else None mask_array = to_numpy(mask_obj) if mask_obj is not None else None pred_array = to_numpy(predicted_mask) # Handle batch dimension: (B, C, H, W) if pred_array.ndim == 4: for idx in range(pred_array.shape[0]): batch_sample = dict(sample) if image_array is not None and image_array.ndim == 4: batch_sample["image"] = image_array[idx] if isinstance(image_meta, list): batch_sample["image_meta_dict"] = image_meta[idx] if mask_array is not None and mask_array.ndim == 4: batch_sample["mask"] = mask_array[idx] self.__save_prediction_masks( batch_sample, pred_array[idx], start_index=start_index+idx ) return # Determine base filename if image_meta and "filename_or_obj" in image_meta: base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0] else: # Use provided start_index when metadata missing base_name = f"prediction_{start_index:04d}" # Now pred_array shape is (C, H, W) num_channels = pred_array.shape[0] for channel_idx in range(num_channels): channel_mask = pred_array[channel_idx] # File names mask_filename = f"{base_name}_ch{channel_idx:02d}.tif" plot_filename = f"{base_name}_ch{channel_idx:02d}.png" mask_path = os.path.join(masks_dir, mask_filename) plot_path = os.path.join(plots_dir, plot_filename) # Save mask TIFF (16-bit) tiff.imwrite(mask_path, channel_mask.astype(np.uint16), compression="zlib") # Extract corresponding true mask channel if exists true_mask = None if mask_array is not None and mask_array.ndim == 3: true_mask = mask_array[channel_idx] # Generate and save visualization self.__plot_mask( file_path=plot_path, image_data=image_array, # type: ignore predicted_mask=channel_mask, true_mask=true_mask, ) def __plot_mask( self, file_path: str, image_data: np.ndarray, predicted_mask: np.ndarray, true_mask: Optional[np.ndarray] = None, ) -> None: """ Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided. """ img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data if true_mask is None: fig, axs = plt.subplots(1, 3, figsize=(15,5)) plt.subplots_adjust(wspace=0.02, hspace=0) self.__plot_panels(axs, img, predicted_mask, 'red', ('Original Image','Predicted Mask','Predicted Contours')) else: fig, axs = plt.subplots(2,3,figsize=(15,10)) plt.subplots_adjust(wspace=0.02, hspace=0.1) # row 0: predicted self.__plot_panels(axs[0], img, predicted_mask, 'red', ('Original Image','Predicted Mask','Predicted Contours')) # row 1: true self.__plot_panels(axs[1], img, true_mask, 'blue', ('Original Image','True Mask','True Contours')) fig.savefig(file_path, bbox_inches='tight', dpi=300) plt.close(fig) def __plot_panels( self, axes, img: np.ndarray, mask: np.ndarray, contour_color: str, titles: Tuple[str, ...] ): """ Plot a row of three panels: original image, mask, and mask boundaries on image. Args: axes: list/array of three Axis objects. img: Image array (H, W or H, W, C). mask: Label mask (H, W). contour_color: Color for boundaries. titles: (title_img, title_mask, title_contours). """ # Panel 1: Original image ax0, ax1, ax2 = axes ax0.imshow(img, cmap='gray' if img.ndim == 2 else None) ax0.set_title(titles[0]); ax0.axis('off') # Compute boundaries once boundaries = find_boundaries(mask, mode='thick') # Panel 2: Mask with black boundaries cmap = plt.get_cmap("gist_ncar") cmap = mcolors.ListedColormap([ cmap(i/len(np.unique(mask))) for i in range(len(np.unique(mask))) ]) ax1.imshow(mask, cmap=cmap) ax1.contour(boundaries, colors='black', linewidths=0.5) ax1.set_title(titles[1]) ax1.axis('off') # Panel 3: Original image with black contour overlay ax2.imshow(img) # Draw boundaries as contour lines ax2.contour(boundaries, colors=contour_color, linewidths=0.5) ax2.set_title(titles[2]) ax2.axis('off') def __compute_flows_from_masks( self, true_masks: Tensor ) -> np.ndarray: """ Convert segmentation masks to flow fields for training. Args: true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks. Returns: numpy array of concatenated [flow_vectors, renumbered_true_masks] per image. renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow. """ # Move to CPU numpy _true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16) batch_size = _true_masks.shape[0] # Ensure each label has a channel dimension if _true_masks.ndim == 3: # shape (batch, H, W) -> (batch, 1, H, W) _true_masks = _true_masks[:, np.newaxis, :, :] batch_size, *_ = _true_masks.shape # Renumber labels to ensure uniqueness renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0] for i in range(batch_size)]) # Compute vector flows per image flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i]) for i in range(batch_size)]) return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32) def __compute_flow_from_mask( self, mask: np.ndarray ) -> np.ndarray: """ Compute normalized flow vectors from a labeled mask. Args: mask: 3D array of instance-labeled mask of shape (C, H, W). Returns: flow: Array of shape (2 * C, H, W). """ if mask.max() == 0 or np.count_nonzero(mask) <= 1: # No flow to compute logger.warning("Empty mask!") C, H, W = mask.shape return np.zeros((2*C, H, W), dtype=np.float32) # Delegate to GPU or CPU routine if self._device.type == "cuda" or self._device.type == "mps": return self.__mask_to_flow_gpu(mask) else: return self.__mask_to_flow_cpu(mask) def __mask_to_flow_gpu(self, mask: np.ndarray) -> np.ndarray: """Convert masks to flows using diffusion from center pixel. Center of masks where diffusion starts is defined by pixel closest to median within the mask. Args: masks (3D array): Labelled masks of shape (C, H, W). Returns: np.ndarray: A 3D array where for each channel the flows for each pixel are represented along the X and Y axes. """ channels, height, width = mask.shape flows = np.zeros((2*channels, height, width), np.float32) for channel in range(channels): padded_height, padded_width = height + 2, width + 2 # Pad the mask with a 1-pixel border masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device) masks_padded = F.pad(masks_padded, (1, 1, 1, 1)) # Get coordinates of all non-zero pixels in the padded mask y, x = torch.nonzero(masks_padded, as_tuple=True) y = y.int(); x = x.int() # ensure integer type # Generate 8-connected neighbors (including center) via broadcasted offsets offsets = torch.tensor([ [ 0, 0], # center [-1, 0], # up [ 1, 0], # down [ 0, -1], # left [ 0, 1], # right [-1, -1], # up-left [-1, 1], # up-right [ 1, -1], # down-left [ 1, 1], # down-right ], dtype=torch.int32, device=self._device) # (9, 2) # coords: (N, 2) coords = torch.stack((y, x), dim=1) # neighbors: (9, N, 2) neighbors = offsets[:, None, :] + coords[None, :, :] # transpose into (2, 9, N) for the GPU kernel neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index # Build connectivity mask: True where neighbor label == center label center_labels = masks_padded[y, x][None, :] # (1, N) neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N) is_neighbor = neighbor_labels == center_labels # (9, N) # Compute object slices and pack into array for get_centers slices = find_objects(mask) slices_arr = np.array([ [i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop] for i, sl in enumerate(slices) if sl is not None ], dtype=int) # Compute centers (pixel indices) and extents via the provided helper centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr) # Move centers to GPU and shift by +1 for padding meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding # Determine number of diffusion iterations n_iter = 2 * ext.max() # Run the GPU diffusion kernel mu = self.__propagate_centers_gpu( neighbor_indices=neighbors, center_indices=meds_p.T, valid_neighbor_mask=is_neighbor, output_shape=(padded_height, padded_width), num_iterations=n_iter ) # Cast to float64 and normalize flow vectors mu = mu.astype(np.float64) mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60 # Remove the padding and write into final output flow_output = np.zeros((2, height, width), dtype=np.float32) ys_np = y.cpu().numpy() - 1 xs_np = x.cpu().numpy() - 1 flow_output[:, ys_np, xs_np] = mu flows[2*channel: 2*channel + 2] = flow_output return flows @staticmethod @njit(nogil=True) def __get_mask_centers_and_extents( label_map: np.ndarray, slices_arr: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Compute the centroids and extents of labeled regions in a 2D mask array. Args: label_map (np.ndarray): 2D array where each connected region has a unique integer label (1…K). slices_arr (np.ndarray): Array of shape (K, 5), where each row is (label_id, row_start, row_stop, col_start, col_stop). Returns: centers (np.ndarray): Integer array of shape (K, 2) with (row, col) center for each label. extents (np.ndarray): Integer array of shape (K,) giving the sum of height and width + 2 for each region. """ num_regions = slices_arr.shape[0] centers = np.zeros((num_regions, 2), dtype=np.int32) extents = np.zeros(num_regions, dtype=np.int32) for idx in prange(num_regions): # Unpack slice info label_id = slices_arr[idx, 0] row_start = slices_arr[idx, 1] row_stop = slices_arr[idx, 2] col_start = slices_arr[idx, 3] col_stop = slices_arr[idx, 4] # Extract binary submask for this label submask = (label_map[row_start:row_stop, col_start:col_stop] == label_id) # Get local coordinates of all pixels in the region ys, xs = np.nonzero(submask) # Compute the floating-point centroid within the submask y_mean = ys.mean() x_mean = xs.mean() # Find the pixel closest to the centroid by minimizing squared distance dist_sq = (ys - y_mean) ** 2 + (xs - x_mean) ** 2 closest_idx = dist_sq.argmin() # Convert to global coordinates center_row = ys[closest_idx] + row_start center_col = xs[closest_idx] + col_start centers[idx, 0] = center_row centers[idx, 1] = center_col # Compute extent as height + width + 2 (to include one-pixel border) height = row_stop - row_start width = col_stop - col_start extents[idx] = height + width + 2 return centers, extents def __propagate_centers_gpu( self, neighbor_indices: torch.Tensor, center_indices: torch.Tensor, valid_neighbor_mask: torch.Tensor, output_shape: Tuple[int, int], num_iterations: int = 200 ) -> np.ndarray: """ Propagates center points across a mask using GPU-based diffusion. Args: neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel. center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers. valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid. output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W). num_iterations (int, optional): Number of diffusion iterations. Defaults to 200. Returns: np.ndarray: Array of shape (2, N) with the computed flows. """ # Determine total number of elements and choose dtype accordingly total_elems = torch.prod(torch.tensor(output_shape)) if total_elems > 4e7 or self._device.type == "mps": diffusion_tensor = torch.zeros(output_shape, dtype=torch.float, device=self._device) else: diffusion_tensor = torch.zeros(output_shape, dtype=torch.double, device=self._device) # Unpack center row and column indices center_rows, center_cols = center_indices # Unpack neighbor row and column indices for 9 neighbors per pixel # Order: [0: center, 1: up, 2: down, 3: left, 4: right, # 5: up-left, 6: up-right, 7: down-left, 8: down-right] neigh_rows, neigh_cols = neighbor_indices # each of shape (9, N) # Perform diffusion iterations for _ in range(num_iterations): # Add source at each mask center diffusion_tensor[center_rows, center_cols] += 1 # Sample neighbor values for each pixel neighbor_vals = diffusion_tensor[neigh_rows, neigh_cols] # shape (9, N) # Zero out invalid neighbors neighbor_vals *= valid_neighbor_mask # Update the first neighbor (index 0) with the average of valid neighbor values diffusion_tensor[neigh_rows[0], neigh_cols[0]] = neighbor_vals.mean(dim=0) # Compute spatial gradients for 2D flow: dy and dx # Using neighbor indices: up = 1, down = 2, left = 3, right = 4 grad_samples = diffusion_tensor[ neigh_rows[[2, 1, 4, 3]], # indices [down, up, right, left] neigh_cols[[2, 1, 4, 3]] ] # shape (4, N) dy = grad_samples[0] - grad_samples[1] dx = grad_samples[2] - grad_samples[3] # Stack and convert to numpy flow field with shape (2, N) flow_field = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=0) return flow_field def __mask_to_flow_cpu(self, mask: np.ndarray) -> np.ndarray: """ Convert labeled masks to flow vectors by simulating diffusion from mask centers. Each mask's center is chosen as the pixel closest to its geometric centroid. A diffusion process is run on a padded local patch, and flows are derived as gradients (dy, dx) of the resulting density map. Args: masks (np.ndarray): 3D integer array of labels `(C x H x W)`, where 0 = background and positive integers = mask IDs. Returns: flow_field (np.ndarray): Array of shape `(2*C, H, W)` containing flow components [dy, dx] normalized per pixel. """ channels, height, width = mask.shape flows = np.zeros((2*channels, height, width), np.float32) for channel in range(channels): # Initialize flow_field with two channels: dy and dx flow_field = np.zeros((2, height, width), dtype=np.float64) # Find bounding box for each labeled mask mask_slices = find_objects(mask) # centers: List[Tuple[int, int]] = [] # Iterate over mask labels in parallel for label_idx in prange(len(mask_slices)): slc = mask_slices[label_idx] if slc is None: continue # Extract row and column slice for this mask row_slice, col_slice = slc # Add 1-pixel border around the patch patch_height = (row_slice.stop - row_slice.start) + 2 patch_width = (col_slice.stop - col_slice.start) + 2 # Get local coordinates of mask pixels within the patch local_rows, local_cols = np.nonzero( mask[row_slice, col_slice] == (label_idx + 1) ) # Shift coords by +1 for the border padding local_rows = local_rows.astype(np.int32) + 1 local_cols = local_cols.astype(np.int32) + 1 # Compute centroid and find nearest pixel as diffusion seed centroid_row = local_rows.mean() centroid_col = local_cols.mean() distances = (local_cols - centroid_col) ** 2 + (local_rows - centroid_row) ** 2 seed_index = distances.argmin() center_row = int(local_rows[seed_index]) center_col = int(local_cols[seed_index]) # Determine number of iterations total_iter = 2 * (patch_height + patch_width) # Initialize flat diffusion map for the local patch diffusion_map = np.zeros(patch_height * patch_width, dtype=np.float64) # Run diffusion from the seed center diffusion_map = self.__diffuse_from_center( diffusion_map, local_rows, local_cols, center_row, center_col, patch_width, total_iter ) # Compute flow as finite differences (gradient) on the diffusion map dy = ( diffusion_map[(local_rows + 1) * patch_width + local_cols] - diffusion_map[(local_rows - 1) * patch_width + local_cols] ) dx = ( diffusion_map[local_rows * patch_width + (local_cols + 1)] - diffusion_map[local_rows * patch_width + (local_cols - 1)] ) # Write flows back into the global flow_field array flow_field[0, row_slice.start + local_rows - 1, col_slice.start + local_cols - 1] = dy flow_field[1, row_slice.start + local_rows - 1, col_slice.start + local_cols - 1] = dx # Store center location in original image coordinates # centers.append( # (row_slice.start + center_row - 1, # col_slice.start + center_col - 1) # ) # Normalize each vector [dy,dx] by its magnitude magnitudes = np.sqrt((flow_field**2).sum(axis=0)) + 1e-60 flow_field /= magnitudes flows[2*channel: 2*channel + 2] = flow_field return flows @staticmethod @njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True) def __diffuse_from_center( diffusion_map: np.ndarray, row_coords: np.ndarray, col_coords: np.ndarray, center_row: int, center_col: int, patch_width: int, num_iterations: int ) -> np.ndarray: """ Perform diffusion of particles from a seed pixel across a local mask patch. At each iteration, one particle is added at the seed, and each mask pixel's value is updated to the average of itself and its 8-connected neighbors. Args: diffusion_map (np.ndarray): Flat array of length patch_height * patch_width. row_coords (np.ndarray): 1D array of row indices for mask pixels (local coords). col_coords (np.ndarray): 1D array of column indices for mask pixels (local coords). center_row (int): Row index of the seed point in local patch coords. center_col (int): Column index of the seed point in local patch coords. patch_width (int): Width (number of columns) in the local patch. num_iterations (int): Number of diffusion iterations to perform. Returns: np.ndarray: Updated diffusion_map after performing diffusion. """ # Compute linear indices for each mask pixel and its neighbors base_idx = row_coords * patch_width + col_coords up = (row_coords - 1) * patch_width + col_coords down = (row_coords + 1) * patch_width + col_coords left = row_coords * patch_width + (col_coords - 1) right = row_coords * patch_width + (col_coords + 1) up_left = (row_coords - 1) * patch_width + (col_coords - 1) up_right = (row_coords - 1) * patch_width + (col_coords + 1) down_left = (row_coords + 1) * patch_width + (col_coords - 1) down_right = (row_coords + 1) * patch_width + (col_coords + 1) for _ in range(num_iterations): # Inject one particle at the seed location diffusion_map[center_row * patch_width + center_col] += 1.0 # Update each mask pixel as the average over itself and neighbors diffusion_map[base_idx] = ( diffusion_map[base_idx] + diffusion_map[up] + diffusion_map[down] + diffusion_map[left] + diffusion_map[right] + diffusion_map[up_left] + diffusion_map[up_right] + diffusion_map[down_left] + diffusion_map[down_right] ) * (1.0 / 9.0) return diffusion_map def __segment_instances( self, probability_map: np.ndarray, flow: np.ndarray, prob_threshold: float = 0.0, flow_threshold: float = 0.4, num_iters: int = 200, min_object_size: int = 15 ) -> np.ndarray: """ Generate instance segmentation masks from probability and flow fields. Args: probability_map: 3D array (channels, height, width) of cell probabilities. flow: 3D array (2*channels, height, width) of forward flow vectors. prob_threshold: threshold to binarize probability_map. (Default 0.0) flow_threshold: threshold for filtering bad flow masks. (Default 0.4) num_iters: number of iterations for flow-following. (Default 200) min_object_size: minimum area to keep small instances. (Default 15) Returns: 3D array of uint16 instance labels for each channel. """ # Create a binary mask of likely cell locations probability_mask = probability_map > prob_threshold # If no cells exceed the threshold, return an empty mask if not np.any(probability_mask): logger.warning("No cell pixels found.") return np.zeros_like(probability_map, dtype=np.uint16) # Prepare output array for instance labels labeled_instances = np.zeros_like(probability_map, dtype=np.uint16) # Process each channel independently for channel_index in range(probability_mask.shape[0]): # Extract flow vectors for this channel (two components per channel) channel_flow_vectors = flow[2 * channel_index : 2 * channel_index + 2] # Extract binary mask for this channel channel_mask = probability_mask[channel_index] nonzero_coords = np.stack(np.nonzero(channel_mask)) # Follow the flow vectors to generate coordinate mappings flow_coordinates = self.__follow_flows( flow_field=channel_flow_vectors * channel_mask / 5.0, initial_coords=nonzero_coords, num_iters=num_iters ) # If flow following fails, leave this channel empty if flow_coordinates is None: labeled_instances[channel_index] = np.zeros( probability_map.shape[1:], dtype=np.uint16 ) continue if not torch.is_tensor(flow_coordinates): flow_coordinates = torch.from_numpy( flow_coordinates).to(self._device, dtype=torch.int32) else: flow_coordinates = flow_coordinates.int() # Obtain preliminary instance masks by clustering the coordinates channel_instances_mask = self.__get_mask( pixel_positions=flow_coordinates, valid_indices=nonzero_coords, original_shape=probability_map.shape[1:] ) # Filter out bad flow-derived instances if requested if channel_instances_mask.max() > 0 and flow_threshold > 0: channel_instances_mask = self.__remove_inconsistent_flow_masks( mask=channel_instances_mask, flow_network=channel_flow_vectors, error_threshold=flow_threshold ) # Remove small objects or holes below the minimum size if min_object_size > 0: # channel_instances_mask = morphology.remove_small_holes( # channel_instances_mask, area_threshold=min_object_size # ) # channel_instances_mask = morphology.remove_small_objects( # channel_instances_mask, min_size=min_object_size # ) channel_instances_mask = self.__fill_holes_and_prune_small_masks( channel_instances_mask, minimum_size=min_object_size ) labeled_instances[channel_index] = channel_instances_mask else: # No valid instances found, leave the channel empty labeled_instances[channel_index] = np.zeros( probability_map.shape[1:], dtype=np.uint16 ) return labeled_instances def __follow_flows( self, flow_field: np.ndarray, initial_coords: np.ndarray, num_iters: int = 200 ) -> Union[np.ndarray, torch.Tensor]: """ Trace pixel positions through a flow field via iterative interpolation. Args: flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors. initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions. num_iters (int): Number of integration steps. Returns: np.ndarray or torch.Tensor: Final (y, x) positions of each point. """ dims = 2 # Extract spatial dimensions height, width = flow_field.shape[1:] # Choose GPU/MPS path if available if self._device.type in ("cuda", "mps"): # Prepare point tensor: shape [1, 1, num_points, 2] pts = torch.zeros((1, 1, initial_coords.shape[1], dims), dtype=torch.float32, device=self._device) # Prepare flow volume: shape [1, 2, height, width] flow_vol = torch.zeros((1, dims, height, width), dtype=torch.float32, device=self._device) # Load initial positions and flow into tensors (flip order for grid_sample) # dim 0 = x # dim 1 = y for i in range(dims): pts[0, 0, :, dims - i - 1] = ( torch.from_numpy(initial_coords[i]) .to(self._device, torch.float32) ) flow_vol[0, dims - i - 1] = ( torch.from_numpy(flow_field[i]) .to(self._device, torch.float32) ) # Prepare normalization factors for x and y (max index) max_indices = torch.tensor([width - 1, height - 1], dtype=torch.float32, device=self._device) # Reshape for broadcasting to point tensor dims max_idx_pt = max_indices.view(1, 1, 1, dims) # Reshape for broadcasting to flow volume dims max_idx_flow = max_indices.view(1, dims, 1, 1) # Normalize flow values to [-1, 1] range flow_vol = (flow_vol * 2) / max_idx_flow # Normalize points to [-1, 1] pts = (pts / max_idx_pt) * 2 - 1 # Iterate: sample flow and update points for _ in range(num_iters): sampled = torch.nn.functional.grid_sample( flow_vol, pts, align_corners=False ) # Update each coordinate and clamp to valid range for i in range(dims): pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0) # Denormalize back to original pixel coordinates pts = (pts + 1) * 0.5 * max_idx_pt # Swap channels back to (y, x) and flatten final_pts = pts[..., [1, 0]].squeeze() # Convert from (num_points, 2) to (2, num_points) return final_pts.T if final_pts.ndim > 1 else final_pts.unsqueeze(0).T # CPU fallback using numpy and scipy current_pos = initial_coords.copy().astype(np.float32) temp_delta = np.zeros_like(current_pos, dtype=np.float32) for _ in range(num_iters): # Interpolate flow at current positions self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta) # Update positions and clamp to image bounds current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1) current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1) return current_pos @staticmethod @njit([ "(int16[:,:,:], float32[:], float32[:], float32[:,:])", "(float32[:,:,:], float32[:], float32[:], float32[:,:])" ], cache=True) def __map_coordinates( image_data: np.ndarray, y_coords: np.ndarray, x_coords: np.ndarray, output: np.ndarray ) -> None: """ Perform in-place bilinear interpolation on an image volume. Args: image_data (np.ndarray): Input volume with shape (C, H, W). y_coords (np.ndarray): Array of new y positions (num_points). x_coords (np.ndarray): Array of new x positions (num_points). output (np.ndarray): Output array of shape (C, num_points) to fill. Returns: None. Results written directly into `output`. """ channels, height, width = image_data.shape # Compute integer (floor) and fractional parts for coords y_floor = y_coords.astype(np.int32) x_floor = x_coords.astype(np.int32) y_frac = y_coords - y_floor x_frac = x_coords - x_floor # Loop over each sample point for idx in range(y_floor.shape[0]): # Clamp base indices to valid range y0 = min(max(y_floor[idx], 0), height - 1) x0 = min(max(x_floor[idx], 0), width - 1) y1 = min(y0 + 1, height - 1) x1 = min(x0 + 1, width - 1) wy = y_frac[idx] wx = x_frac[idx] # Interpolate per channel for c in range(channels): v00 = np.float32(image_data[c, y0, x0]) v10 = np.float32(image_data[c, y0, x1]) v01 = np.float32(image_data[c, y1, x0]) v11 = np.float32(image_data[c, y1, x1]) # Bilinear interpolation formula output[c, idx] = ( v00 * (1 - wy) * (1 - wx) + v10 * (1 - wy) * wx + v01 * wy * (1 - wx) + v11 * wy * wx ) def __get_mask( self, pixel_positions: torch.Tensor, valid_indices: np.ndarray, original_shape: Tuple[int, ...], pad_radius: int = 20, max_size_fraction: float = 0.4 ) -> np.ndarray: """ Generate labeled masks by clustering pixel trajectories via histogram peaks and region growing. This function executes the following steps: 1. Pads and clamps pixel final positions to avoid border effects. 2. Builds a dense histogram of pixel counts over spatial bins. 3. Identifies local maxima in the histogram as seed points. 4. Extracts local patches around each seed and grows regions by iteratively adding neighbors that exceed an intensity threshold. 5. Maps grown patches back to original image indices. 6. Removes any masks that exceed a maximum size fraction of the image. Args: pixel_positions (torch.Tensor): Tensor of shape `[2, N_pixels]`, dtype=int, containing final pixel coordinates after dynamics for each dimension. valid_indices (np.ndarray): Integer array of shape `[2, N_pixels]` giving indices of pixels in the original image grid. original_shape (tuple of ints): Spatial dimensions of the original image, e.g. (H, W). pad_radius (int): Number of zero-padding pixels added on each side of the histogram. Defaults to 20. max_size_fraction (float): If any mask has a pixel count > max_size_fraction * total_pixels, it will be removed. Defaults to 0.4. Returns: np.ndarray: Integer mask array of shape `original_shape` with labels 0 (background) and 1..M. Raises: ValueError: If input dimensions are inconsistent or pixel_positions shape is invalid. """ # Validate inputs ndim = len(original_shape) if pixel_positions.ndim != 2 or pixel_positions.size(0) != ndim: msg = f"pixel_positions must be shape [{ndim}, N], got {tuple(pixel_positions.shape)}" logger.error(msg) raise ValueError(msg) if pad_radius < 0: msg = f"pad_radius must be non-negative, got {pad_radius}" logger.error(msg) raise ValueError(msg) # Step 1: Pad and clamp pixel positions padded_positions = pixel_positions.clone().to(torch.int64) + pad_radius for dim in range(ndim): max_val = original_shape[dim] + pad_radius - 1 padded_positions[dim] = torch.clamp(padded_positions[dim], min=0, max=max_val) # Build histogram dimensions hist_shape = tuple(s + 2 * pad_radius for s in original_shape) # Step 2: Create sparse tensor and densify to get per-pixel counts try: counts_sparse = torch.sparse_coo_tensor( padded_positions, torch.ones(padded_positions.shape[1], dtype=torch.int32, device=pixel_positions.device), size=hist_shape ) histogram = counts_sparse.to_dense() except Exception as e: logger.error("Failed to build dense histogram: %s", e) raise # Step 3: Find peaks via 5x5 max-pooling k = 5 pooled = F.max_pool2d( histogram.unsqueeze(0), kernel_size=k, stride=1, padding=k // 2 ).squeeze() # Seeds are positions where histogram equals local max and count > threshold seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10)) if seed_positions.numel() == 0: logger.warning("No seeds found: returning empty mask") return np.zeros(original_shape, dtype=np.uint16) # Sort seeds by ascending count to process small peaks first seed_counts = histogram[tuple(seed_positions.T)] order = torch.argsort(seed_counts) seed_positions = seed_positions[order] del pooled, counts_sparse # Step 4: Extract local patches and perform region growing num_seeds = seed_positions.shape[0] # Tensor to hold local patches patches = torch.zeros((num_seeds, 11, 11), device=pixel_positions.device) for idx in range(num_seeds): coords = seed_positions[idx] slices = tuple(slice(c - 5, c + 6) for c in coords) patches[idx] = histogram[slices] del histogram # Initialize seed mask (center pixel of each patch) seed_masks = torch.zeros_like(patches, device=pixel_positions.device) seed_masks[:, 5, 5] = 1 # Iterative dilation and thresholding for _ in range(5): seed_masks = F.max_pool2d( seed_masks, kernel_size=3, stride=1, padding=1 ) seed_masks = seed_masks & (patches > 2) # Compute final mask coordinates final_coords = [] for idx in range(num_seeds): coords_local = torch.nonzero(seed_masks[idx]) # Shift back to global positions coords_global = coords_local + seed_positions[idx] - 5 final_coords.append(tuple(coords_global.T)) # Step 5: Paint masks into padded volume dtype = torch.int32 if num_seeds < 2**16 else torch.int64 mask_padded = torch.zeros(hist_shape, dtype=dtype, device=pixel_positions.device) for label_idx, coords in enumerate(final_coords, start=1): mask_padded[coords] = label_idx # Extract only the padded positions that correspond to original pixels mask_values = mask_padded[tuple(padded_positions)] mask_values = mask_values.cpu().numpy() # Step 6: Map to original image and remove oversized masks mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32) mask_final[valid_indices] = mask_values # Prune masks that are too large labels, counts = fastremap.unique(mask_final, return_counts=True) total_pixels = np.prod(original_shape) oversized = labels[counts > (total_pixels * max_size_fraction)] if oversized.size > 0: mask_final = fastremap.mask(mask_final, oversized) fastremap.renumber(mask_final, in_place=True) return mask_final def __remove_inconsistent_flow_masks( self, mask: np.ndarray, flow_network: np.ndarray, error_threshold: float = 0.4 ) -> np.ndarray: """ Remove labeled masks that have inconsistent optical flows compared to network-predicted flows. This performs a quality control step by computing flows from the provided masks and comparing them to the flows predicted by the network. Masks with a mean squared flow error above `error_threshold` are discarded (set to 0). Args: mask (np.ndarray): Integer mask array with shape [H, W]. Values: 0 = no mask; 1,2,... = mask labels. flow_network (np.ndarray): Float array of network-predicted flows with shape [2, H, W]. error_threshold (float): Maximum allowed mean squared flow error per mask label. Defaults to 0.4. Returns: np.ndarray: The input mask with inconsistent masks removed (labels set to 0). Raises: MemoryError: If the mask size exceeds available GPU memory. """ # If mask is very large and running on CUDA, check memory num_pixels = mask.size if ( num_pixels > 10000 * 10000 and self._device.type == 'cuda' ): # Clear unused GPU cache torch.cuda.empty_cache() # Determine PyTorch version major, minor = map(int, torch.__version__.split('.')[:2]) # Determine current CUDA device index device_index = ( self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device() ) # Get free and total memory if major == 1 and minor < 10: total_mem = torch.cuda.get_device_properties(device_index).total_memory used_mem = torch.cuda.memory_allocated(device_index) free_mem = total_mem - used_mem else: free_mem, total_mem = torch.cuda.mem_get_info(device_index) # Estimate required memory for mask-based flow computation # Assume float32 per pixel required_bytes = num_pixels * np.dtype(np.float32).itemsize if required_bytes > free_mem: logger.error( 'Image too large for GPU memory in flow QC step (required: %d B, available: %d B)', required_bytes, free_mem ) raise MemoryError('Insufficient GPU memory for flow QC computation') # Compute flow errors per mask label flow_errors, _ = self.__compute_flow_error(mask, flow_network) # Identify labels with error above threshold bad_labels = np.nonzero(flow_errors > error_threshold)[0] + 1 # Remove bad masks by setting their label to 0 mask[np.isin(mask, bad_labels)] = 0 return mask def __compute_flow_error( self, mask: np.ndarray, flow_network: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Compute mean squared error between network-predicted flows and flows derived from masks. Args: mask (np.ndarray): Integer masks, shape must match flow_network spatial dims. flow_network (np.ndarray): Network predicted flows of shape [axis, ...]. Returns: Tuple[np.ndarray, np.ndarray]: - flow_errors: 1D array (length = max label) of mean squared error per label. - computed_flows: Array of flows derived from the mask, same shape as flow_network. Raises: ValueError: If the spatial dimensions of `mask_array` and `flow_network` do not match. """ # Ensure mask and flow shapes match if flow_network.shape[1:] != mask.shape: logger.error( 'Shape mismatch: network flow shape %s vs mask shape %s', flow_network.shape[1:], mask.shape ) raise ValueError('Network flow and mask shapes must match') # Compute flows from mask labels (user-provided function) computed_flows = self.__compute_flow_from_mask(mask[None, ...]) # Prepare array for errors (one value per mask label) num_labels = int(mask.max()) flow_errors = np.zeros(num_labels, dtype=float) # Accumulate mean squared error over each flow axis for axis_index in range(computed_flows.shape[0]): # MSE per label: mean((computed - predicted/5)^2) flow_errors += mean( (computed_flows[axis_index] - flow_network[axis_index] / 5.0) ** 2, mask, index=np.arange(1, num_labels + 1) ) return flow_errors, computed_flows def __fill_holes_and_prune_small_masks( self, masks: np.ndarray, minimum_size: int = 15 ) -> np.ndarray: """ Fill holes in labeled masks and remove masks smaller than a given size. This function performs two steps: 1. Fills internal holes in each labeled mask using `fill_voids.fill`. 2. Discards any mask whose pixel count is below `minimum_size`. Args: masks (np.ndarray): Integer mask array of dimension 2 or 3 (shape [H, W] or [D, H, W]). Values: 0 = background; 1,2,... = mask labels. minimum_size (int): Minimum number of pixels required to keep a mask. Masks smaller than this will be removed. Set to -1 to skip size-based pruning. Defaults to 15. Returns: np.ndarray: Processed mask array with holes filled and small masks removed. Raises: ValueError: If `masks` is not a 2D or 3D integer array. """ # Validate input dimensions if masks.ndim not in (2, 3): msg = f"Expected 2D or 3D mask array, got {masks.ndim}D." logger.error(msg) raise ValueError(msg) # Optionally remove masks smaller than minimum_size if minimum_size >= 0: # Compute label counts (skipping background at index 0) labels, counts = fastremap.unique(masks, return_counts=True) # Identify labels to remove: those with count < minimum_size small_labels = labels[counts < minimum_size] if small_labels.size > 0: masks = fastremap.mask(masks, small_labels) fastremap.renumber(masks, in_place=True) # Find bounding boxes for each mask label object_slices = find_objects(masks) new_label = 1 output_masks = np.zeros_like(masks, dtype=masks.dtype) # Loop over each original slice, fill holes, and assign new labels for original_label, slc in enumerate(object_slices, start=1): if slc is None: continue # Extract sub-volume or sub-image region = masks[slc] == original_label if not np.any(region): continue # Fill internal holes filled_region = fill_voids.fill(region) # Write back into output mask with sequential labels output_masks[slc][filled_region] = new_label new_label += 1 # Final pruning of small masks after filling (optional) if minimum_size >= 0: labels, counts = fastremap.unique(output_masks, return_counts=True) small_labels = labels[counts < minimum_size] if small_labels.size > 0: output_masks = fastremap.mask(output_masks, small_labels) fastremap.renumber(output_masks, in_place=True) return output_masks