import torch import random import numpy as np from numba import njit, prange import torch from torch import Tensor import torch.nn.functional as F from torch.utils.data import DataLoader import fastremap from skimage import morphology from scipy.ndimage import mean, find_objects import fill_voids from monai.data.dataset import Dataset from monai.transforms import * # type: ignore import os import glob from pprint import pformat from typing import Optional, Tuple, List, Union from config import Config from core.models import * from core.losses import * from core.optimizers import * from core.schedulers import * from core.logger import get_logger logger = get_logger() class CellSegmentator: def __init__(self, config: Config) -> None: self.__set_seed(config.dataset_config.common.seed) self.__parse_config(config) self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu") self._train_dataloader: Optional[DataLoader] = None self._valid_dataloader: Optional[DataLoader] = None self._test_dataloader: Optional[DataLoader] = None self._predict_dataloader: Optional[DataLoader] = None def create_dataloaders( self, train_transforms: Optional[Compose] = None, valid_transforms: Optional[Compose] = None, test_transforms: Optional[Compose] = None, predict_transforms: Optional[Compose] = None ) -> None: """ Creates train, validation, test, and prediction dataloaders based on dataset configuration and provided transforms. Args: train_transforms (Optional[Compose]): Transformations for training data. valid_transforms (Optional[Compose]): Transformations for validation data. test_transforms (Optional[Compose]): Transformations for testing data. predict_transforms (Optional[Compose]): Transformations for prediction data. Raises: ValueError: If required transforms are missing. RuntimeError: If critical dataset config values are missing. """ if self._dataset_setup.is_training and train_transforms is None: raise ValueError("Training mode requires 'train_transforms' to be provided.") elif not self._dataset_setup.is_training and test_transforms is None and predict_transforms is None: raise ValueError("In inference mode, at least one of 'test_transforms' or 'predict_transforms' must be provided.") if self._dataset_setup.is_training: # Training mode: handle either pre-split datasets or splitting on the fly if self._dataset_setup.training.is_split: # Validate presence of validation transforms if validation directory and size are set if ( self._dataset_setup.training.pre_split.valid_dir and self._dataset_setup.training.valid_size and valid_transforms is None ): raise ValueError("Validation transforms must be provided when using pre-split validation data.") # Use explicitly split directories train_dir = self._dataset_setup.training.pre_split.train_dir valid_dir = self._dataset_setup.training.pre_split.valid_dir test_dir = self._dataset_setup.training.pre_split.test_dir train_offset = self._dataset_setup.training.train_offset valid_offset = self._dataset_setup.training.valid_offset test_offset = self._dataset_setup.training.test_offset shuffle = False else: # Same validation for split mode with full data directory if ( self._dataset_setup.training.split.all_data_dir and self._dataset_setup.training.valid_size and valid_transforms is None ): raise ValueError("Validation transforms must be provided when splitting dataset.") # Automatically split dataset from one directory train_dir = valid_dir = test_dir = self._dataset_setup.training.split.all_data_dir number_of_images = len(os.listdir(os.path.join(train_dir, 'images'))) if number_of_images == 0: raise FileNotFoundError(f"No images found in '{train_dir}/images'") # Calculate train/valid sizes train_size = ( self._dataset_setup.training.train_size if isinstance(self._dataset_setup.training.train_size, int) else int(number_of_images * self._dataset_setup.training.train_size) ) valid_size = ( self._dataset_setup.training.valid_size if isinstance(self._dataset_setup.training.valid_size, int) else int(number_of_images * self._dataset_setup.training.valid_size) ) train_offset = self._dataset_setup.training.train_offset valid_offset = self._dataset_setup.training.valid_offset + train_size test_offset = self._dataset_setup.training.test_offset + train_size + valid_size shuffle = True # Train dataloader train_dataset = self.__get_dataset( images_dir=os.path.join(train_dir, 'images'), masks_dir=os.path.join(train_dir, 'masks'), transforms=train_transforms, # type: ignore size=self._dataset_setup.training.train_size, offset=train_offset, shuffle=shuffle ) self._train_dataloader = DataLoader(train_dataset, batch_size=self._dataset_setup.training.batch_size, shuffle=True) logger.info(f"Loaded training dataset with {len(train_dataset)} samples.") # Validation dataloader if valid_transforms is not None: if not valid_dir or not self._dataset_setup.training.valid_size: raise RuntimeError("Validation directory or size is not properly configured.") valid_dataset = self.__get_dataset( images_dir=os.path.join(valid_dir, 'images'), masks_dir=os.path.join(valid_dir, 'masks'), transforms=valid_transforms, size=self._dataset_setup.training.valid_size, offset=valid_offset, shuffle=shuffle ) self._valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded validation dataset with {len(valid_dataset)} samples.") # Test dataloader if test_transforms is not None: if not test_dir or not self._dataset_setup.training.test_size: raise RuntimeError("Test directory or size is not properly configured.") test_dataset = self.__get_dataset( images_dir=os.path.join(test_dir, 'images'), masks_dir=os.path.join(test_dir, 'masks'), transforms=test_transforms, size=self._dataset_setup.training.test_size, offset=test_offset, shuffle=shuffle ) self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded test dataset with {len(test_dataset)} samples.") # Prediction dataloader if predict_transforms is not None: if not test_dir or not self._dataset_setup.training.test_size: raise RuntimeError("Prediction directory or size is not properly configured.") predict_dataset = self.__get_dataset( images_dir=os.path.join(test_dir, 'images'), masks_dir=None, transforms=predict_transforms, size=self._dataset_setup.training.test_size, offset=test_offset, shuffle=shuffle ) self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.") else: # Inference mode (no training) test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images') test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks') if test_transforms is not None: test_dataset = self.__get_dataset( images_dir=test_images, masks_dir=test_masks, transforms=test_transforms, size=self._dataset_setup.testing.test_size, offset=self._dataset_setup.testing.test_offset, shuffle=self._dataset_setup.testing.shuffle ) self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded test dataset with {len(test_dataset)} samples.") if predict_transforms is not None: predict_dataset = self.__get_dataset( images_dir=test_images, masks_dir=None, transforms=predict_transforms, size=self._dataset_setup.testing.test_size, offset=self._dataset_setup.testing.test_offset, shuffle=self._dataset_setup.testing.shuffle ) self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.") def train(self) -> None: pass def evaluate(self) -> None: pass def predict(self) -> None: pass def __parse_config(self, config: Config) -> None: """ Parses the given configuration object to initialize model, criterion, optimizer, scheduler, and dataset setup. Args: config (Config): Configuration object with model, optimizer, scheduler, criterion, and dataset setup information. """ model = config.model criterion = config.criterion optimizer = config.optimizer scheduler = config.scheduler logger.info("========== Parsed Configuration ==========") logger.info("Model Config:\n%s", pformat(model.dump(), indent=2)) if criterion: logger.info("Criterion Config:\n%s", pformat(criterion.dump(), indent=2)) if optimizer: logger.info("Optimizer Config:\n%s", pformat(optimizer.dump(), indent=2)) if scheduler: logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2)) logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2)) logger.info("==========================================") # Initialize model using the model registry self._model = ModelRegistry.get_model_class(model.name)(model.params) # Initialize loss criterion if specified self._criterion = ( CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params) if criterion is not None else None ) # Initialize optimizer if specified self._optimizer = ( OptimizerRegistry.get_optimizer_class(optimizer.name)( model_params=self._model.parameters(), optim_params=optimizer.params ) if optimizer is not None else None ) # Initialize scheduler only if both scheduler and optimizer are defined self._scheduler = ( SchedulerRegistry.get_scheduler_class(scheduler.name)( optimizer=self._optimizer.optim, params=scheduler.params ) if scheduler is not None and self._optimizer is not None and self._optimizer.optim is not None else None ) logger.info("========== Model Components Initialization ==========") logger.info("├─ Model: " + (f"{model.name}" if self._model else "Not specified")) logger.info("├─ Criterion: " + (f"{criterion.name}" if self._criterion else "Not specified")) # type: ignore logger.info("├─ Optimizer: " + (f"{optimizer.name}" if self._optimizer else "Not specified")) # type: ignore logger.info("└─ Scheduler: " + (f"{scheduler.name}" if self._scheduler else "Not specified")) # type: ignore logger.info("=====================================================") # Save dataset config self._dataset_setup = config.dataset_config common = config.dataset_config.common logger.info("========== Dataset Setup ==========") logger.info("[COMMON]") logger.info(f"├─ Seed: {common.seed}") logger.info(f"├─ Device: {common.device}") logger.info(f"└─ Predictions output dir: {common.predictions_dir}") if config.dataset_config.is_training: training = config.dataset_config.training logger.info("[MODE] Training") logger.info(f"├─ Batch size: {training.batch_size}") logger.info(f"├─ Epochs: {training.num_epochs}") logger.info(f"├─ Validation frequency: {training.val_freq}") logger.info(f"├─ Use AMP: {'yes' if training.use_amp else 'no'}") logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}") if training.is_split: logger.info(f"├─ Using pre-split directories:") logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}") logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}") logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}") else: logger.info(f"├─ Using unified dataset with splits:") logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}") logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}") logger.info(f"└─ Dataset split:") logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}") logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}") logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}") else: testing = config.dataset_config.testing logger.info("[MODE] Inference") logger.info(f"├─ Test dir: {testing.test_dir}") logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})") logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}") logger.info(f"├─ Use ensemble: {'yes' if testing.use_ensemble else 'no'}") logger.info(f"└─ Pretrained weights:") logger.info(f" ├─ Single model: {testing.pretrained_weights}") logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}") logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}") logger.info("===================================") def __set_seed(self, seed: Optional[int]) -> None: """ Sets the random seed for reproducibility across Python, NumPy, and PyTorch. Args: seed (Optional[int]): Seed value. If None, no seeding is performed. """ if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False logger.info(f"Random seed set to {seed}") else: logger.info("Seed not set (None provided)") def __get_dataset( self, images_dir: str, masks_dir: Optional[str], transforms: Compose, size: Union[int, float], offset: int, shuffle: bool ) -> Dataset: """ Loads and returns a dataset object from image and optional mask directories. Args: images_dir (str): Path to directory or glob pattern for input images. masks_dir (Optional[str]): Path to directory or glob pattern for masks. transforms (Compose): Transformations to apply to each image or pair. size (Union[int, float]): Either an integer or a fraction of the dataset. offset (int): Number of images to skip from the start. shuffle (bool): Whether to shuffle the dataset before slicing. Returns: Dataset: A dataset containing image and optional mask paths. Raises: FileNotFoundError: If no images are found. ValueError: If masks are provided but do not match image count. ValueError: If dataset is too small for requested size or offset. """ # Collect sorted list of image paths images = sorted(glob.glob(images_dir)) if not images: raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'") if masks_dir is not None: # Collect and validate sorted list of mask paths masks = sorted(glob.glob(masks_dir)) if len(images) != len(masks): raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})") # Convert float size (fraction) to absolute count size = size if isinstance(size, int) else int(size * len(images)) if size <= 0: raise ValueError(f"Size must be positive, got: {size}") if len(images) < size: raise ValueError(f"Not enough images ({len(images)}) for requested size ({size})") if len(images) < size + offset: raise ValueError(f"Offset ({offset}) + size ({size}) exceeds dataset length ({len(images)})") # Shuffle image-mask pairs if requested if shuffle: if masks_dir is not None: combined = list(zip(images, masks)) # type: ignore random.shuffle(combined) images, masks = zip(*combined) else: random.shuffle(images) # Apply offset and limit by size images = images[offset: offset + size] if masks_dir is not None: masks = masks[offset: offset + size] # type: ignore # Prepare data structure for Dataset class if masks_dir is not None: data = [ {"image": image, "mask": mask} for image, mask in zip(images, masks) # type: ignore ] else: data = [{"image": image} for image in images] return Dataset(data, transforms) def __compute_flows_from_masks( self, true_masks: Tensor ) -> np.ndarray: """ Convert segmentation masks to flow fields for training. Args: true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks. Returns: numpy array of concatenated [renumbered_true_masks, binary_masks, flow_vectors] per image. renumbered_true_masks is labels, binary_masks is cell distance transform, flow_vectors[2] is Y flow, flows[k][3] is X flow, and flow_vectors[4] is heat distribution. """ # Move to CPU numpy _true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16) batch_size = _true_masks.shape[0] # Ensure each label has a channel dimension if _true_masks.ndim == 3: # shape (batch, H, W) -> (batch, 1, H, W) _true_masks = _true_masks[:, np.newaxis, :, :] batch_size, *_ = _true_masks.shape # Renumber labels to ensure uniqueness renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0] for i in range(batch_size)]) # Compute vector flows per image flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i]) for i in range(batch_size)]) return np.concatenate((renumbered, renumbered > 0.5, flow_vectors), axis=1).astype(np.float32) def __compute_flow_from_mask( self, mask: np.ndarray ) -> np.ndarray: """ Compute normalized flow vectors from a labeled mask. Args: mask: 3D array of instance-labeled mask of shape (C, H, W). Returns: flow: Array of shape (2 * C, H, W). """ if mask.max() == 0 or np.count_nonzero(mask) <= 1: # No flow to compute logger.warning("Empty mask!") C, H, W = mask.shape return np.zeros((2*C, H, W), dtype=np.float32) # Delegate to GPU or CPU routine if self._device.type == "cuda" or self._device.type == "mps": return self.__mask_to_flow_gpu(mask) else: return self.__mask_to_flow_cpu(mask) def __mask_to_flow_gpu(self, mask: np.ndarray) -> np.ndarray: """Convert masks to flows using diffusion from center pixel. Center of masks where diffusion starts is defined by pixel closest to median within the mask. Args: masks (3D array): Labelled masks of shape (C, H, W). Returns: np.ndarray: A 3D array where for each channel the flows for each pixel are represented along the X and Y axes. """ channels, height, width = mask.shape flows = np.zeros((2*channels, height, width), np.float32) for channel in range(channels): padded_height, padded_width = height + 2, width + 2 # Pad the mask with a 1-pixel border masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device) masks_padded = F.pad(masks_padded, (1, 1, 1, 1)) # Get coordinates of all non-zero pixels in the padded mask y, x = torch.nonzero(masks_padded, as_tuple=True) y = y.int(); x = x.int() # ensure integer type # Generate 8-connected neighbors (including center) via broadcasted offsets offsets = torch.tensor([ [ 0, 0], # center [-1, 0], # up [ 1, 0], # down [ 0, -1], # left [ 0, 1], # right [-1, -1], # up-left [-1, 1], # up-right [ 1, -1], # down-left [ 1, 1], # down-right ], dtype=torch.int32, device=self._device) # (9, 2) # coords: (N, 2) coords = torch.stack((y, x), dim=1) # neighbors: (9, N, 2) neighbors = offsets[:, None, :] + coords[None, :, :] # transpose into (2, 9, N) for the GPU kernel neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index # Build connectivity mask: True where neighbor label == center label center_labels = masks_padded[y, x][None, :] # (1, N) neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N) is_neighbor = neighbor_labels == center_labels # (9, N) # Compute object slices and pack into array for get_centers slices = find_objects(mask) slices_arr = np.array([ [i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop] for i, sl in enumerate(slices) if sl is not None ], dtype=int) # Compute centers (pixel indices) and extents via the provided helper centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr) # Move centers to GPU and shift by +1 for padding meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding # Determine number of diffusion iterations n_iter = 2 * ext.max() # Run the GPU diffusion kernel mu = self.__propagate_centers_gpu( neighbor_indices=neighbors, center_indices=meds_p.T, valid_neighbor_mask=is_neighbor, output_shape=(padded_height, padded_width), num_iterations=n_iter ) # Cast to float64 and normalize flow vectors mu = mu.astype(np.float64) mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60 # Remove the padding and write into final output flow_output = np.zeros((2, height, width), dtype=np.float32) ys_np = y.cpu().numpy() - 1 xs_np = x.cpu().numpy() - 1 flow_output[:, ys_np, xs_np] = mu flows[2*channel: 2*channel + 2] = flow_output return flows @staticmethod @njit(nogil=True) def __get_mask_centers_and_extents( label_map: np.ndarray, slices_arr: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Compute the centroids and extents of labeled regions in a 2D mask array. Args: label_map (np.ndarray): 2D array where each connected region has a unique integer label (1…K). slices_arr (np.ndarray): Array of shape (K, 5), where each row is (label_id, row_start, row_stop, col_start, col_stop). Returns: centers (np.ndarray): Integer array of shape (K, 2) with (row, col) center for each label. extents (np.ndarray): Integer array of shape (K,) giving the sum of height and width + 2 for each region. """ num_regions = slices_arr.shape[0] centers = np.zeros((num_regions, 2), dtype=np.int32) extents = np.zeros(num_regions, dtype=np.int32) for idx in prange(num_regions): # Unpack slice info label_id = slices_arr[idx, 0] row_start = slices_arr[idx, 1] row_stop = slices_arr[idx, 2] col_start = slices_arr[idx, 3] col_stop = slices_arr[idx, 4] # Extract binary submask for this label submask = (label_map[row_start:row_stop, col_start:col_stop] == label_id) # Get local coordinates of all pixels in the region ys, xs = np.nonzero(submask) # Compute the floating-point centroid within the submask y_mean = ys.mean() x_mean = xs.mean() # Find the pixel closest to the centroid by minimizing squared distance dist_sq = (ys - y_mean) ** 2 + (xs - x_mean) ** 2 closest_idx = dist_sq.argmin() # Convert to global coordinates center_row = ys[closest_idx] + row_start center_col = xs[closest_idx] + col_start centers[idx, 0] = center_row centers[idx, 1] = center_col # Compute extent as height + width + 2 (to include one-pixel border) height = row_stop - row_start width = col_stop - col_start extents[idx] = height + width + 2 return centers, extents def __propagate_centers_gpu( self, neighbor_indices: torch.Tensor, center_indices: torch.Tensor, valid_neighbor_mask: torch.Tensor, output_shape: Tuple[int, int], num_iterations: int = 200 ) -> np.ndarray: """ Propagates center points across a mask using GPU-based diffusion. Args: neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel. center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers. valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid. output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W). num_iterations (int, optional): Number of diffusion iterations. Defaults to 200. Returns: np.ndarray: Array of shape (2, N) with the computed flows. """ # Determine total number of elements and choose dtype accordingly total_elems = torch.prod(torch.tensor(output_shape)) if total_elems > 4e7 or self._device.type == "mps": diffusion_tensor = torch.zeros(output_shape, dtype=torch.float, device=self._device) else: diffusion_tensor = torch.zeros(output_shape, dtype=torch.double, device=self._device) # Unpack center row and column indices center_rows, center_cols = center_indices # Unpack neighbor row and column indices for 9 neighbors per pixel # Order: [0: center, 1: up, 2: down, 3: left, 4: right, # 5: up-left, 6: up-right, 7: down-left, 8: down-right] neigh_rows, neigh_cols = neighbor_indices # each of shape (9, N) # Perform diffusion iterations for _ in range(num_iterations): # Add source at each mask center diffusion_tensor[center_rows, center_cols] += 1 # Sample neighbor values for each pixel neighbor_vals = diffusion_tensor[neigh_rows, neigh_cols] # shape (9, N) # Zero out invalid neighbors neighbor_vals *= valid_neighbor_mask # Update the first neighbor (index 0) with the average of valid neighbor values diffusion_tensor[neigh_rows[0], neigh_cols[0]] = neighbor_vals.mean(dim=0) # Compute spatial gradients for 2D flow: dy and dx # Using neighbor indices: up = 1, down = 2, left = 3, right = 4 grad_samples = diffusion_tensor[ neigh_rows[[2, 1, 4, 3]], # indices [down, up, right, left] neigh_cols[[2, 1, 4, 3]] ] # shape (4, N) dy = grad_samples[0] - grad_samples[1] dx = grad_samples[2] - grad_samples[3] # Stack and convert to numpy flow field with shape (2, N) flow_field = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=0) return flow_field def __mask_to_flow_cpu(self, mask: np.ndarray) -> np.ndarray: """ Convert labeled masks to flow vectors by simulating diffusion from mask centers. Each mask's center is chosen as the pixel closest to its geometric centroid. A diffusion process is run on a padded local patch, and flows are derived as gradients (dy, dx) of the resulting density map. Args: masks (np.ndarray): 3D integer array of labels `(C x H x W)`, where 0 = background and positive integers = mask IDs. Returns: flow_field (np.ndarray): Array of shape `(2*C, H, W)` containing flow components [dy, dx] normalized per pixel. """ channels, height, width = mask.shape flows = np.zeros((2*channels, height, width), np.float32) for channel in range(channels): # Initialize flow_field with two channels: dy and dx flow_field = np.zeros((2, height, width), dtype=np.float64) # Find bounding box for each labeled mask mask_slices = find_objects(mask) # centers: List[Tuple[int, int]] = [] # Iterate over mask labels in parallel for label_idx in prange(len(mask_slices)): slc = mask_slices[label_idx] if slc is None: continue # Extract row and column slice for this mask row_slice, col_slice = slc # Add 1-pixel border around the patch patch_height = (row_slice.stop - row_slice.start) + 2 patch_width = (col_slice.stop - col_slice.start) + 2 # Get local coordinates of mask pixels within the patch local_rows, local_cols = np.nonzero( mask[row_slice, col_slice] == (label_idx + 1) ) # Shift coords by +1 for the border padding local_rows = local_rows.astype(np.int32) + 1 local_cols = local_cols.astype(np.int32) + 1 # Compute centroid and find nearest pixel as diffusion seed centroid_row = local_rows.mean() centroid_col = local_cols.mean() distances = (local_cols - centroid_col) ** 2 + (local_rows - centroid_row) ** 2 seed_index = distances.argmin() center_row = int(local_rows[seed_index]) center_col = int(local_cols[seed_index]) # Determine number of iterations total_iter = 2 * (patch_height + patch_width) # Initialize flat diffusion map for the local patch diffusion_map = np.zeros(patch_height * patch_width, dtype=np.float64) # Run diffusion from the seed center diffusion_map = self.__diffuse_from_center( diffusion_map, local_rows, local_cols, center_row, center_col, patch_width, total_iter ) # Compute flow as finite differences (gradient) on the diffusion map dy = ( diffusion_map[(local_rows + 1) * patch_width + local_cols] - diffusion_map[(local_rows - 1) * patch_width + local_cols] ) dx = ( diffusion_map[local_rows * patch_width + (local_cols + 1)] - diffusion_map[local_rows * patch_width + (local_cols - 1)] ) # Write flows back into the global flow_field array flow_field[0, row_slice.start + local_rows - 1, col_slice.start + local_cols - 1] = dy flow_field[1, row_slice.start + local_rows - 1, col_slice.start + local_cols - 1] = dx # Store center location in original image coordinates # centers.append( # (row_slice.start + center_row - 1, # col_slice.start + center_col - 1) # ) # Normalize each vector [dy,dx] by its magnitude magnitudes = np.sqrt((flow_field**2).sum(axis=0)) + 1e-60 flow_field /= magnitudes flows[2*channel: 2*channel + 2] = flow_field return flows @staticmethod @njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True) def __diffuse_from_center( diffusion_map: np.ndarray, row_coords: np.ndarray, col_coords: np.ndarray, center_row: int, center_col: int, patch_width: int, num_iterations: int ) -> np.ndarray: """ Perform diffusion of particles from a seed pixel across a local mask patch. At each iteration, one particle is added at the seed, and each mask pixel's value is updated to the average of itself and its 8-connected neighbors. Args: diffusion_map (np.ndarray): Flat array of length patch_height * patch_width. row_coords (np.ndarray): 1D array of row indices for mask pixels (local coords). col_coords (np.ndarray): 1D array of column indices for mask pixels (local coords). center_row (int): Row index of the seed point in local patch coords. center_col (int): Column index of the seed point in local patch coords. patch_width (int): Width (number of columns) in the local patch. num_iterations (int): Number of diffusion iterations to perform. Returns: np.ndarray: Updated diffusion_map after performing diffusion. """ # Compute linear indices for each mask pixel and its neighbors base_idx = row_coords * patch_width + col_coords up = (row_coords - 1) * patch_width + col_coords down = (row_coords + 1) * patch_width + col_coords left = row_coords * patch_width + (col_coords - 1) right = row_coords * patch_width + (col_coords + 1) up_left = (row_coords - 1) * patch_width + (col_coords - 1) up_right = (row_coords - 1) * patch_width + (col_coords + 1) down_left = (row_coords + 1) * patch_width + (col_coords - 1) down_right = (row_coords + 1) * patch_width + (col_coords + 1) for _ in range(num_iterations): # Inject one particle at the seed location diffusion_map[center_row * patch_width + center_col] += 1.0 # Update each mask pixel as the average over itself and neighbors diffusion_map[base_idx] = ( diffusion_map[base_idx] + diffusion_map[up] + diffusion_map[down] + diffusion_map[left] + diffusion_map[right] + diffusion_map[up_left] + diffusion_map[up_right] + diffusion_map[down_left] + diffusion_map[down_right] ) * (1.0 / 9.0) return diffusion_map def __segment_instances( self, probability_map: np.ndarray, flow: np.ndarray, prob_threshold: float = 0.0, flow_threshold: float = 0.4, num_iters: int = 200, min_object_size: int = 0 ) -> np.ndarray: """ Generate instance segmentation masks from probability and flow fields. Args: probability_map: 3D array (channels, height, width) of cell probabilities. flow: 3D array (2*channels, height, width) of forward flow vectors. prob_threshold: threshold to binarize probability_map. (Default 0.0) flow_threshold: threshold for filtering bad flow masks. (Default 0.4) num_iters: number of iterations for flow-following. (Default 200) min_object_size: minimum area to keep small instances. (Default 0) Returns: 3D array of uint16 instance labels for each channel. """ # Create a binary mask of likely cell locations probability_mask = probability_map > prob_threshold # If no cells exceed the threshold, return an empty mask if not np.any(probability_mask): logger.warning("No cell pixels found.") return np.zeros_like(probability_map, dtype=np.uint16) # Prepare output array for instance labels labeled_instances = np.zeros_like(probability_map, dtype=np.uint16) # Process each channel independently for channel_index in range(probability_mask.shape[0]): # Extract flow vectors for this channel (two components per channel) channel_flow_vectors = flow[2 * channel_index : 2 * channel_index + 2] # Extract binary mask for this channel channel_mask = probability_mask[channel_index] nonzero_coords = np.stack(np.nonzero(channel_mask)) # Follow the flow vectors to generate coordinate mappings flow_coordinates = self.__follow_flows( flow_field=channel_flow_vectors * channel_mask / 5.0, initial_coords=nonzero_coords, num_iters=num_iters ) # If flow following fails, leave this channel empty if flow_coordinates is None: labeled_instances[channel_index] = np.zeros( probability_map.shape[1:], dtype=np.uint16 ) continue if not torch.is_tensor(flow_coordinates): flow_coordinates = torch.from_numpy( flow_coordinates).to(self._device, dtype=torch.int32) else: flow_coordinates = flow_coordinates.int() # Obtain preliminary instance masks by clustering the coordinates channel_instances_mask = self.__get_mask( pixel_positions=flow_coordinates, valid_indices=nonzero_coords, original_shape=probability_map.shape[1:] ) # Filter out bad flow-derived instances if requested if channel_instances_mask.max() > 0 and flow_threshold > 0: channel_instances_mask = self.__remove_inconsistent_flow_masks( mask=channel_instances_mask, flow_network=channel_flow_vectors, error_threshold=flow_threshold ) # Remove small objects or holes below the minimum size if min_object_size > 0: # channel_instances_mask = morphology.remove_small_holes( # channel_instances_mask, area_threshold=min_object_size # ) # channel_instances_mask = morphology.remove_small_objects( # channel_instances_mask, min_size=min_object_size # ) channel_instances_mask = self.__fill_holes_and_prune_small_masks( channel_instances_mask, minimum_size=min_object_size ) labeled_instances[channel_index] = channel_instances_mask else: # No valid instances found, leave the channel empty labeled_instances[channel_index] = np.zeros( probability_map.shape[1:], dtype=np.uint16 ) return labeled_instances def __follow_flows( self, flow_field: np.ndarray, initial_coords: np.ndarray, num_iters: int = 200 ) -> Union[np.ndarray, torch.Tensor]: """ Trace pixel positions through a flow field via iterative interpolation. Args: flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors. initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions. num_iters (int): Number of integration steps. Returns: np.ndarray or torch.Tensor: Final (y, x) positions of each point. """ dims = 2 # Extract spatial dimensions height, width = flow_field.shape[1:] # Choose GPU/MPS path if available if self._device.type in ("cuda", "mps"): # Prepare point tensor: shape [1, 1, num_points, 2] pts = torch.zeros((1, 1, initial_coords.shape[1], dims), dtype=torch.float32, device=self._device) # Prepare flow volume: shape [1, 2, height, width] flow_vol = torch.zeros((1, dims, height, width), dtype=torch.float32, device=self._device) # Load initial positions and flow into tensors (flip order for grid_sample) # dim 0 = x # dim 1 = y for i in range(dims): pts[0, 0, :, dims - i - 1] = ( torch.from_numpy(initial_coords[i]) .to(self._device, torch.float32) ) flow_vol[0, dims - i - 1] = ( torch.from_numpy(flow_field[i]) .to(self._device, torch.float32) ) # Prepare normalization factors for x and y (max index) max_indices = torch.tensor([width - 1, height - 1], dtype=torch.float32, device=self._device) # Reshape for broadcasting to point tensor dims max_idx_pt = max_indices.view(1, 1, 1, dims) # Reshape for broadcasting to flow volume dims max_idx_flow = max_indices.view(1, dims, 1, 1) # Normalize flow values to [-1, 1] range flow_vol = (flow_vol * 2) / max_idx_flow # Normalize points to [-1, 1] pts = (pts / max_idx_pt) * 2 - 1 # Iterate: sample flow and update points for _ in range(num_iters): sampled = torch.nn.functional.grid_sample( flow_vol, pts, align_corners=False ) # Update each coordinate and clamp to valid range for i in range(dims): pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0) # Denormalize back to original pixel coordinates pts = (pts + 1) * 0.5 * max_idx_pt # Swap channels back to (y, x) and flatten final_pts = pts[..., [1, 0]].squeeze() # Convert from (num_points, 2) to (2, num_points) return final_pts.T if final_pts.ndim > 1 else final_pts.unsqueeze(0).T # CPU fallback using numpy and scipy current_pos = initial_coords.copy().astype(np.float32) temp_delta = np.zeros_like(current_pos, dtype=np.float32) for _ in range(num_iters): # Interpolate flow at current positions self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta) # Update positions and clamp to image bounds current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1) current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1) return current_pos @staticmethod @njit([ "(int16[:,:,:], float32[:], float32[:], float32[:,:])", "(float32[:,:,:], float32[:], float32[:], float32[:,:])" ], cache=True) def __map_coordinates( image_data: np.ndarray, y_coords: np.ndarray, x_coords: np.ndarray, output: np.ndarray ) -> None: """ Perform in-place bilinear interpolation on an image volume. Args: image_data (np.ndarray): Input volume with shape (C, H, W). y_coords (np.ndarray): Array of new y positions (num_points). x_coords (np.ndarray): Array of new x positions (num_points). output (np.ndarray): Output array of shape (C, num_points) to fill. Returns: None. Results written directly into `output`. """ channels, height, width = image_data.shape # Compute integer (floor) and fractional parts for coords y_floor = y_coords.astype(np.int32) x_floor = x_coords.astype(np.int32) y_frac = y_coords - y_floor x_frac = x_coords - x_floor # Loop over each sample point for idx in range(y_floor.shape[0]): # Clamp base indices to valid range y0 = min(max(y_floor[idx], 0), height - 1) x0 = min(max(x_floor[idx], 0), width - 1) y1 = min(y0 + 1, height - 1) x1 = min(x0 + 1, width - 1) wy = y_frac[idx] wx = x_frac[idx] # Interpolate per channel for c in range(channels): v00 = np.float32(image_data[c, y0, x0]) v10 = np.float32(image_data[c, y0, x1]) v01 = np.float32(image_data[c, y1, x0]) v11 = np.float32(image_data[c, y1, x1]) # Bilinear interpolation formula output[c, idx] = ( v00 * (1 - wy) * (1 - wx) + v10 * (1 - wy) * wx + v01 * wy * (1 - wx) + v11 * wy * wx ) def __get_mask( self, pixel_positions: torch.Tensor, valid_indices: np.ndarray, original_shape: Tuple[int, ...], pad_radius: int = 20, max_size_fraction: float = 0.4 ) -> np.ndarray: """ Generate labeled masks by clustering pixel trajectories via histogram peaks and region growing. This function executes the following steps: 1. Pads and clamps pixel final positions to avoid border effects. 2. Builds a dense histogram of pixel counts over spatial bins. 3. Identifies local maxima in the histogram as seed points. 4. Extracts local patches around each seed and grows regions by iteratively adding neighbors that exceed an intensity threshold. 5. Maps grown patches back to original image indices. 6. Removes any masks that exceed a maximum size fraction of the image. Args: pixel_positions (torch.Tensor): Tensor of shape `[2, N_pixels]`, dtype=int, containing final pixel coordinates after dynamics for each dimension. valid_indices (np.ndarray): Integer array of shape `[2, N_pixels]` giving indices of pixels in the original image grid. original_shape (tuple of ints): Spatial dimensions of the original image, e.g. (H, W). pad_radius (int): Number of zero-padding pixels added on each side of the histogram. Defaults to 20. max_size_fraction (float): If any mask has a pixel count > max_size_fraction * total_pixels, it will be removed. Defaults to 0.4. Returns: np.ndarray: Integer mask array of shape `original_shape` with labels 0 (background) and 1..M. Raises: ValueError: If input dimensions are inconsistent or pixel_positions shape is invalid. """ # Validate inputs ndim = len(original_shape) if pixel_positions.ndim != 2 or pixel_positions.size(0) != ndim: msg = f"pixel_positions must be shape [{ndim}, N], got {tuple(pixel_positions.shape)}" logger.error(msg) raise ValueError(msg) if pad_radius < 0: msg = f"pad_radius must be non-negative, got {pad_radius}" logger.error(msg) raise ValueError(msg) # Step 1: Pad and clamp pixel positions padded_positions = pixel_positions.clone().to(torch.int64) + pad_radius for dim in range(ndim): max_val = original_shape[dim] + pad_radius - 1 padded_positions[dim] = torch.clamp(padded_positions[dim], min=0, max=max_val) # Build histogram dimensions hist_shape = tuple(s + 2 * pad_radius for s in original_shape) # Step 2: Create sparse tensor and densify to get per-pixel counts try: counts_sparse = torch.sparse_coo_tensor( padded_positions, torch.ones(padded_positions.shape[1], dtype=torch.int32, device=pixel_positions.device), size=hist_shape ) histogram = counts_sparse.to_dense() except Exception as e: logger.error("Failed to build dense histogram: %s", e) raise # Step 3: Find peaks via 5x5 max-pooling k = 5 pooled = F.max_pool2d( histogram.unsqueeze(0), kernel_size=k, stride=1, padding=k // 2 ).squeeze() # Seeds are positions where histogram equals local max and count > threshold seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10)) if seed_positions.numel() == 0: logger.warning("No seeds found: returning empty mask") return np.zeros(original_shape, dtype=np.uint16) # Sort seeds by ascending count to process small peaks first seed_counts = histogram[tuple(seed_positions.T)] order = torch.argsort(seed_counts) seed_positions = seed_positions[order] del pooled, counts_sparse # Step 4: Extract local patches and perform region growing num_seeds = seed_positions.shape[0] # Tensor to hold local patches patches = torch.zeros((num_seeds, 11, 11), device=pixel_positions.device) for idx in range(num_seeds): coords = seed_positions[idx] slices = tuple(slice(c - 5, c + 6) for c in coords) patches[idx] = histogram[slices] del histogram # Initialize seed mask (center pixel of each patch) seed_masks = torch.zeros_like(patches, device=pixel_positions.device) seed_masks[:, 5, 5] = 1 # Iterative dilation and thresholding for _ in range(5): seed_masks = F.max_pool2d( seed_masks, kernel_size=3, stride=1, padding=1 ) seed_masks = seed_masks & (patches > 2) # Compute final mask coordinates final_coords = [] for idx in range(num_seeds): coords_local = torch.nonzero(seed_masks[idx]) # Shift back to global positions coords_global = coords_local + seed_positions[idx] - 5 final_coords.append(tuple(coords_global.T)) # Step 5: Paint masks into padded volume dtype = torch.int32 if num_seeds < 2**16 else torch.int64 mask_padded = torch.zeros(hist_shape, dtype=dtype, device=pixel_positions.device) for label_idx, coords in enumerate(final_coords, start=1): mask_padded[coords] = label_idx # Extract only the padded positions that correspond to original pixels mask_values = mask_padded[tuple(padded_positions)] mask_values = mask_values.cpu().numpy() # Step 6: Map to original image and remove oversized masks mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32) mask_final[valid_indices] = mask_values # Prune masks that are too large labels, counts = fastremap.unique(mask_final, return_counts=True) total_pixels = np.prod(original_shape) oversized = labels[counts > (total_pixels * max_size_fraction)] if oversized.size > 0: mask_final = fastremap.mask(mask_final, oversized) fastremap.renumber(mask_final, in_place=True) return mask_final def __remove_inconsistent_flow_masks( self, mask: np.ndarray, flow_network: np.ndarray, error_threshold: float = 0.4 ) -> np.ndarray: """ Remove labeled masks that have inconsistent optical flows compared to network-predicted flows. This performs a quality control step by computing flows from the provided masks and comparing them to the flows predicted by the network. Masks with a mean squared flow error above `error_threshold` are discarded (set to 0). Args: mask (np.ndarray): Integer mask array with shape [H, W]. Values: 0 = no mask; 1,2,... = mask labels. flow_network (np.ndarray): Float array of network-predicted flows with shape [2, H, W]. error_threshold (float): Maximum allowed mean squared flow error per mask label. Defaults to 0.4. Returns: np.ndarray: The input mask with inconsistent masks removed (labels set to 0). Raises: MemoryError: If the mask size exceeds available GPU memory. """ # If mask is very large and running on CUDA, check memory num_pixels = mask.size if ( num_pixels > 10000 * 10000 and self._device.type == 'cuda' ): # Clear unused GPU cache torch.cuda.empty_cache() # Determine PyTorch version major, minor = map(int, torch.__version__.split('.')[:2]) # Determine current CUDA device index device_index = ( self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device() ) # Get free and total memory if major == 1 and minor < 10: total_mem = torch.cuda.get_device_properties(device_index).total_memory used_mem = torch.cuda.memory_allocated(device_index) free_mem = total_mem - used_mem else: free_mem, total_mem = torch.cuda.mem_get_info(device_index) # Estimate required memory for mask-based flow computation # Assume float32 per pixel required_bytes = num_pixels * np.dtype(np.float32).itemsize if required_bytes > free_mem: logger.error( 'Image too large for GPU memory in flow QC step (required: %d B, available: %d B)', required_bytes, free_mem ) raise MemoryError('Insufficient GPU memory for flow QC computation') # Compute flow errors per mask label flow_errors, _ = self.__compute_flow_error(mask, flow_network) # Identify labels with error above threshold bad_labels = np.nonzero(flow_errors > error_threshold)[0] + 1 # Remove bad masks by setting their label to 0 mask[np.isin(mask, bad_labels)] = 0 return mask def __compute_flow_error( self, mask: np.ndarray, flow_network: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Compute mean squared error between network-predicted flows and flows derived from masks. Args: mask (np.ndarray): Integer masks, shape must match flow_network spatial dims. flow_network (np.ndarray): Network predicted flows of shape [axis, ...]. Returns: Tuple[np.ndarray, np.ndarray]: - flow_errors: 1D array (length = max label) of mean squared error per label. - computed_flows: Array of flows derived from the mask, same shape as flow_network. Raises: ValueError: If the spatial dimensions of `mask_array` and `flow_network` do not match. """ # Ensure mask and flow shapes match if flow_network.shape[1:] != mask.shape: logger.error( 'Shape mismatch: network flow shape %s vs mask shape %s', flow_network.shape[1:], mask.shape ) raise ValueError('Network flow and mask shapes must match') # Compute flows from mask labels (user-provided function) computed_flows = self.__compute_flow_from_mask(mask[None, ...]) # Prepare array for errors (one value per mask label) num_labels = int(mask.max()) flow_errors = np.zeros(num_labels, dtype=float) # Accumulate mean squared error over each flow axis for axis_index in range(computed_flows.shape[0]): # MSE per label: mean((computed - predicted/5)^2) flow_errors += mean( (computed_flows[axis_index] - flow_network[axis_index] / 5.0) ** 2, mask, index=np.arange(1, num_labels + 1) ) return flow_errors, computed_flows def __fill_holes_and_prune_small_masks( self, masks: np.ndarray, minimum_size: int = 15 ) -> np.ndarray: """ Fill holes in labeled masks and remove masks smaller than a given size. This function performs two steps: 1. Fills internal holes in each labeled mask using `fill_voids.fill`. 2. Discards any mask whose pixel count is below `minimum_size`. Args: masks (np.ndarray): Integer mask array of dimension 2 or 3 (shape [H, W] or [D, H, W]). Values: 0 = background; 1,2,... = mask labels. minimum_size (int): Minimum number of pixels required to keep a mask. Masks smaller than this will be removed. Set to -1 to skip size-based pruning. Defaults to 15. Returns: np.ndarray: Processed mask array with holes filled and small masks removed. Raises: ValueError: If `masks` is not a 2D or 3D integer array. """ # Validate input dimensions if masks.ndim not in (2, 3): msg = f"Expected 2D or 3D mask array, got {masks.ndim}D." logger.error(msg) raise ValueError(msg) # Optionally remove masks smaller than minimum_size if minimum_size >= 0: # Compute label counts (skipping background at index 0) labels, counts = fastremap.unique(masks, return_counts=True) # Identify labels to remove: those with count < minimum_size small_labels = labels[counts < minimum_size] if small_labels.size > 0: masks = fastremap.mask(masks, small_labels) fastremap.renumber(masks, in_place=True) # Find bounding boxes for each mask label object_slices = find_objects(masks) new_label = 1 output_masks = np.zeros_like(masks, dtype=masks.dtype) # Loop over each original slice, fill holes, and assign new labels for original_label, slc in enumerate(object_slices, start=1): if slc is None: continue # Extract sub-volume or sub-image region = masks[slc] == original_label if not np.any(region): continue # Fill internal holes filled_region = fill_voids.fill(region) # Write back into output mask with sequential labels output_masks[slc][filled_region] = new_label new_label += 1 # Final pruning of small masks after filling (optional) if minimum_size >= 0: labels, counts = fastremap.unique(output_masks, return_counts=True) small_labels = labels[counts < minimum_size] if small_labels.size > 0: output_masks = fastremap.mask(output_masks, small_labels) fastremap.renumber(output_masks, in_place=True) return output_masks