""" This code is adapted from the following codes: [1] https://github.com/MouseLand/cellpose/blob/main/cellpose/utils.py [2] https://github.com/MouseLand/cellpose/blob/main/cellpose/dynamics.py [3] https://github.com/MouseLand/cellpose/blob/main/cellpose/metrics.py [4] https://github.com/Lee-Gihun/MEDIAR/tree/main/core [5] https://github.com/Lee-Gihun/MEDIAR/tree/main/core/MEDIAR """ import random import numpy as np from numba import njit, prange import torch from torch import Tensor import torch.nn.functional as F from torch.utils.data import DataLoader import fastremap import fill_voids from skimage import morphology from skimage.segmentation import find_boundaries from scipy.special import expit from scipy.ndimage import mean, find_objects from monai.data.dataset import Dataset from monai.transforms import * # type: ignore from monai.inferers.utils import sliding_window_inference from monai.metrics.cumulative_average import CumulativeAverage import matplotlib.pyplot as plt import matplotlib.colors as mcolors import os import glob import csv import copy import time import tifffile as tiff from pprint import pformat from tabulate import tabulate from typing import Any, Dict, Literal, Optional, Tuple, List, Union from tqdm import tqdm import wandb from config import Config from core.models import * from core.losses import * from core.optimizers import * from core.schedulers import * from core.utils import ( compute_batch_segmentation_tp_fp_fn, compute_f1_score, compute_average_precision_score ) from core.logger import get_logger logger = get_logger() class CellSegmentator: def __init__(self, config: Config) -> None: self._device: torch.device = torch.device(config.dataset_config.common.device or "cpu") self.__set_seed(config.dataset_config.common.seed) self.__parse_config(config) self._scaler = ( torch.amp.GradScaler(self._device.type) # type: ignore if self._dataset_setup.is_training and self._dataset_setup.common.use_amp else None ) self._train_dataloader: Optional[DataLoader] = None self._valid_dataloader: Optional[DataLoader] = None self._test_dataloader: Optional[DataLoader] = None self._predict_dataloader: Optional[DataLoader] = None self._best_weights = None def create_dataloaders( self, train_transforms: Optional[Compose] = None, valid_transforms: Optional[Compose] = None, test_transforms: Optional[Compose] = None, predict_transforms: Optional[Compose] = None ) -> None: """ Creates train, validation, test, and prediction dataloaders based on dataset configuration and provided transforms. Args: train_transforms (Optional[Compose]): Transformations for training data. valid_transforms (Optional[Compose]): Transformations for validation data. test_transforms (Optional[Compose]): Transformations for testing data. predict_transforms (Optional[Compose]): Transformations for prediction data. Raises: ValueError: If required transforms are missing. RuntimeError: If critical dataset config values are missing. """ if self._dataset_setup.is_training and train_transforms is None: raise ValueError("Training mode requires 'train_transforms' to be provided.") elif not self._dataset_setup.is_training and test_transforms is None and predict_transforms is None: raise ValueError("In inference mode, at least one of 'test_transforms' or 'predict_transforms' must be provided.") if self._dataset_setup.is_training: # Training mode: handle either pre-split datasets or splitting on the fly if self._dataset_setup.training.is_split: # Validate presence of validation transforms if validation directory and size are set if ( self._dataset_setup.training.pre_split.valid_dir and self._dataset_setup.training.valid_size and valid_transforms is None ): raise ValueError("Validation transforms must be provided when using pre-split validation data.") # Use explicitly split directories train_dir = self._dataset_setup.training.pre_split.train_dir valid_dir = self._dataset_setup.training.pre_split.valid_dir test_dir = self._dataset_setup.training.pre_split.test_dir train_offset = self._dataset_setup.training.train_offset valid_offset = self._dataset_setup.training.valid_offset test_offset = self._dataset_setup.training.test_offset shuffle = False else: # Same validation for split mode with full data directory if ( self._dataset_setup.training.split.all_data_dir and self._dataset_setup.training.valid_size and valid_transforms is None ): raise ValueError("Validation transforms must be provided when splitting dataset.") # Automatically split dataset from one directory train_dir = valid_dir = test_dir = self._dataset_setup.training.split.all_data_dir number_of_images = len(os.listdir(os.path.join(train_dir, 'images'))) if number_of_images == 0: raise FileNotFoundError(f"No images found in '{train_dir}/images'") # Calculate train/valid sizes train_size = ( self._dataset_setup.training.train_size if isinstance(self._dataset_setup.training.train_size, int) else int(number_of_images * self._dataset_setup.training.train_size) ) valid_size = ( self._dataset_setup.training.valid_size if isinstance(self._dataset_setup.training.valid_size, int) else int(number_of_images * self._dataset_setup.training.valid_size) ) train_offset = self._dataset_setup.training.train_offset valid_offset = self._dataset_setup.training.valid_offset + train_size test_offset = self._dataset_setup.training.test_offset + train_size + valid_size shuffle = True # Train dataloader train_dataset = self.__get_dataset( images_dir=os.path.join(train_dir, 'images'), masks_dir=os.path.join(train_dir, 'masks', self._dataset_setup.common.masks_subdir), transforms=train_transforms, # type: ignore size=self._dataset_setup.training.train_size, offset=train_offset, shuffle=shuffle ) self._train_dataloader = DataLoader(train_dataset, batch_size=self._dataset_setup.training.batch_size, shuffle=True) logger.info(f"Loaded training dataset with {len(train_dataset)} samples.") # Validation dataloader if valid_transforms is not None: if not valid_dir or not self._dataset_setup.training.valid_size: raise RuntimeError("Validation directory or size is not properly configured.") valid_dataset = self.__get_dataset( images_dir=os.path.join(valid_dir, 'images'), masks_dir=os.path.join(valid_dir, 'masks', self._dataset_setup.common.masks_subdir), transforms=valid_transforms, size=self._dataset_setup.training.valid_size, offset=valid_offset, shuffle=shuffle ) self._valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded validation dataset with {len(valid_dataset)} samples.") # Test dataloader if test_transforms is not None: if not test_dir or not self._dataset_setup.training.test_size: raise RuntimeError("Test directory or size is not properly configured.") test_dataset = self.__get_dataset( images_dir=os.path.join(test_dir, 'images'), masks_dir=os.path.join(test_dir, 'masks', self._dataset_setup.common.masks_subdir), transforms=test_transforms, size=self._dataset_setup.training.test_size, offset=test_offset, shuffle=shuffle ) self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded test dataset with {len(test_dataset)} samples.") # Prediction dataloader if predict_transforms is not None: if not test_dir or not self._dataset_setup.training.test_size: raise RuntimeError("Prediction directory or size is not properly configured.") predict_dataset = self.__get_dataset( images_dir=os.path.join(test_dir, 'images'), masks_dir=None, transforms=predict_transforms, size=self._dataset_setup.training.test_size, offset=test_offset, shuffle=shuffle ) self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.") else: # Inference mode (no training) test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images') test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks', self._dataset_setup.common.masks_subdir) if test_transforms is not None: test_dataset = self.__get_dataset( images_dir=test_images, masks_dir=test_masks, transforms=test_transforms, size=self._dataset_setup.testing.test_size, offset=self._dataset_setup.testing.test_offset, shuffle=self._dataset_setup.testing.shuffle ) self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded test dataset with {len(test_dataset)} samples.") if predict_transforms is not None: predict_dataset = self.__get_dataset( images_dir=test_images, masks_dir=None, transforms=predict_transforms, size=self._dataset_setup.testing.test_size, offset=self._dataset_setup.testing.test_offset, shuffle=self._dataset_setup.testing.shuffle ) self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.") def print_data_info( self, loader_type: Literal["train", "valid", "test", "predict"], index: Optional[int] = None ) -> None: """ Prints statistics for a single sample from the specified dataloader. Args: loader_type: One of "train", "valid", "test", "predict". index: The sample index; if None, a random index is selected. """ # Retrieve the dataloader attribute, e.g., self._train_dataloader loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None) if loader is None: logger.error(f"Dataloader '{loader_type}' is not initialized.") return dataset = loader.dataset total = len(dataset) # type: ignore if total == 0: logger.error(f"Dataset for '{loader_type}' is empty.") return # Choose index idx = index if index is not None else random.randint(0, total - 1) if not (0 <= idx < total): logger.error(f"Index {idx} is out of range [0, {total}).") return # Fetch the sample and apply transforms sample = dataset[idx] # Expecting a dict with {'image': ..., 'mask': ...} or {'image': ...} img = sample["image"] # Convert tensor to numpy if needed if isinstance(img, torch.Tensor): img = img.cpu().numpy() img = np.asarray(img) # Compute image statistics img_min, img_max = img.min(), img.max() img_mean, img_std = float(img.mean()), float(img.std()) img_shape = img.shape # Prepare log lines lines = [ "=" * 40, f"Dataloader: '{loader_type}', sample index: {idx} / {total - 1}", f"Image — shape: {img_shape}, min: {img_min:.4f}, max: {img_max:.4f}, mean: {img_mean:.4f}, std: {img_std:.4f}" ] # For 'predict', no mask is available if loader_type != "predict": mask = sample.get("mask", None) if mask is not None: # Convert tensor to numpy if needed if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() mask = np.asarray(mask) m_min, m_max = mask.min(), mask.max() m_mean, m_std = float(mask.mean()), float(mask.std()) m_shape = mask.shape lines.append( f"Mask — shape: {m_shape}, min: {m_min:.4f}, " f"max: {m_max:.4f}, mean: {m_mean:.4f}, std: {m_std:.4f}" ) else: lines.append("Mask — not available for this sample.") lines.append("=" * 40) # Output via logger for l in lines: logger.info(l) def train(self, save_results: bool = True, only_masks: bool = False) -> None: """ Train the model over multiple epochs, including validation and test. Args: save_results (bool): If True, the predicted masks and test metrics will be saved. only_masks (bool): If True and save_results is True, only raw predicted masks are saved, without visualization overlays. """ # Ensure training is enabled in dataset setup if not self._dataset_setup.is_training: raise RuntimeError("Dataset setup indicates training is disabled.") # Determine device name for logging if self._device.type == "cpu": device_name = "cpu" else: idx = self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device() device_name = torch.cuda.get_device_name(idx) logger.info(f"\n{'=' * 50}\n") logger.info(f"Training starts on device: {device_name}") logger.info(f"\n{'=' * 50}") best_f1_score = 0.0 for epoch in range(1, self._dataset_setup.training.num_epochs + 1): train_metrics = self.__run_epoch("train", epoch) self.__print_with_logging(train_metrics, epoch) # Step the scheduler after training if self._scheduler is not None: self._scheduler.step() # Periodic validation or tuning if epoch % self._dataset_setup.training.val_freq == 0: if self._valid_dataloader is not None: valid_metrics = self.__run_epoch("valid", epoch) self.__print_with_logging(valid_metrics, epoch) # Update best model on improved F1 f1 = valid_metrics.get("valid_f1_score", 0.0) if f1 > best_f1_score: best_f1_score = f1 # Deep copy weights to avoid reference issues self._best_weights = copy.deepcopy(self._model.state_dict()) logger.info(f"Updated best model weights with F1 score: {f1:.4f}") # Restore best model weights if available if self._best_weights is not None: self._model.load_state_dict(self._best_weights) if self._test_dataloader is not None: test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks) self.__print_with_logging(test_metrics, 0) save_path = self._dataset_setup.common.predictions_dir os.makedirs(save_path, exist_ok=True) self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv')) def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None: """ Run a full test epoch and display/log the resulting metrics. Args: save_results (bool): If True, the predicted masks and test metrics will be saved. only_masks (bool): If True and save_results is True, only raw predicted masks are saved, without visualization overlays. """ test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks) self.__print_with_logging(test_metrics, 0) save_path = self._dataset_setup.common.predictions_dir os.makedirs(save_path, exist_ok=True) self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv')) def predict(self, only_masks: bool = False) -> None: """ Run inference on the predict set and save the resulting instance masks. Args: only_masks (bool): If True, only raw predicted masks are saved, without visualization overlays. """ # Ensure the predict DataLoader has been set if self._predict_dataloader is None: raise RuntimeError("DataLoader for mode 'predict' is not set.") batch_counter = 0 for batch in tqdm(self._predict_dataloader, desc="Predicting"): # Move input images to the configured device (CPU/GPU) inputs = batch["image"].to(self._device) # Use automatic mixed precision if enabled in dataset setup with torch.amp.autocast( # type: ignore self._device.type, enabled=self._dataset_setup.common.use_amp ): # Disable gradient computation for inference with torch.no_grad(): # Run the model’s forward pass in ‘predict’ mode raw_output = self.__run_inference(inputs, mode="predict") # Convert logits/probabilities to discrete instance masks # ground_truth is not passed here; only predictions are needed preds, _ = self.__post_process_predictions(raw_output) # Save out the predicted masks, using batch_counter to index files self.__save_prediction_masks( sample=batch, predicted_mask=preds, start_index=batch_counter, only_masks=only_masks ) # Increment counter by batch size for unique file naming batch_counter += inputs.shape[0] def run(self, save_results: bool = True, only_masks: bool = False) -> None: """ Orchestrate the full workflow and report execution time: - If training is enabled in the dataset setup, start training. - Otherwise, if a test DataLoader is provided, run evaluation. - Else if a prediction DataLoader is provided, run inference/prediction. - If neither loader is available in non‐training mode, raise an error. Args: save_results (bool): If True, the predicted masks and test metrics will be saved. only_masks (bool): If True and save_results is True, only raw predicted masks are saved, without visualization overlays. """ start_time = time.time() logger.info( f"Masks saving: {'enabled' if save_results else 'disabled'}; " f"Additional visualizations: " f"{'enabled' if save_results and not only_masks else 'disabled'}" ) # 1) TRAINING PATH if self._dataset_setup.is_training: # Launch the full training loop (with validation, scheduler steps, etc.) self.train(save_results=save_results, only_masks=only_masks) else: # 2) NON-TRAINING PATH (TEST or PREDICT) if self._test_dataloader is not None: # Run a single evaluation epoch on the test set and log metrics self.evaluate(save_results=save_results, only_masks=only_masks) elif self._predict_dataloader is not None: # Run inference on the predict set and save outputs self.predict(only_masks=only_masks) else: # 3) ERROR: no appropriate loader found raise RuntimeError( "Neither test nor predict DataLoader is set for non‐training mode." ) elapsed = time.time() - start_time if elapsed < 60: logger.info(f"Total execution time: {elapsed:.2f} seconds") elif elapsed < 3600: minutes = int(elapsed // 60) seconds = elapsed % 60 logger.info(f"Total execution time: {minutes} min {seconds:.2f} sec") else: hours = int(elapsed // 3600) minutes = int((elapsed % 3600) // 60) seconds = elapsed % 60 logger.info(f"Total execution time: {hours} h {minutes} min {seconds:.2f} sec") def load_from_checkpoint(self, checkpoint_path: str) -> None: """ Loads model weights from a specified checkpoint into the current model, but only for parameters whose shapes match. Parameters with mismatched shapes (e.g., classification heads with different output sizes) remain at their initialized values. Args: checkpoint_path (str): Path to the checkpoint file containing the model weights. """ # Load the checkpoint (state_dict) from file onto CPU checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) # Extract nested state_dict if present state_dict = checkpoint.get("state_dict", checkpoint) # Get the current model's parameter dictionary model_dict = self._model.state_dict() # Filter pretrained parameters to those matching in name and shape pretrained_dict = { k: v for k, v in state_dict.items() if k in model_dict and v.size() == model_dict[k].size() } # Log how many parameters are loaded, skipped, or missing skipped = [k for k in state_dict if k not in pretrained_dict] missing = [k for k in model_dict if k not in pretrained_dict] logger.info( f"Loaded {len(pretrained_dict)} parameters;" f" skipped {len(skipped)} params from checkpoint;" f" {len(missing)} params remain uninitialized in model." ) # Update the model's state_dict and load it model_dict.update(pretrained_dict) self._model.load_state_dict(model_dict) def save_checkpoint(self, checkpoint_path: str) -> None: """ Saves the current model weights to a checkpoint file. Args: checkpoint_path (str): Path where the checkpoint file will be saved. """ # Write the checkpoint to disk os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) torch.save(( self._model.state_dict() if self._best_weights is None else self._best_weights), checkpoint_path ) def __parse_config(self, config: Config) -> None: """ Parses the given configuration object to initialize model, criterion, optimizer, scheduler, and dataset setup. Args: config (Config): Configuration object with model, optimizer, scheduler, criterion, and dataset setup information. """ model = config.model criterion = config.criterion optimizer = config.optimizer scheduler = config.scheduler logger.info("========== Parsed Configuration ==========") logger.info("Model Config:\n%s", pformat(model.dump(), indent=2)) if criterion: logger.info("Criterion Config:\n%s", pformat(criterion.dump(), indent=2)) if optimizer: logger.info("Optimizer Config:\n%s", pformat(optimizer.dump(), indent=2)) if scheduler: logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2)) logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2)) logger.info("Wandb Config:\n%s", pformat(config.wandb_config.model_dump(), indent=2)) logger.info("==========================================") # Initialize model using the model registry self._model = ModelRegistry.get_model_class(model.name)(model.params) # Loads model weights from a specified checkpoint pretrained_weights = config.dataset_config.common.pretrained_weights if pretrained_weights: self.load_from_checkpoint(pretrained_weights) logger.info(f"Loaded pre-trained weights from: {pretrained_weights}") self._model = self._model.to(self._device) # Initialize loss criterion if specified self._criterion = ( CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params) if criterion is not None else None ) if hasattr(self._criterion, "num_classes"): nc_model = self._model.num_classes nc_crit = getattr(self._criterion, "num_classes") if nc_model != nc_crit: raise ValueError( f"Number of classes mismatch: model.num_classes={nc_model} " f"but criterion.num_classes={nc_crit}" ) # Initialize optimizer if specified self._optimizer = ( OptimizerRegistry.get_optimizer_class(optimizer.name)( model_params=self._model.parameters(), optim_params=optimizer.params ) if optimizer is not None else None ) # Initialize scheduler only if both scheduler and optimizer are defined self._scheduler = ( SchedulerRegistry.get_scheduler_class(scheduler.name)( optimizer=self._optimizer.optim, params=scheduler.params ) if scheduler is not None and self._optimizer is not None and self._optimizer.optim is not None else None ) logger.info("========== Model Components Initialization ==========") logger.info("├─ Model: " + (f"{model.name}" if self._model else "Not specified")) logger.info("├─ Criterion: " + (f"{criterion.name}" if self._criterion else "Not specified")) # type: ignore logger.info("├─ Optimizer: " + (f"{optimizer.name}" if self._optimizer else "Not specified")) # type: ignore logger.info("└─ Scheduler: " + (f"{scheduler.name}" if self._scheduler else "Not specified")) # type: ignore logger.info("=====================================================") # Save dataset config self._dataset_setup = config.dataset_config common = config.dataset_config.common logger.info("========== Dataset Setup ==========") logger.info("[COMMON]") logger.info(f"├─ Seed: {common.seed}") logger.info(f"├─ Device: {common.device}") logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}") logger.info(f"├─ Masks subdirectory: {common.masks_subdir}") logger.info(f"└─ Predictions output dir: {common.predictions_dir}") logger.info(f"├─ Pretrained weights: {common.pretrained_weights or 'None'}") if config.dataset_config.is_training: training = config.dataset_config.training logger.info("[MODE] Training") logger.info(f"├─ Batch size: {training.batch_size}") logger.info(f"├─ Epochs: {training.num_epochs}") logger.info(f"├─ Validation frequency: {training.val_freq}") if training.is_split: logger.info(f"├─ Using pre-split directories:") logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}") logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}") logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}") else: logger.info(f"├─ Using unified dataset with splits:") logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}") logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}") logger.info(f"└─ Dataset split:") logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}") logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}") logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}") else: testing = config.dataset_config.testing logger.info("[MODE] Inference") logger.info(f"├─ Test dir: {testing.test_dir}") logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})") logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}") self._wandb_config = config.wandb_config if self._wandb_config.use_wandb: logger.info("[W&B]") logger.info(f"├─ Project: {self._wandb_config.project}") if self._wandb_config.group: logger.info(f"├─ Group: {self._wandb_config.group}") if self._wandb_config.entity: logger.info(f"├─ Entity: {self._wandb_config.entity}") if self._wandb_config.name: logger.info(f"├─ Run name: {self._wandb_config.name}") if self._wandb_config.tags: logger.info(f"├─ Tags: {', '.join(self._wandb_config.tags)}") if self._wandb_config.notes: logger.info(f"├─ Notes: {self._wandb_config.notes}") logger.info(f"└─ Save code: {'yes' if self._wandb_config.save_code else 'no'}") else: logger.info("[W&B] Logging disabled") logger.info("===================================") def __set_seed(self, seed: Optional[int]) -> None: """ Sets the random seed for reproducibility across Python, NumPy, and PyTorch. Args: seed (Optional[int]): Seed value. If None, no seeding is performed. """ if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False logger.info(f"Random seed set to {seed}") else: logger.info("Seed not set (None provided)") def __get_dataset( self, images_dir: str, masks_dir: Optional[str], transforms: Compose, size: Union[int, float], offset: int, shuffle: bool ) -> Dataset: """ Loads and returns a dataset object from image and optional mask directories. Args: images_dir (str): Path to directory or glob pattern for input images. masks_dir (Optional[str]): Path to directory or glob pattern for masks. transforms (Compose): Transformations to apply to each image or pair. size (Union[int, float]): Either an integer or a fraction of the dataset. offset (int): Number of images to skip from the start. shuffle (bool): Whether to shuffle the dataset before slicing. Returns: Dataset: A dataset containing image and optional mask paths. Raises: FileNotFoundError: If no images are found. ValueError: If masks are provided but do not match image count. ValueError: If dataset is too small for requested size or offset. """ # Collect sorted list of image paths images = sorted( glob.glob(os.path.join(images_dir, '*.tif')) + glob.glob(os.path.join(images_dir, '*.tiff')) ) if not images: raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'") if masks_dir is not None: # Collect and validate sorted list of mask paths masks = sorted( glob.glob(os.path.join(masks_dir, '*.tif')) + glob.glob(os.path.join(masks_dir, '*.tiff')) ) if len(images) != len(masks): raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})") # Convert float size (fraction) to absolute count size = size if isinstance(size, int) else int(size * len(images)) if size <= 0: raise ValueError(f"Size must be positive, got: {size}") if len(images) < size: raise ValueError(f"Not enough images ({len(images)}) for requested size ({size})") if len(images) < size + offset: raise ValueError(f"Offset ({offset}) + size ({size}) exceeds dataset length ({len(images)})") # Shuffle image-mask pairs if requested if shuffle: if masks_dir is not None: combined = list(zip(images, masks)) # type: ignore random.shuffle(combined) images, masks = zip(*combined) else: random.shuffle(images) # Apply offset and limit by size images = images[offset: offset + size] if masks_dir is not None: masks = masks[offset: offset + size] # type: ignore # Prepare data structure for Dataset class if masks_dir is not None: data = [ {"image": image, "mask": mask} for image, mask in zip(images, masks) # type: ignore ] else: data = [{"image": image} for image in images] return Dataset(data, transforms) def __print_with_logging(self, metrics: Dict[str, Union[float, np.ndarray]], step: int) -> None: """ Print metrics in a tabular format and log to W&B. Args: metrics (Dict[str, Union[float, np.ndarray]]): Mapping from metric names to either a float or a ND numpy array. step (int): epoch index. """ rows: list[tuple[str, str]] = [] for key, val in metrics.items(): if isinstance(val, np.ndarray): # Convert array to string, e.g. '[0.2, 0.8, 0.5]' val_str = np.array2string(val, separator=', ') else: # Format scalar with 4 decimal places val_str = f"{val:.4f}" rows.append((key, val_str)) table = tabulate( tabular_data=rows, headers=["Metric", "Value"], floatfmt=".4f", tablefmt="fancy_grid" ) print(table, "\n") if self._wandb_config.use_wandb: # Keep only scalar values scalar_results: dict[str, float] = {} for key, val in metrics.items(): if isinstance(val, np.ndarray): continue # Ensure float type scalar_results[key] = float(val) wandb.log(scalar_results, step=step) def __save_metrics_to_csv( self, metrics: Dict[str, Union[float, np.ndarray]], output_path: str ) -> None: """ Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'. Args: metrics (Dict[str, Union[float, np.ndarray]]): Mapping from metric names to scalar values or numpy arrays. output_path (str): Path to the output CSV file. """ with open(output_path, mode='w', newline='') as csv_file: writer = csv.writer(csv_file) writer.writerow(['Metric', 'Value']) for name, value in metrics.items(): # Convert numpy arrays to string representation if isinstance(value, np.ndarray): # Flatten and join with commas flat = value.flatten() val_str = ','.join([f"{v}" for v in flat]) else: val_str = f"{value}" writer.writerow([name, val_str]) def __run_epoch(self, mode: Literal["train", "valid", "test"], epoch: Optional[int] = None, save_results: bool = True, only_masks: bool = False ) -> Dict[str, Union[float, np.ndarray]]: """ Execute one epoch of training, validation, or testing. Args: mode (str): One of 'train', 'valid', or 'test'. epoch (int, optional): Current epoch number for logging. save_results (bool): If True, the predicted masks and test metrics will be saved. only_masks (bool): If True and save_results is True, only raw predicted masks are saved, without visualization overlays. Returns: Dict[str, Union[float, np.ndarray]]: Metrics for valid/test. """ # Ensure required components are available if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None): raise RuntimeError("Optimizer and loss function must be initialized for train/valid.") # Set model mode and choose the appropriate data loader if mode == "train": self._model.train() loader = self._train_dataloader else: self._model.eval() loader = self._valid_dataloader if mode == "valid" else self._test_dataloader # Check that the data loader is provided if loader is None: raise RuntimeError(f"DataLoader for mode '{mode}' is not set.") all_tp, all_fp, all_fn = [], [], [] # Prepare tqdm description with epoch info if available if epoch is not None: desc = f"Epoch {epoch}/{self._dataset_setup.training.num_epochs} [{mode}]" else: desc = f"Epoch ({mode})" # Iterate over batches batch_counter = 0 for batch in tqdm(loader, desc=desc): inputs = batch["image"].to(self._device) targets = batch["mask"].to(self._device) # Zero gradients for training if self._optimizer is not None: self._optimizer.zero_grad() # Mixed precision context if enabled with torch.amp.autocast( # type: ignore self._device.type, enabled=self._dataset_setup.common.use_amp ): # Only compute gradients in training phase with torch.set_grad_enabled(mode == "train"): # Forward pass raw_output = self.__run_inference(inputs, mode) if self._criterion is not None: # Convert label masks to flow representations (one-hot) flow_targets = self.__compute_flows_from_masks(targets) # Compute loss for this batch batch_loss = self._criterion( raw_output, torch.from_numpy(flow_targets).to(device=raw_output.device) ) # Post-process and compute F1 during validation and testing if mode in ("valid", "test"): preds, labels_post = self.__post_process_predictions( raw_output, targets ) # Collecting statistics on the batch tp, fp, fn, tp_masks, fp_masks, fn_masks = self.__compute_stats( predicted_masks=preds, ground_truth_masks=labels_post, # type: ignore iou_threshold=0.5, return_error_masks=(mode == "test") and save_results is True and not only_masks ) all_tp.append(tp) all_fp.append(fp) all_fn.append(fn) if mode == "test" and save_results is True: masks = (tp_masks, fp_masks, fn_masks) if not only_masks else None self.__save_prediction_masks( sample=batch, predicted_mask=preds, start_index=batch_counter, only_masks=only_masks, masks=masks # type: ignore ) # Backpropagation and optimizer step in training if mode == "train": if self._dataset_setup.common.use_amp and self._scaler is not None: self._scaler.scale(batch_loss).backward() # type: ignore self._scaler.unscale_(self._optimizer.optim) # type: ignore self._scaler.step(self._optimizer.optim) # type: ignore self._scaler.update() else: batch_loss.backward() # type: ignore if self._optimizer is not None: self._optimizer.step() batch_counter += inputs.shape[0] if self._criterion is not None: # Collect loss metrics epoch_metrics: Dict[str, Union[float, np.ndarray]] = { f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items() } # Reset internal loss metrics accumulator self._criterion.reset_metrics() else: epoch_metrics = {} # Include F1 and mAP for validation and testing if mode in ("valid", "test"): # Concatenating by batch: shape (num_batches*B, C) tp_array = np.vstack(all_tp) fp_array = np.vstack(all_fp) fn_array = np.vstack(all_fn) epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( tp_array, fp_array, fn_array, reduction="micro" ) epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( tp_array, fp_array, fn_array, reduction="imagewise" ) epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( tp_array, fp_array, fn_array, reduction="macro" ) return epoch_metrics def __run_inference( self, inputs: torch.Tensor, mode: Literal["train", "valid", "test", "predict"] = "train" ) -> torch.Tensor: """ Perform model inference for different stages. Args: inputs (torch.Tensor): Input tensor of shape (B, C, H, W). stage (Literal[...]): One of "train", "valid", "test", "predict". Returns: torch.Tensor: Model outputs tensor. """ if mode != "train": # Use sliding window inference for non-training phases outputs = sliding_window_inference( inputs, roi_size=512, sw_batch_size=4, predictor=self._model, padding_mode="constant", mode="gaussian", overlap=0.5, ) else: # Direct forward pass during training outputs = self._model(inputs) return outputs # type: ignore def __post_process_predictions( self, raw_outputs: torch.Tensor, ground_truth: Optional[torch.Tensor] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Post-process raw network outputs to extract instance segmentation masks. Args: raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W). ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W). Returns: Tuple[np.ndarray, Optional[np.ndarray]]: - instance_masks: Instance-wise masks array of shape (B, С, H, W). - labels_np: Converted ground truth of shape (B, С, H, W) or None if ground_truth was not provided. """ # Move outputs to CPU and convert to numpy outputs_np = raw_outputs.cpu().numpy() # Split channels: gradient flows then class logits gradflow = outputs_np[:, :2 * self._model.num_classes] logits = outputs_np[:, -self._model.num_classes :] # Apply sigmoid to logits to get probabilities probabilities = self.__sigmoid(logits) batch_size, _, height, width = probabilities.shape # Prepare container for instance masks instance_masks = np.zeros((batch_size, self._model.num_classes, height, width), dtype=np.uint16) for idx in range(batch_size): instance_masks[idx] = self.__segment_instances( probability_map=probabilities[idx], flow=gradflow[idx], prob_threshold=0.5, flow_threshold=0.4, min_object_size=15 ) # Convert ground truth to numpy labels_np = ground_truth.cpu().numpy() if ground_truth is not None else None return instance_masks, labels_np def __compute_stats( self, predicted_masks: np.ndarray, ground_truth_masks: np.ndarray, iou_threshold: float = 0.5, return_error_masks: bool = False ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """ Compute batch-wise true positives, false positives, and false negatives for instance segmentation, using a configurable IoU threshold. Args: predicted_masks (np.ndarray): Predicted instance masks of shape (B, C, H, W). ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W). iou_threshold (float): Intersection-over-Union threshold for matching predictions to ground truths (default: 0.5). return_error_masks (bool): Whether to also return binary error masks. Returns: Tuple(np.ndarray, np.ndarray, np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None): - tp: True positives per batch and class, shape (B, C) - fp: False positives per batch and class, shape (B, C) - fn: False negatives per batch and class, shape (B, C) - tp_maks: True positives mask per batch and class, shape (B, C, H, W) - fp_maks: False positives mask per batch and class, shape (B, C, H, W) - fn_maks: False negatives mask per batch and class, shape (B, C, H, W) """ stats = compute_batch_segmentation_tp_fp_fn( batch_ground_truth=ground_truth_masks, batch_prediction=predicted_masks, iou_threshold=iou_threshold, return_error_masks=return_error_masks, remove_boundary_objects=True ) tp = stats["tp"] fp = stats["fp"] fn = stats["fn"] tp_mask = stats["tp_mask"] if return_error_masks else None fp_mask = stats["fp_mask"] if return_error_masks else None fn_mask = stats["fn_mask"] if return_error_masks else None return tp, fp, fn, tp_mask, fp_mask, fn_mask def __compute_f1_metric( self, true_positives: np.ndarray, false_positives: np.ndarray, false_negatives: np.ndarray, reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro" ) -> Union[float, np.ndarray]: """ Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes. Args: true_positives: array of TP counts per sample and class. false_positives: array of FP counts per sample and class. false_negatives: array of FN counts per sample and class. reduction: - 'none': return F1 for each sample, class → shape (batch_size, num_classes) - 'micro': global F1 over all samples & classes - 'macro': average class-wise F1 (classes summed over batch) - 'imagewise': F1 per sample (summing over classes), then average over samples - 'per_class': F1 per class (summing over batch), return vector of shape (num_classes,) - 'weighted': class-wise F1 weighted by support (TP+FN) Returns: float for reductions 'micro', 'imagewise', 'macro', 'weighted'; or np.ndarray of shape (batch_size, num_classes) if reduction='none'. """ batch_size, num_classes = true_positives.shape # 1) No reduction: compute F1 for each (sample, class) if reduction == "none": f1_matrix = np.zeros((batch_size, num_classes), dtype=float) for i in range(batch_size): for c in range(num_classes): tp_val = int(true_positives[i, c]) fp_val = int(false_positives[i, c]) fn_val = int(false_negatives[i, c]) _, _, f1_val = compute_f1_score( tp_val, fp_val, fn_val ) f1_matrix[i, c] = f1_val return f1_matrix # 2) Micro: sum all TP/FP/FN and compute a single F1 if reduction == "micro": tp_total = int(true_positives.sum()) fp_total = int(false_positives.sum()) fn_total = int(false_negatives.sum()) _, _, f1_global = compute_f1_score( tp_total, fp_total, fn_total ) return f1_global # 3) Imagewise: compute per-sample F1 (sum over classes), then average if reduction == "imagewise": f1_per_image = np.zeros(batch_size, dtype=float) for i in range(batch_size): tp_i = int(true_positives[i].sum()) fp_i = int(false_positives[i].sum()) fn_i = int(false_negatives[i].sum()) _, _, f1_i = compute_f1_score( tp_i, fp_i, fn_i ) f1_per_image[i] = f1_i return float(f1_per_image.mean()) # Aggregate per class across the batch for per_class, macro, weighted tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,) fp_per_class = false_positives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int) # 4) Per-class: compute F1 for each class and return vector if reduction == "per_class": f1_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): _, _, f1_per_class[c] = compute_f1_score( tp_per_class[c], fp_per_class[c], fn_per_class[c] ) return f1_per_class # 5) Macro: average F1 across classes equally if reduction == "macro": f1_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): _, _, f1_c = compute_f1_score( tp_per_class[c], fp_per_class[c], fn_per_class[c] ) f1_per_class[c] = f1_c return float(f1_per_class.mean()) # 6) Weighted: class-wise F1 weighted by support = TP + FN if reduction == "weighted": f1_per_class = np.zeros(num_classes, dtype=float) support = np.zeros(num_classes, dtype=float) for c in range(num_classes): tp_c = tp_per_class[c] fp_c = fp_per_class[c] fn_c = fn_per_class[c] _, _, f1_c = compute_f1_score( tp_c, fp_c, fn_c ) f1_per_class[c] = f1_c support[c] = tp_c + fn_c total_support = support.sum() if total_support == 0: # fallback to unweighted macro if no positives return float(f1_per_class.mean()) weights = support / total_support return float((f1_per_class * weights).sum()) raise ValueError(f"Unknown reduction mode: {reduction}") def __compute_average_precision_metric( self, true_positives: np.ndarray, false_positives: np.ndarray, false_negatives: np.ndarray, reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro" ) -> Union[float, np.ndarray]: """ Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes. AP is defined here as: AP = TP / (TP + FP + FN) Args: true_positives: array of true positives per sample and class. false_positives: array of false positives per sample and class. false_negatives: array of false negatives per sample and class. reduction: - 'none': return AP for each sample and class → shape (batch_size, num_classes) - 'micro': global AP over all samples & classes - 'macro': average class-wise AP (each class summed over batch) - 'imagewise': AP per sample (summing stats over classes), then average over batch - 'per_class': AP per class (summing over batch), return vector of shape (num_classes,) - 'weighted': class-wise AP weighted by support (TP+FN) Returns: float for reductions 'micro', 'imagewise', 'macro', 'weighted'; or np.ndarray of shape (batch_size, num_classes) if reduction='none'. """ batch_size, num_classes = true_positives.shape # 1) No reduction: AP per (sample, class) if reduction == "none": ap_matrix = np.zeros((batch_size, num_classes), dtype=float) for i in range(batch_size): for c in range(num_classes): tp_val = int(true_positives[i, c]) fp_val = int(false_positives[i, c]) fn_val = int(false_negatives[i, c]) ap_val = compute_average_precision_score( tp_val, fp_val, fn_val ) ap_matrix[i, c] = ap_val return ap_matrix # 2) Micro: sum all TP/FP/FN and compute one AP if reduction == "micro": tp_total = int(true_positives.sum()) fp_total = int(false_positives.sum()) fn_total = int(false_negatives.sum()) return compute_average_precision_score( tp_total, fp_total, fn_total ) # 3) Imagewise: compute per-sample AP (sum over classes), then mean if reduction == "imagewise": ap_per_image = np.zeros(batch_size, dtype=float) for i in range(batch_size): tp_i = int(true_positives[i].sum()) fp_i = int(false_positives[i].sum()) fn_i = int(false_negatives[i].sum()) ap_per_image[i] = compute_average_precision_score( tp_i, fp_i, fn_i ) return float(ap_per_image.mean()) # For macro and weighted: first aggregate per class across batch tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,) fp_per_class = false_positives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int) # 4) Per-class: compute F1 for each class and return vector if reduction == "per_class": ap_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): ap_per_class[c] = compute_average_precision_score( tp_per_class[c], fp_per_class[c], fn_per_class[c] ) return ap_per_class # 5) Macro: average AP across classes equally if reduction == "macro": ap_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): ap_per_class[c] = compute_average_precision_score( tp_per_class[c], fp_per_class[c], fn_per_class[c] ) return float(ap_per_class.mean()) # 6) Weighted: class-wise AP weighted by support = TP + FN if reduction == "weighted": ap_per_class = np.zeros(num_classes, dtype=float) support = np.zeros(num_classes, dtype=float) for c in range(num_classes): tp_c = tp_per_class[c] fp_c = fp_per_class[c] fn_c = fn_per_class[c] ap_per_class[c] = compute_average_precision_score( tp_c, fp_c, fn_c ) support[c] = tp_c + fn_c total_support = support.sum() if total_support == 0: # fallback to unweighted macro if no positive instances return float(ap_per_class.mean()) weights = support / total_support return float((ap_per_class * weights).sum()) raise ValueError(f"Unknown reduction mode: {reduction}") @staticmethod def __sigmoid(z: np.ndarray) -> np.ndarray: """ Numerically stable sigmoid activation for numpy arrays. Args: z (np.ndarray): Input array. Returns: np.ndarray: Sigmoid of the input. """ return expit(z) def __save_prediction_masks( self, sample: Dict[str, Any], predicted_mask: Union[np.ndarray, torch.Tensor], start_index: int = 0, only_masks: bool = False, masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None ) -> None: """ Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders. Args: sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict'). predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W). start_index (int): Starting index for naming when metadata is missing. only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations. masks (Tuple[np.ndarray, np.ndarray, np.ndarray] | None): A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None. """ # Base directories (created once per call) base_output_dir = self._dataset_setup.common.predictions_dir masks_dir = base_output_dir plots_dir = os.path.join(base_output_dir, "plots") evaluate_dir = os.path.join(plots_dir, "evaluate") os.makedirs(masks_dir, exist_ok=True) os.makedirs(plots_dir, exist_ok=True) os.makedirs(evaluate_dir, exist_ok=True) # Convert tensors to numpy if necessary def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: return x.cpu().numpy() if isinstance(x, torch.Tensor) else x pred_array = to_numpy(predicted_mask).astype(np.uint16) # Handle batch dimension for idx in range(pred_array.shape[0]): batch_sample: Dict[str, Any] = {} # copy per-sample image and meta img = to_numpy(sample["image"]) if img.ndim == 4: batch_sample["image"] = img[idx] if "mask" in sample: msk = to_numpy(sample["mask"]).astype(np.uint16) if msk.ndim == 4: batch_sample["mask"] = msk[idx] image_meta = sample.get("image_meta_dict") if isinstance(image_meta, dict) and "filename_or_obj" in image_meta: fname = image_meta["filename_or_obj"][idx] batch_sample["image_name"] = fname single_masks = ( (masks[0][idx], masks[1][idx], masks[2][idx]) if masks is not None else None ) self.__save_single_prediction_mask( sample=batch_sample, pred_array=pred_array[idx], start_index=start_index + idx, masks_dir=masks_dir, plots_dir=plots_dir, evaluate_dir=evaluate_dir, only_masks=only_masks, masks=single_masks, ) def __save_single_prediction_mask( self, sample: Dict[str, Any], pred_array: np.ndarray, start_index: int, masks_dir: str, plots_dir: str, evaluate_dir: str, only_masks: bool = False, masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None ) -> None: """ Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations. Assumes output directories already exist. Args: sample (Dict[str, Any]): Dictionary containing 'image', 'mask', and optional 'image_meta_dict' for metadata. pred_array (np.ndarray): Predicted mask array of shape (C,H,W). start_index (int): Base index for generating filenames when metadata is missing. masks_dir (str): Directory for saving TIFF masks. plots_dir (str): Directory for saving PNG visualizations. evaluate_dir (str): Directory for saving PNG visualizations of evaluation results. only_masks (bool): If True, saves only TIFF mask files; skips PNG plots. masks (Tuple[np.ndarray, np.ndarray, np.ndarray], optional): A tuple of true-positive, false-positive, and false-negative mask arrays, each of shape (C,H,W). Defaults to None. """ if pred_array.ndim == 2: pred_array = np.expand_dims(pred_array, axis=0) elif pred_array.ndim != 3: raise ValueError( f"Unsupported predicted_mask dimensions: {pred_array.ndim}." "Expected 2D (H,W) or 3D (C,H,W)." ) # Handle image array if present image_array: np.ndarray = sample["image"] if image_array.ndim == 2: image_array = np.expand_dims(image_array, axis=0) elif image_array.ndim != 3: raise ValueError( f"Unsupported image dimensions: {image_array.ndim}." "Expected 2D (H,W) or 3D (C,H,W)." ) true_mask_array: Optional[np.ndarray] = sample.get("mask") if isinstance(true_mask_array, np.ndarray): if true_mask_array.ndim == 2: true_mask_array = np.expand_dims(true_mask_array, axis=0) elif true_mask_array.ndim != 3: raise ValueError( f"Unsupported true_mask_array dimensions: {true_mask_array.ndim}." "Expected 2D (H,W) or 3D (C,H,W)." ) # Determine filename base image_meta = sample.get("image_name") if isinstance(image_meta, (str, os.PathLike)): base_name = os.path.splitext(os.path.basename(image_meta))[0] else: base_name = f"prediction_{start_index:04d}" # Save main mask TIFF mask_path = os.path.join(masks_dir, f"{base_name}_mask.tif") tiff.imwrite(mask_path, pred_array.astype(np.uint16), compression="zlib") if only_masks: return # Save channel-wise plots num_channels = pred_array.shape[0] for ch in range(num_channels): true_ch = true_mask_array[ch] if true_mask_array is not None else None self.__plot_mask( file_path=os.path.join(plots_dir, f"{base_name}_ch{ch}.png"), image_data=image_array, predicted_mask=pred_array[ch], true_mask=true_ch, ) if masks is not None and true_ch is not None: self.__save_mask_comparison_visuals( gt=true_ch, pred=pred_array[ch], tp_mask=masks[0][ch], fp_mask=masks[1][ch], fn_mask=masks[2][ch], file_path=os.path.join(evaluate_dir, f"{base_name}_ch{ch}.png") ) def __plot_mask( self, file_path: str, image_data: np.ndarray, predicted_mask: np.ndarray, true_mask: Optional[np.ndarray] = None, ) -> None: """ Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided. Args: file_path (str): Path where the visualization image will be saved. image_data (np.ndarray): The original input image array, expected shape (C, H, W). predicted_mask (np.ndarray): The predicted mask array, shape (H, W), depending on the task. true_mask (Optional[np.ndarray], optional): The ground-truth mask array. If provided, an additional row with true mask and overlap visualization will be added to the plot. Default is None. """ img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data if true_mask is None: fig, axs = plt.subplots(1, 3, figsize=(15,5)) plt.subplots_adjust(wspace=0.02, hspace=0) self.__plot_panels(axs, img, predicted_mask, 'red', ('Original Image','Predicted Mask','Predicted Contours')) else: fig, axs = plt.subplots(2,3,figsize=(15,10)) plt.subplots_adjust(wspace=0.02, hspace=0.02) # row 0: predicted self.__plot_panels(axs[0], img, predicted_mask, 'red', ('Original Image','Predicted Mask','Predicted Contours')) # row 1: true self.__plot_panels(axs[1], img, true_mask, 'blue', ('Original Image','True Mask','True Contours')) fig.savefig(file_path, bbox_inches='tight', dpi=300) plt.close(fig) def __plot_panels( self, axes, img: np.ndarray, mask: np.ndarray, contour_color: str, titles: Tuple[str, ...] ): """ Plot a row of three panels: original image, mask, and mask boundaries on image. Args: axes: list/array of three Axis objects. img: Image array (H, W or H, W, C). mask: Label mask (H, W). contour_color: Color for boundaries. titles: (title_img, title_mask, title_contours). """ # Panel 1: Original image ax0, ax1, ax2 = axes ax0.imshow(img, cmap='gray' if img.ndim == 2 else None) ax0.set_title(titles[0]); ax0.axis('off') # Compute boundaries once boundaries = find_boundaries(mask, mode='thick') # Panel 2: Mask with black boundaries cmap = plt.get_cmap("gist_ncar") cmap = mcolors.ListedColormap([ cmap(i/len(np.unique(mask))) for i in range(len(np.unique(mask))) ]) ax1.imshow(mask, cmap=cmap) ax1.contour(boundaries, colors='black', linewidths=0.5) ax1.set_title(titles[1]) ax1.axis('off') # Panel 3: Original image with black contour overlay ax2.imshow(img) # Draw boundaries as contour lines ax2.contour(boundaries, colors=contour_color, linewidths=0.5) ax2.set_title(titles[2]) ax2.axis('off') def __save_mask_comparison_visuals( self, gt: np.ndarray, pred: np.ndarray, tp_mask: np.ndarray, fp_mask: np.ndarray, fn_mask: np.ndarray, file_path: str ) -> None: """ Creates and saves a 1x3 subplot figure showing: 1) True mask with boundaries 2) Predicted mask without boundaries 3) Overlay mask combining FP (R), TP (G), FN (B) Args: gt (np.ndarray): Ground truth mask (H, W). pred (np.ndarray): Predicted mask (H, W). tp_mask (np.ndarray): True positive mask (H, W). fp_mask (np.ndarray): False positive mask (H, W). fn_mask (np.ndarray): False negative mask (H, W). file_path (str): Path where the visualization image will be saved. """ # Prepare overlay mask overlap_mask = np.zeros((*gt.shape[:2], 3), dtype=np.uint8) overlap_mask[..., 0] = np.where(fp_mask, 255, 0) overlap_mask[..., 1] = np.where(tp_mask, 255, 0) overlap_mask[..., 2] = np.where(fn_mask, 255, 0) # Set up figure fig, axes = plt.subplots(1, 3, figsize=(15, 5), gridspec_kw={'width_ratios': [1, 1, 1]}) plt.subplots_adjust(wspace=0.02, hspace=0.0, left=0.05, right=0.95, top=0.95, bottom=0.05) # Colormap for instances num_instances = max(np.max(gt), np.max(pred)) cmap = plt.get_cmap("gist_ncar") colors = [cmap(i / num_instances) for i in range(num_instances)] cmap = mcolors.ListedColormap(colors) # Plot true mask axes[0].imshow(gt, cmap=cmap) axes[0].contour(find_boundaries(gt, mode="thick"), colors="black", linewidths=0.5) axes[0].set_title("True Mask") axes[0].axis("off") # Plot predicted mask axes[1].imshow(pred, cmap=cmap) axes[1].contour(find_boundaries(pred, mode="thick"), colors="black", linewidths=0.5) axes[1].set_title("Predicted Mask") axes[1].axis("off") # Plot overlay axes[2].imshow(overlap_mask) axes[2].set_title("Overlay Mask (R-FP; G-TP; B-FN)") axes[2].axis("off") # Save plt.savefig(file_path, bbox_inches="tight", dpi=300) plt.close() def __compute_flows_from_masks( self, true_masks: Tensor ) -> np.ndarray: """ Convert segmentation masks to flow fields for training. Args: true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks. Returns: numpy array of concatenated [flow_vectors, renumbered_true_masks] per image. renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow. """ # Move to CPU numpy _true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16) batch_size = _true_masks.shape[0] # Ensure each label has a channel dimension if _true_masks.ndim == 3: # shape (batch, H, W) -> (batch, 1, H, W) _true_masks = _true_masks[:, np.newaxis, :, :] batch_size, *_ = _true_masks.shape # Renumber labels to ensure uniqueness renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0] for i in range(batch_size)]) # Compute vector flows per image flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i]) for i in range(batch_size)]) return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32) def __compute_flow_from_mask( self, mask: np.ndarray ) -> np.ndarray: """ Compute normalized flow vectors from a labeled mask. Args: mask: 3D array of instance-labeled mask of shape (C, H, W). Returns: flow: Array of shape (2 * C, H, W). """ if mask.max() == 0 or np.count_nonzero(mask) <= 1: # No flow to compute logger.warning("Empty mask!") C, H, W = mask.shape return np.zeros((2*C, H, W), dtype=np.float32) # Delegate to GPU or CPU routine if self._device.type == "cuda" or self._device.type == "mps": return self.__mask_to_flow_gpu(mask) else: return self.__mask_to_flow_cpu(mask) def __mask_to_flow_gpu(self, mask: np.ndarray) -> np.ndarray: """Convert masks to flows using diffusion from center pixel. Center of masks where diffusion starts is defined by pixel closest to median within the mask. Args: masks (3D array): Labelled masks of shape (C, H, W). Returns: np.ndarray: A 3D array where for each channel the flows for each pixel are represented along the X and Y axes. """ channels, height, width = mask.shape flows = np.zeros((2*channels, height, width), np.float32) for channel in range(channels): mask_channel = mask[channel] if mask_channel.max() > 0: padded_height, padded_width = height + 2, width + 2 # Pad the mask with a 1-pixel border masks_padded = torch.from_numpy(mask_channel.astype(np.int64)).to(self._device) masks_padded = F.pad(masks_padded, (1, 1, 1, 1)) # Get coordinates of all non-zero pixels in the padded mask y, x = torch.nonzero(masks_padded, as_tuple=True) y = y.int(); x = x.int() # ensure integer type # Generate 8-connected neighbors (including center) via broadcasted offsets offsets = torch.tensor([ [ 0, 0], # center [-1, 0], # up [ 1, 0], # down [ 0, -1], # left [ 0, 1], # right [-1, -1], # up-left [-1, 1], # up-right [ 1, -1], # down-left [ 1, 1], # down-right ], dtype=torch.int32, device=self._device) # (9, 2) # coords: (N, 2) coords = torch.stack((y, x), dim=1) # neighbors: (9, N, 2) neighbors = offsets[:, None, :] + coords[None, :, :] # transpose into (2, 9, N) for the GPU kernel neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index # Build connectivity mask: True where neighbor label == center label center_labels = masks_padded[y, x][None, :] # (1, N) neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N) is_neighbor = neighbor_labels == center_labels # (9, N) # Compute object slices and pack into array for get_centers slices = find_objects(mask_channel) slices_arr = np.array([ [i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop] for i, sl in enumerate(slices, start=1) if sl is not None ], dtype=np.int16) # Compute centers (pixel indices) and extents via the provided helper centers, ext = self.__get_mask_centers_and_extents(mask_channel, slices_arr) # Move centers to GPU and shift by +1 for padding meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding # Determine number of diffusion iterations n_iter = 2 * ext.max() # Run the GPU diffusion kernel mu = self.__propagate_centers_gpu( neighbor_indices=neighbors, center_indices=meds_p.T, valid_neighbor_mask=is_neighbor, output_shape=(padded_height, padded_width), num_iterations=n_iter ) # Cast to float64 and normalize flow vectors mu = mu.astype(np.float64) mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60 # Remove the padding and write into final output flow_output = np.zeros((2, height, width), dtype=np.float32) ys_np = y.cpu().numpy() - 1 xs_np = x.cpu().numpy() - 1 flow_output[:, ys_np, xs_np] = mu.reshape(2, -1) flows[2*channel: 2*channel + 2] = flow_output return flows @staticmethod @njit(nogil=True) def __get_mask_centers_and_extents( label_map: np.ndarray, slices_arr: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Compute the centroids and extents of labeled regions in a 2D mask array. Args: label_map (np.ndarray): 2D array where each connected region has a unique integer label (1…K). slices_arr (np.ndarray): Array of shape (K, 5), where each row is (label_id, row_start, row_stop, col_start, col_stop). Returns: centers (np.ndarray): Integer array of shape (K, 2) with (row, col) center for each label. extents (np.ndarray): Integer array of shape (K,) giving the sum of height and width + 2 for each region. """ num_regions = slices_arr.shape[0] centers = np.zeros((num_regions, 2), dtype=np.int32) extents = np.zeros(num_regions, dtype=np.int32) for idx in prange(num_regions): # Unpack slice info label_id = slices_arr[idx, 0] row_start = slices_arr[idx, 1] row_stop = slices_arr[idx, 2] col_start = slices_arr[idx, 3] col_stop = slices_arr[idx, 4] # Extract binary submask for this label submask = (label_map[row_start:row_stop, col_start:col_stop] == label_id) # Get local coordinates of all pixels in the region ys, xs = np.nonzero(submask) # Compute the floating-point centroid within the submask y_mean = ys.mean() x_mean = xs.mean() # Find the pixel closest to the centroid by minimizing squared distance dist_sq = (ys - y_mean) ** 2 + (xs - x_mean) ** 2 closest_idx = dist_sq.argmin() # Convert to global coordinates center_row = ys[closest_idx] + row_start center_col = xs[closest_idx] + col_start centers[idx, 0] = center_row centers[idx, 1] = center_col # Compute extent as height + width + 2 (to include one-pixel border) height = row_stop - row_start width = col_stop - col_start extents[idx] = height + width + 2 return centers, extents def __propagate_centers_gpu( self, neighbor_indices: torch.Tensor, center_indices: torch.Tensor, valid_neighbor_mask: torch.Tensor, output_shape: Tuple[int, int], num_iterations: int = 200 ) -> np.ndarray: """ Propagates center points across a mask using GPU-based diffusion. Args: neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel. center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers. valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid. output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W). num_iterations (int, optional): Number of diffusion iterations. Defaults to 200. Returns: np.ndarray: Array of shape (2, N) with the computed flows. """ # Determine total number of elements and choose dtype accordingly total_elems = torch.prod(torch.tensor(output_shape)) if total_elems > 4e7 or self._device.type == "mps": diffusion_tensor = torch.zeros(output_shape, dtype=torch.float, device=self._device) else: diffusion_tensor = torch.zeros(output_shape, dtype=torch.double, device=self._device) # Unpack center row and column indices center_rows, center_cols = center_indices # Unpack neighbor row and column indices for 9 neighbors per pixel # Order: [0: center, 1: up, 2: down, 3: left, 4: right, # 5: up-left, 6: up-right, 7: down-left, 8: down-right] neigh_rows, neigh_cols = neighbor_indices # each of shape (9, N) # Perform diffusion iterations for _ in range(num_iterations): # Add source at each mask center diffusion_tensor[center_rows, center_cols] += 1 # Sample neighbor values for each pixel neighbor_vals = diffusion_tensor[neigh_rows, neigh_cols] # shape (9, N) # Zero out invalid neighbors neighbor_vals *= valid_neighbor_mask # Update the first neighbor (index 0) with the average of valid neighbor values diffusion_tensor[neigh_rows[0], neigh_cols[0]] = neighbor_vals.mean(dim=0) # Compute spatial gradients for 2D flow: dy and dx # Using neighbor indices: up = 1, down = 2, left = 3, right = 4 grad_samples = diffusion_tensor[ neigh_rows[[2, 1, 4, 3]], # indices [down, up, right, left] neigh_cols[[2, 1, 4, 3]] ] # shape (4, N) dy = grad_samples[0] - grad_samples[1] dx = grad_samples[2] - grad_samples[3] # Stack and convert to numpy flow field with shape (2, N) flow_field = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=0) return flow_field def __mask_to_flow_cpu(self, mask: np.ndarray) -> np.ndarray: """ Convert labeled masks to flow vectors by simulating diffusion from mask centers. Each mask's center is chosen as the pixel closest to its geometric centroid. A diffusion process is run on a padded local patch, and flows are derived as gradients (dy, dx) of the resulting density map. Args: masks (np.ndarray): 3D integer array of labels `(C x H x W)`, where 0 = background and positive integers = mask IDs. Returns: flow_field (np.ndarray): Array of shape `(2*C, H, W)` containing flow components [dy, dx] normalized per pixel. """ channels, height, width = mask.shape flows = np.zeros((2*channels, height, width), np.float32) for channel in range(channels): # Initialize flow_field with two channels: dy and dx flow_field = np.zeros((2, height, width), dtype=np.float64) mask_channel = mask[channel] # Find bounding box for each labeled mask mask_slices = find_objects(mask_channel) # centers: List[Tuple[int, int]] = [] # Iterate over mask labels in parallel for label_idx in prange(len(mask_slices)): slc = mask_slices[label_idx] if slc is None: continue # Extract row and column slice for this mask row_slice, col_slice = slc # Add 1-pixel border around the patch patch_height = (row_slice.stop - row_slice.start) + 2 patch_width = (col_slice.stop - col_slice.start) + 2 # Get local coordinates of mask pixels within the patch local_rows, local_cols = np.nonzero( mask_channel[row_slice, col_slice] == (label_idx + 1) ) # Shift coords by +1 for the border padding local_rows = local_rows.astype(np.int32) + 1 local_cols = local_cols.astype(np.int32) + 1 # Compute centroid and find nearest pixel as diffusion seed centroid_row = local_rows.mean() centroid_col = local_cols.mean() distances = (local_cols - centroid_col) ** 2 + (local_rows - centroid_row) ** 2 seed_index = distances.argmin() center_row = int(local_rows[seed_index]) center_col = int(local_cols[seed_index]) # Determine number of iterations total_iter = 2 * (patch_height + patch_width) # Initialize flat diffusion map for the local patch diffusion_map = np.zeros(patch_height * patch_width, dtype=np.float64) # Run diffusion from the seed center diffusion_map = self.__diffuse_from_center( diffusion_map, local_rows, local_cols, center_row, center_col, patch_width, total_iter ) # Compute flow as finite differences (gradient) on the diffusion map dy = ( diffusion_map[(local_rows + 1) * patch_width + local_cols] - diffusion_map[(local_rows - 1) * patch_width + local_cols] ) dx = ( diffusion_map[local_rows * patch_width + (local_cols + 1)] - diffusion_map[local_rows * patch_width + (local_cols - 1)] ) # Write flows back into the global flow_field array flow_field[0, row_slice.start + local_rows - 1, col_slice.start + local_cols - 1] = dy flow_field[1, row_slice.start + local_rows - 1, col_slice.start + local_cols - 1] = dx # Store center location in original image coordinates # centers.append( # (row_slice.start + center_row - 1, # col_slice.start + center_col - 1) # ) # Normalize each vector [dy,dx] by its magnitude magnitudes = np.sqrt((flow_field**2).sum(axis=0)) + 1e-60 flow_field /= magnitudes flows[2*channel: 2*channel + 2] = flow_field return flows @staticmethod @njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True) def __diffuse_from_center( diffusion_map: np.ndarray, row_coords: np.ndarray, col_coords: np.ndarray, center_row: int, center_col: int, patch_width: int, num_iterations: int ) -> np.ndarray: """ Perform diffusion of particles from a seed pixel across a local mask patch. At each iteration, one particle is added at the seed, and each mask pixel's value is updated to the average of itself and its 8-connected neighbors. Args: diffusion_map (np.ndarray): Flat array of length patch_height * patch_width. row_coords (np.ndarray): 1D array of row indices for mask pixels (local coords). col_coords (np.ndarray): 1D array of column indices for mask pixels (local coords). center_row (int): Row index of the seed point in local patch coords. center_col (int): Column index of the seed point in local patch coords. patch_width (int): Width (number of columns) in the local patch. num_iterations (int): Number of diffusion iterations to perform. Returns: np.ndarray: Updated diffusion_map after performing diffusion. """ # Compute linear indices for each mask pixel and its neighbors base_idx = row_coords * patch_width + col_coords up = (row_coords - 1) * patch_width + col_coords down = (row_coords + 1) * patch_width + col_coords left = row_coords * patch_width + (col_coords - 1) right = row_coords * patch_width + (col_coords + 1) up_left = (row_coords - 1) * patch_width + (col_coords - 1) up_right = (row_coords - 1) * patch_width + (col_coords + 1) down_left = (row_coords + 1) * patch_width + (col_coords - 1) down_right = (row_coords + 1) * patch_width + (col_coords + 1) for _ in range(num_iterations): # Inject one particle at the seed location diffusion_map[center_row * patch_width + center_col] += 1.0 # Update each mask pixel as the average over itself and neighbors diffusion_map[base_idx] = ( diffusion_map[base_idx] + diffusion_map[up] + diffusion_map[down] + diffusion_map[left] + diffusion_map[right] + diffusion_map[up_left] + diffusion_map[up_right] + diffusion_map[down_left] + diffusion_map[down_right] ) * (1.0 / 9.0) return diffusion_map def __segment_instances( self, probability_map: np.ndarray, flow: np.ndarray, prob_threshold: float = 0.0, flow_threshold: float = 0.4, num_iters: int = 200, min_object_size: int = 15 ) -> np.ndarray: """ Generate instance segmentation masks from probability and flow fields. Args: probability_map: 3D array `(C, H, W)` of cell probabilities. flow: 3D array `(2*C, H, W)` of forward flow vectors. prob_threshold: threshold to binarize probability_map. (Default 0.0) flow_threshold: threshold for filtering bad flow masks. (Default 0.4) num_iters: number of iterations for flow-following. (Default 200) min_object_size: minimum area to keep small instances. (Default 15) Returns: 3D array of uint16 instance labels for each channel. """ # Create a binary mask of likely cell locations probability_mask = probability_map > prob_threshold # If no cells exceed the threshold, return an empty mask if not np.any(probability_mask): logger.warning("No cell pixels found.") return np.zeros_like(probability_map, dtype=np.uint16) # Prepare output array for instance labels labeled_instances = np.zeros_like(probability_map, dtype=np.uint16) # Process each channel independently for channel_index in range(probability_mask.shape[0]): # Extract flow vectors for this channel (two components per channel) channel_flow_vectors = flow[2 * channel_index : 2 * channel_index + 2] # Extract binary mask for this channel channel_mask = probability_mask[channel_index] if not channel_mask.sum(): continue nonzero_coords = np.stack(np.nonzero(channel_mask)) # Follow the flow vectors to generate coordinate mappings flow_coordinates = self.__follow_flows( flow_field=channel_flow_vectors * channel_mask / 5.0, initial_coords=nonzero_coords, num_iters=num_iters ) if not torch.is_tensor(flow_coordinates): flow_coordinates = torch.from_numpy( flow_coordinates).to(self._device, dtype=torch.int32) else: flow_coordinates = flow_coordinates.int() # Obtain preliminary instance masks by clustering the coordinates channel_instances_mask = self.__get_mask( pixel_positions=flow_coordinates, valid_indices=nonzero_coords, original_shape=probability_map.shape[1:] ) # Filter out bad flow-derived instances if requested if channel_instances_mask.max() > 0 and flow_threshold > 0: channel_instances_mask = self.__remove_inconsistent_flow_masks( mask=channel_instances_mask, flow_network=channel_flow_vectors, error_threshold=flow_threshold ) # Remove small objects or holes below the minimum size if min_object_size > 0: # channel_instances_mask = morphology.remove_small_holes( # channel_instances_mask, area_threshold=min_object_size # ) # channel_instances_mask = morphology.remove_small_objects( # channel_instances_mask, min_size=min_object_size # ) channel_instances_mask = self.__fill_holes_and_prune_small_masks( channel_instances_mask, minimum_size=min_object_size ) labeled_instances[channel_index] = channel_instances_mask return labeled_instances def __follow_flows( self, flow_field: np.ndarray, initial_coords: np.ndarray, num_iters: int = 200 ) -> Union[np.ndarray, torch.Tensor]: """ Trace pixel positions through a flow field via iterative interpolation. Args: flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors. initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions. num_iters (int): Number of integration steps. Returns: np.ndarray or torch.Tensor: Final (y, x) positions of each point. """ dims = 2 # Extract spatial dimensions height, width = flow_field.shape[1:] # Choose GPU/MPS path if available if self._device.type in ("cuda", "mps"): # Prepare point tensor: shape [1, 1, num_points, 2] pts = torch.zeros((1, 1, initial_coords.shape[1], dims), dtype=torch.float32, device=self._device) # Prepare flow volume: shape [1, 2, height, width] flow_vol = torch.zeros((1, dims, height, width), dtype=torch.float32, device=self._device) # Load initial positions and flow into tensors (flip order for grid_sample) # dim 0 = x # dim 1 = y for i in range(dims): pts[0, 0, :, dims - i - 1] = ( torch.from_numpy(initial_coords[i]) .to(self._device, torch.float32) ) flow_vol[0, dims - i - 1] = ( torch.from_numpy(flow_field[i]) .to(self._device, torch.float32) ) # Prepare normalization factors for x and y (max index) max_indices = torch.tensor([width - 1, height - 1], dtype=torch.float32, device=self._device) # Reshape for broadcasting to point tensor dims max_idx_pt = max_indices.view(1, 1, 1, dims) # Reshape for broadcasting to flow volume dims max_idx_flow = max_indices.view(1, dims, 1, 1) # Normalize flow values to [-1, 1] range flow_vol = (flow_vol * 2) / max_idx_flow # Normalize points to [-1, 1] pts = (pts / max_idx_pt) * 2 - 1 # Iterate: sample flow and update points for _ in range(num_iters): sampled = torch.nn.functional.grid_sample( flow_vol, pts, align_corners=False ) # Update each coordinate and clamp to valid range for i in range(dims): pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0) # Denormalize back to original pixel coordinates pts = (pts + 1) * 0.5 * max_idx_pt # Swap channels back to (y, x) and flatten final_pts = pts[..., [1, 0]].squeeze() # Convert from (num_points, 2) to (2, num_points) return final_pts.T if final_pts.ndim > 1 else final_pts.unsqueeze(0).T # CPU fallback using numpy and scipy current_pos = initial_coords.copy().astype(np.float32) temp_delta = np.zeros_like(current_pos, dtype=np.float32) for _ in range(num_iters): # Interpolate flow at current positions self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta) # Update positions and clamp to image bounds current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1) current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1) return current_pos @staticmethod @njit([ "(int16[:,:,:], float32[:], float32[:], float32[:,:])", "(float32[:,:,:], float32[:], float32[:], float32[:,:])" ], cache=True) def __map_coordinates( image_data: np.ndarray, y_coords: np.ndarray, x_coords: np.ndarray, output: np.ndarray ) -> None: """ Perform in-place bilinear interpolation on an image volume. Args: image_data (np.ndarray): Input volume with shape (C, H, W). y_coords (np.ndarray): Array of new y positions (num_points). x_coords (np.ndarray): Array of new x positions (num_points). output (np.ndarray): Output array of shape (C, num_points) to fill. Returns: None. Results written directly into `output`. """ channels, height, width = image_data.shape # Compute integer (floor) and fractional parts for coords y_floor = y_coords.astype(np.int32) x_floor = x_coords.astype(np.int32) y_frac = y_coords - y_floor x_frac = x_coords - x_floor # Loop over each sample point for idx in range(y_floor.shape[0]): # Clamp base indices to valid range y0 = min(max(y_floor[idx], 0), height - 1) x0 = min(max(x_floor[idx], 0), width - 1) y1 = min(y0 + 1, height - 1) x1 = min(x0 + 1, width - 1) wy = y_frac[idx] wx = x_frac[idx] # Interpolate per channel for c in range(channels): v00 = np.float32(image_data[c, y0, x0]) v10 = np.float32(image_data[c, y0, x1]) v01 = np.float32(image_data[c, y1, x0]) v11 = np.float32(image_data[c, y1, x1]) # Bilinear interpolation formula output[c, idx] = ( v00 * (1 - wy) * (1 - wx) + v10 * (1 - wy) * wx + v01 * wy * (1 - wx) + v11 * wy * wx ) def __get_mask( self, pixel_positions: torch.Tensor, valid_indices: np.ndarray, original_shape: Tuple[int, ...], pad_radius: int = 20, max_size_fraction: float = 0.4 ) -> np.ndarray: """ Generate labeled masks by clustering pixel trajectories via histogram peaks and region growing. This function executes the following steps: 1. Pads and clamps pixel final positions to avoid border effects. 2. Builds a dense histogram of pixel counts over spatial bins. 3. Identifies local maxima in the histogram as seed points. 4. Extracts local patches around each seed and grows regions by iteratively adding neighbors that exceed an intensity threshold. 5. Maps grown patches back to original image indices. 6. Removes any masks that exceed a maximum size fraction of the image. Args: pixel_positions (torch.Tensor): Tensor of shape `[2, N_pixels]`, dtype=int, containing final pixel coordinates after dynamics for each dimension. valid_indices (np.ndarray): Integer array of shape `[2, N_pixels]` giving indices of pixels in the original image grid. original_shape (tuple of ints): Spatial dimensions of the original image, e.g. (H, W). pad_radius (int): Number of zero-padding pixels added on each side of the histogram. Defaults to 20. max_size_fraction (float): If any mask has a pixel count > max_size_fraction * total_pixels, it will be removed. Defaults to 0.4. Returns: np.ndarray: Integer mask array of shape `original_shape` with labels 0 (background) and 1..M. Raises: ValueError: If input dimensions are inconsistent or pixel_positions shape is invalid. """ # Validate inputs ndim = len(original_shape) if pixel_positions.ndim != 2 or pixel_positions.size(0) != ndim: msg = f"pixel_positions must be shape [{ndim}, N], got {tuple(pixel_positions.shape)}" logger.error(msg) raise ValueError(msg) if pad_radius < 0: msg = f"pad_radius must be non-negative, got {pad_radius}" logger.error(msg) raise ValueError(msg) # Step 1: Pad and clamp pixel positions padded_positions = pixel_positions.clone().to(torch.int64) + pad_radius for dim in range(ndim): max_val = original_shape[dim] + pad_radius - 1 padded_positions[dim] = torch.clamp(padded_positions[dim], min=0, max=max_val) # Build histogram dimensions hist_shape = tuple(s + 2 * pad_radius for s in original_shape) # Step 2: Create sparse tensor and densify to get per-pixel counts try: counts_sparse = torch.sparse_coo_tensor( padded_positions, torch.ones(padded_positions.shape[1], dtype=torch.int32, device=pixel_positions.device), size=hist_shape ) histogram = counts_sparse.to_dense() except Exception as e: logger.error("Failed to build dense histogram: %s", e) raise # Step 3: Find peaks via 5x5 max-pooling # k = 5 # pooled = F.max_pool2d( # histogram.float().unsqueeze(0).unsqueeze(1), # kernel_size=k, # stride=1, # padding=k // 2 # ).squeeze() pooled = self.__max_pool_nd( histogram.unsqueeze(0), kernel_size=5 ).squeeze() # Seeds are positions where histogram equals local max and count > threshold seed_positions = torch.nonzero((histogram - pooled == 0) * (histogram > 10)) if seed_positions.shape[0] == 0: logger.warning("No seeds found: returning empty mask") return np.zeros(original_shape, dtype=np.uint16) # Sort seeds by ascending count to process small peaks first seed_counts = histogram[tuple(seed_positions.T)] order = torch.argsort(seed_counts) seed_positions = seed_positions[order] del pooled, counts_sparse # Step 4: Extract local patches and perform region growing num_seeds = seed_positions.shape[0] # Tensor to hold local patches patches = torch.zeros((num_seeds, 11, 11), device=pixel_positions.device) for idx in range(num_seeds): coords = seed_positions[idx] slices = tuple(slice(c - 5, c + 6) for c in coords) patches[idx] = histogram[slices] del histogram # Initialize seed mask (center pixel of each patch) seed_masks = torch.zeros_like(patches, device=pixel_positions.device) seed_masks[:, 5, 5] = 1 # Iterative dilation and thresholding for _ in range(5): # seed_masks = F.max_pool2d( # seed_masks.float().unsqueeze(0), # kernel_size=3, # stride=1, # padding=1 # ).squeeze(0).int() seed_masks = self.__max_pool_nd(seed_masks, kernel_size=3) seed_masks *= (patches > 2) # Compute final mask coordinates final_coords = [] for idx in range(num_seeds): coords_local = torch.nonzero(seed_masks[idx]) # Shift back to global positions coords_global = coords_local + seed_positions[idx] - 5 final_coords.append(tuple(coords_global.T)) # Step 5: Paint masks into padded volume dtype = torch.int32 if num_seeds < 2**16 else torch.int64 mask_padded = torch.zeros(hist_shape, dtype=dtype, device=pixel_positions.device) for label_idx, coords in enumerate(final_coords, start=1): mask_padded[coords] = label_idx # Extract only the padded positions that correspond to original pixels mask_values = mask_padded[tuple(padded_positions)] mask_values = mask_values.cpu().numpy() # Step 6: Map to original image and remove oversized masks mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32) mask_final[tuple(valid_indices)] = mask_values # Prune masks that are too large labels, counts = fastremap.unique(mask_final, return_counts=True) total_pixels = np.prod(original_shape) oversized = labels[counts > (total_pixels * max_size_fraction)] if oversized.size > 0: mask_final = fastremap.mask(mask_final, oversized) fastremap.renumber(mask_final, in_place=True) return mask_final def __max_pool1d( self, input_tensor: Tensor, kernel_size: int = 5, axis: int = 1, output_tensor: Optional[Tensor] = None ) -> Tensor: """ Memory-efficient 1D max pooling along a specified axis using in-place updates. Requires: - stride = 1 - padding = kernel_size // 2 - odd kernel_size >= 3 Args: input_tensor (Tensor): Source tensor for pooling. kernel_size (int): Size of the pooling window (must be odd and >= 3). axis (int): Axis along which to compute 1D max pooling. output_tensor (Optional[Tensor]): Tensor to store the result. If None, a clone of input_tensor is used. Returns: Tensor: The pooled tensor, same shape as input_tensor. """ # Initialize or copy data into the output tensor if output_tensor is None: output = input_tensor.clone() else: output = output_tensor output.copy_(input_tensor) # Number of elements along the chosen axis and half-window size dimension_size = input_tensor.shape[axis] half_window = kernel_size // 2 # Slide window offsets from -half_window to +half_window for offset in range(-half_window, half_window + 1): # Compute slice indices depending on axis if axis == 1: target_slice = output[:, max(-offset, 0): min(dimension_size - offset, dimension_size)] source_slice = input_tensor[:, max(offset, 0): min(dimension_size + offset, dimension_size)] elif axis == 2: target_slice = output[:, :, max(-offset, 0): min(dimension_size - offset, dimension_size)] source_slice = input_tensor[:, :, max(offset, 0): min(dimension_size + offset, dimension_size)] elif axis == 3: target_slice = output[:, :, :, max(-offset, 0): min(dimension_size - offset, dimension_size)] source_slice = input_tensor[:, :, :, max(offset, 0): min(dimension_size + offset, dimension_size)] else: raise ValueError(f"Unsupported axis {axis} for 1D pooling") # In-place element-wise maximum torch.maximum(target_slice, source_slice, out=target_slice) return output def __max_pool_nd( self, input_tensor: Tensor, kernel_size: int = 5 ) -> Tensor: """ Memory-efficient N-dimensional max pooling for 2D or 3D spatial data. Applies 1D max pooling sequentially over each spatial axis. Args: input_tensor (Tensor): Input tensor with shape (batch_size, dim1, dim2, ..., dimN). kernel_size (int): Size of the pooling window (must be odd and >= 3). Returns: Tensor: The pooled tensor, same shape as input_tensor. """ # Determine number of spatial dimensions (excluding batch axis) num_spatial_dims = input_tensor.ndim - 1 # First pass: pool along axis=1 pooled = self.__max_pool1d(input_tensor, kernel_size=kernel_size, axis=1) # Second pass: pool along axis=2 pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=2) # If 3D data, apply a third pass along axis=3 if num_spatial_dims == 3: pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=3) elif num_spatial_dims != 2: raise ValueError("max_pool_nd only supports 2D or 3D spatial data") return pooled def __remove_inconsistent_flow_masks( self, mask: np.ndarray, flow_network: np.ndarray, error_threshold: float = 0.4 ) -> np.ndarray: """ Remove labeled masks that have inconsistent optical flows compared to network-predicted flows. This performs a quality control step by computing flows from the provided masks and comparing them to the flows predicted by the network. Masks with a mean squared flow error above `error_threshold` are discarded (set to 0). Args: mask (np.ndarray): Integer mask array with shape [H, W]. Values: 0 = no mask; 1,2,... = mask labels. flow_network (np.ndarray): Float array of network-predicted flows with shape [2, H, W]. error_threshold (float): Maximum allowed mean squared flow error per mask label. Defaults to 0.4. Returns: np.ndarray: The input mask with inconsistent masks removed (labels set to 0). Raises: MemoryError: If the mask size exceeds available GPU memory. """ # If mask is very large and running on CUDA, check memory num_pixels = mask.size if num_pixels > 10000 * 10000 and self._device.type == 'cuda': # Clear unused GPU cache torch.cuda.empty_cache() # Determine PyTorch version major, minor = map(int, torch.__version__.split('.')[:2]) # Determine current CUDA device index device_index = ( self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device() ) # Get free and total memory if major == 1 and minor < 10: total_mem = torch.cuda.get_device_properties(device_index).total_memory used_mem = torch.cuda.memory_allocated(device_index) free_mem = total_mem - used_mem else: free_mem, total_mem = torch.cuda.mem_get_info(device_index) # Estimate required memory for mask-based flow computation # Assume float32 per pixel required_bytes = num_pixels * np.dtype(np.float32).itemsize if required_bytes > free_mem: logger.error( 'Image too large for GPU memory in flow QC step (required: %d B, available: %d B)', required_bytes, free_mem ) raise MemoryError('Insufficient GPU memory for flow QC computation') # Compute flow errors per mask label flow_errors, _ = self.__compute_flow_error(mask, flow_network) # Identify labels with error above threshold bad_labels = np.nonzero(flow_errors > error_threshold)[0] + 1 # Remove bad masks by setting their label to 0 mask[np.isin(mask, bad_labels)] = 0 return mask def __compute_flow_error( self, mask: np.ndarray, flow_network: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Compute mean squared error between network-predicted flows and flows derived from masks. Args: mask (np.ndarray): Integer masks, shape must match flow_network spatial dims. flow_network (np.ndarray): Network predicted flows of shape [axis, ...]. Returns: Tuple[np.ndarray, np.ndarray]: - flow_errors: 1D array (length = max label) of mean squared error per label. - computed_flows: Array of flows derived from the mask, same shape as flow_network. Raises: ValueError: If the spatial dimensions of `mask_array` and `flow_network` do not match. """ # Ensure mask and flow shapes match if flow_network.shape[1:] != mask.shape: logger.error( 'Shape mismatch: network flow shape %s vs mask shape %s', flow_network.shape[1:], mask.shape ) raise ValueError('Network flow and mask shapes must match') # Compute flows from mask labels (user-provided function) computed_flows = self.__compute_flow_from_mask(mask[None, ...]) # Prepare array for errors (one value per mask label) num_labels = int(mask.max()) flow_errors = np.zeros(num_labels, dtype=float) # Accumulate mean squared error over each flow axis for axis_index in range(computed_flows.shape[0]): # MSE per label: mean((computed - predicted/5)^2) flow_errors += mean( (computed_flows[axis_index] - flow_network[axis_index] / 5.0) ** 2, mask, index=np.arange(1, num_labels + 1) ) return flow_errors, computed_flows def __fill_holes_and_prune_small_masks( self, masks: np.ndarray, minimum_size: int = 15 ) -> np.ndarray: """ Fill holes in labeled masks and remove masks smaller than a given size. This function performs two steps: 1. Fills internal holes in each labeled mask using `fill_voids.fill`. 2. Discards any mask whose pixel count is below `minimum_size`. Args: masks (np.ndarray): Integer mask array of dimension 2 or 3 (shape [H, W] or [D, H, W]). Values: 0 = background; 1,2,... = mask labels. minimum_size (int): Minimum number of pixels required to keep a mask. Masks smaller than this will be removed. Set to -1 to skip size-based pruning. Defaults to 15. Returns: np.ndarray: Processed mask array with holes filled and small masks removed. Raises: ValueError: If `masks` is not a 2D or 3D integer array. """ # Validate input dimensions if masks.ndim not in (2, 3): msg = f"Expected 2D or 3D mask array, got {masks.ndim}D." logger.error(msg) raise ValueError(msg) # Initial pruning of too-small masks masks = self._prune_small_masks(masks, minimum_size) # Find bounding boxes for each mask label object_slices = find_objects(masks) new_label = 1 output_masks = np.zeros_like(masks, dtype=masks.dtype) # Loop over each original slice, fill holes, and assign new labels for original_label, slc in enumerate(object_slices, start=1): if slc is None: continue # Extract sub-volume or sub-image region = masks[slc] == original_label if not np.any(region): continue # Fill internal holes filled_region = fill_voids.fill(region) # Write back into output mask with sequential labels output_masks[slc][filled_region] = new_label new_label += 1 # Final pruning after hole filling output_masks = self._prune_small_masks(output_masks, minimum_size) return output_masks def _prune_small_masks( self, masks: np.ndarray, minimum_size: int ) -> np.ndarray: """ Remove labeled regions in `masks` whose pixel count is below `minimum_size`. Args: masks (np.ndarray): Integer mask array (any shape), 0=background. minimum_size (int): Minimum pixel count; labels smaller are removed. If <0, skip pruning. Returns: np.ndarray: Mask array with small labels suppressed and labels renumbered. """ if minimum_size < 0: return masks labels, counts = fastremap.unique(masks, return_counts=True) # Skip background label at index 0 non_bg_labels = labels[1:] non_bg_counts = counts[1:] # Identify labels to remove small_labels = non_bg_labels[non_bg_counts < minimum_size] if small_labels.size > 0: masks = fastremap.mask(masks, small_labels) fastremap.renumber(masks, in_place=True) return masks