From 60aebe5921fd8db60a1d49c6bd97942adf55fdaf Mon Sep 17 00:00:00 2001 From: laynholt Date: Mon, 28 Apr 2025 17:43:56 +0000 Subject: [PATCH] added methods for calculating gradient flows --- core/models/model_v.py | 13 +- core/segmentator.py | 1136 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 1092 insertions(+), 57 deletions(-) diff --git a/core/models/model_v.py b/core/models/model_v.py index 6474dc1..a525daf 100644 --- a/core/models/model_v.py +++ b/core/models/model_v.py @@ -53,6 +53,13 @@ class ModelV(MAnet): self.gradflow_head = DeepSegmentationHead( in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes ) + + # self.gradflow_head = nn.ModuleList([ + # DeepSegmentationHead( + # in_channels=params.decoder_channels[-1], out_channels=2 + # ) for _ in range(params.out_classes) + # ]) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the network""" @@ -66,9 +73,13 @@ class ModelV(MAnet): # Generate masks for cell probability and gradient flows cellprob_mask = self.cellprob_head(decoder_output) gradflow_mask = self.gradflow_head(decoder_output) + + # gradflow_masks = torch.cat( + # [head(decoder_output) for head in self.flow_heads], dim=1 # [B, 2*C, H, W] + # ) # Concatenate the masks for output - masks = torch.cat([gradflow_mask, cellprob_mask], dim=1) + masks = torch.cat((gradflow_mask, cellprob_mask), dim=1) return masks diff --git a/core/segmentator.py b/core/segmentator.py index 4882974..461b98e 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -1,14 +1,26 @@ import torch import random import numpy as np +from numba import njit, prange + +import torch +from torch import Tensor +import torch.nn.functional as F +from torch.utils.data import DataLoader + +import fastremap + +from skimage import morphology +from scipy.ndimage import mean, find_objects +import fill_voids + from monai.data.dataset import Dataset from monai.transforms import * # type: ignore -from torch.utils.data import DataLoader import os import glob from pprint import pformat -from typing import Optional, Union +from typing import Optional, Tuple, List, Union from config import Config from core.models import * @@ -27,6 +39,8 @@ class CellSegmentator: self.__set_seed(config.dataset_config.common.seed) self.__parse_config(config) + self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu") + self._train_dataloader: Optional[DataLoader] = None self._valid_dataloader: Optional[DataLoader] = None self._test_dataloader: Optional[DataLoader] = None @@ -227,21 +241,19 @@ class CellSegmentator: optimizer = config.optimizer scheduler = config.scheduler - # Log the full configuration dictionary - full_config_dict = { - "model": model.dump(), - "criterion": criterion.dump() if criterion else None, - "optimizer": optimizer.dump() if optimizer else None, - "scheduler": scheduler.dump() if scheduler else None, - "dataset_config": config.dataset_config.model_dump() - } logger.info("========== Parsed Configuration ==========") - logger.info(pformat(full_config_dict, width=120)) + logger.info("Model Config:\n%s", pformat(model.dump(), indent=2)) + if criterion: + logger.info("Criterion Config:\n%s", pformat(criterion.dump(), indent=2)) + if optimizer: + logger.info("Optimizer Config:\n%s", pformat(optimizer.dump(), indent=2)) + if scheduler: + logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2)) + logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2)) logger.info("==========================================") # Initialize model using the model registry self._model = ModelRegistry.get_model_class(model.name)(model.params) - logger.info(f"Initialized model: {model.name}") # Initialize loss criterion if specified self._criterion = ( @@ -249,10 +261,6 @@ class CellSegmentator: if criterion is not None else None ) - if self._criterion is not None and criterion is not None: - logger.info(f"Initialized criterion: {criterion.name}") - else: - logger.info("Criterion: not specified") # Initialize optimizer if specified self._optimizer = ( @@ -263,10 +271,6 @@ class CellSegmentator: if optimizer is not None else None ) - if self._optimizer is not None and optimizer is not None: - logger.info(f"Initialized optimizer: {optimizer.name}") - else: - logger.info("Optimizer: not specified") # Initialize scheduler only if both scheduler and optimizer are defined self._scheduler = ( @@ -277,55 +281,62 @@ class CellSegmentator: if scheduler is not None and self._optimizer is not None and self._optimizer.optim is not None else None ) - if self._scheduler is not None and scheduler is not None: - logger.info(f"Initialized scheduler: {scheduler.name}") - else: - logger.info("Scheduler: not specified") + + logger.info("========== Model Components Initialization ==========") + logger.info("├─ Model: " + (f"{model.name}" if self._model else "Not specified")) + logger.info("├─ Criterion: " + (f"{criterion.name}" if self._criterion else "Not specified")) # type: ignore + logger.info("├─ Optimizer: " + (f"{optimizer.name}" if self._optimizer else "Not specified")) # type: ignore + logger.info("└─ Scheduler: " + (f"{scheduler.name}" if self._scheduler else "Not specified")) # type: ignore + logger.info("=====================================================") + # Save dataset config self._dataset_setup = config.dataset_config - logger.info("Dataset setup loaded") common = config.dataset_config.common - logger.info(f"Seed: {common.seed}") - logger.info(f"Device: {common.device}") - logger.info(f"Predictions output dir: {common.predictions_dir}") + + logger.info("========== Dataset Setup ==========") + logger.info("[COMMON]") + logger.info(f"├─ Seed: {common.seed}") + logger.info(f"├─ Device: {common.device}") + logger.info(f"└─ Predictions output dir: {common.predictions_dir}") if config.dataset_config.is_training: training = config.dataset_config.training - logger.info("Mode: Training") - logger.info(f" Batch size: {training.batch_size}") - logger.info(f" Epochs: {training.num_epochs}") - logger.info(f" Validation frequency: {training.val_freq}") - logger.info(f" Use AMP: {'yes' if training.use_amp else 'no'}") - logger.info(f" Pretrained weights: {training.pretrained_weights}") + logger.info("[MODE] Training") + logger.info(f"├─ Batch size: {training.batch_size}") + logger.info(f"├─ Epochs: {training.num_epochs}") + logger.info(f"├─ Validation frequency: {training.val_freq}") + logger.info(f"├─ Use AMP: {'yes' if training.use_amp else 'no'}") + logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}") if training.is_split: - logger.info(" Using pre-split directories:") - logger.info(f" Train dir: {training.pre_split.train_dir}") - logger.info(f" Valid dir: {training.pre_split.valid_dir}") - logger.info(f" Test dir: {training.pre_split.test_dir}") + logger.info(f"├─ Using pre-split directories:") + logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}") + logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}") + logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}") else: - logger.info(" Using unified dataset with splits:") - logger.info(f" All data dir: {training.split.all_data_dir}") - logger.info(f" Shuffle: {'yes' if training.split.shuffle else 'no'}") - - logger.info(" Dataset split:") - logger.info(f" Train size: {training.train_size}, offset: {training.train_offset}") - logger.info(f" Valid size: {training.valid_size}, offset: {training.valid_offset}") - logger.info(f" Test size: {training.test_size}, offset: {training.test_offset}") - + logger.info(f"├─ Using unified dataset with splits:") + logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}") + logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}") + + logger.info(f"└─ Dataset split:") + logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}") + logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}") + logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}") + else: testing = config.dataset_config.testing - logger.info("Mode: Inference") - logger.info(f" Test dir: {testing.test_dir}") - logger.info(f" Test size: {testing.test_size} (offset: {testing.test_offset})") - logger.info(f" Shuffle: {'yes' if testing.shuffle else 'no'}") - logger.info(f" Use ensemble: {'yes' if testing.use_ensemble else 'no'}") - logger.info(f" Pretrained weights:") - logger.info(f" Single model: {testing.pretrained_weights}") - logger.info(f" Ensemble model 1: {testing.ensemble_pretrained_weights1}") - logger.info(f" Ensemble model 2: {testing.ensemble_pretrained_weights2}") + logger.info("[MODE] Inference") + logger.info(f"├─ Test dir: {testing.test_dir}") + logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})") + logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}") + logger.info(f"├─ Use ensemble: {'yes' if testing.use_ensemble else 'no'}") + logger.info(f"└─ Pretrained weights:") + logger.info(f" ├─ Single model: {testing.pretrained_weights}") + logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}") + logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}") + logger.info("===================================") def __set_seed(self, seed: Optional[int]) -> None: @@ -421,3 +432,1016 @@ class CellSegmentator: data = [{"image": image} for image in images] return Dataset(data, transforms) + + + def __compute_flows_from_masks( + self, + true_masks: Tensor + ) -> np.ndarray: + """ + Convert segmentation masks to flow fields for training. + + Args: + true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks. + + Returns: + numpy array of concatenated [renumbered_true_masks, binary_masks, flow_vectors] per image. + renumbered_true_masks is labels, binary_masks is cell distance transform, flow_vectors[2] is Y flow, flows[k][3] is X flow, + and flow_vectors[4] is heat distribution. + """ + # Move to CPU numpy + _true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16) + batch_size = _true_masks.shape[0] + + # Ensure each label has a channel dimension + if _true_masks.ndim == 3: + # shape (batch, H, W) -> (batch, 1, H, W) + _true_masks = _true_masks[:, np.newaxis, :, :] + + batch_size, *_ = _true_masks.shape + + # Renumber labels to ensure uniqueness + renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0] + for i in range(batch_size)]) + # Compute vector flows per image + flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i]) + for i in range(batch_size)]) + + return np.concatenate((renumbered, renumbered > 0.5, flow_vectors), axis=1).astype(np.float32) + + + def __compute_flow_from_mask( + self, + mask: np.ndarray + ) -> np.ndarray: + """ + Compute normalized flow vectors from a labeled mask. + + Args: + mask: 3D array of instance-labeled mask of shape (C, H, W). + + Returns: + flow: Array of shape (2 * C, H, W). + """ + if mask.max() == 0 or np.count_nonzero(mask) <= 1: + # No flow to compute + logger.warning("Empty mask!") + C, H, W = mask.shape + return np.zeros((2*C, H, W), dtype=np.float32) + + # Delegate to GPU or CPU routine + if self._device.type == "cuda" or self._device.type == "mps": + return self.__mask_to_flow_gpu(mask) + else: + return self.__mask_to_flow_cpu(mask) + + + def __mask_to_flow_gpu(self, mask: np.ndarray) -> np.ndarray: + """Convert masks to flows using diffusion from center pixel. + + Center of masks where diffusion starts is defined by pixel closest to median within the mask. + + Args: + masks (3D array): Labelled masks of shape (C, H, W). + + Returns: + np.ndarray: A 3D array where for each channel the flows for each pixel + are represented along the X and Y axes. + """ + + channels, height, width = mask.shape + flows = np.zeros((2*channels, height, width), np.float32) + + for channel in range(channels): + padded_height, padded_width = height + 2, width + 2 + + # Pad the mask with a 1-pixel border + masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device) + masks_padded = F.pad(masks_padded, (1, 1, 1, 1)) + + # Get coordinates of all non-zero pixels in the padded mask + y, x = torch.nonzero(masks_padded, as_tuple=True) + y = y.int(); x = x.int() # ensure integer type + + # Generate 8-connected neighbors (including center) via broadcasted offsets + offsets = torch.tensor([ + [ 0, 0], # center + [-1, 0], # up + [ 1, 0], # down + [ 0, -1], # left + [ 0, 1], # right + [-1, -1], # up-left + [-1, 1], # up-right + [ 1, -1], # down-left + [ 1, 1], # down-right + ], dtype=torch.int32, device=self._device) # (9, 2) + + # coords: (N, 2) + coords = torch.stack((y, x), dim=1) + + # neighbors: (9, N, 2) + neighbors = offsets[:, None, :] + coords[None, :, :] + + # transpose into (2, 9, N) for the GPU kernel + neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index + + # Build connectivity mask: True where neighbor label == center label + center_labels = masks_padded[y, x][None, :] # (1, N) + neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N) + is_neighbor = neighbor_labels == center_labels # (9, N) + + # Compute object slices and pack into array for get_centers + slices = find_objects(mask) + slices_arr = np.array([ + [i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop] + for i, sl in enumerate(slices) if sl is not None + ], dtype=int) + + # Compute centers (pixel indices) and extents via the provided helper + centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr) + # Move centers to GPU and shift by +1 for padding + meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding + + # Determine number of diffusion iterations + n_iter = 2 * ext.max() + + # Run the GPU diffusion kernel + mu = self.__propagate_centers_gpu( + neighbor_indices=neighbors, + center_indices=meds_p.T, + valid_neighbor_mask=is_neighbor, + output_shape=(padded_height, padded_width), + num_iterations=n_iter + ) + + # Cast to float64 and normalize flow vectors + mu = mu.astype(np.float64) + mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60 + + # Remove the padding and write into final output + flow_output = np.zeros((2, height, width), dtype=np.float32) + ys_np = y.cpu().numpy() - 1 + xs_np = x.cpu().numpy() - 1 + flow_output[:, ys_np, xs_np] = mu + flows[2*channel: 2*channel + 2] = flow_output + + return flows + + + @staticmethod + @njit(nogil=True) + def __get_mask_centers_and_extents( + label_map: np.ndarray, + slices_arr: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute the centroids and extents of labeled regions in a 2D mask array. + + Args: + label_map (np.ndarray): 2D array where each connected region has a unique integer label (1…K). + slices_arr (np.ndarray): Array of shape (K, 5), where each row is + (label_id, row_start, row_stop, col_start, col_stop). + + Returns: + centers (np.ndarray): Integer array of shape (K, 2) with (row, col) center for each label. + extents (np.ndarray): Integer array of shape (K,) giving the sum of height and width + 2 for each region. + """ + num_regions = slices_arr.shape[0] + centers = np.zeros((num_regions, 2), dtype=np.int32) + extents = np.zeros(num_regions, dtype=np.int32) + + for idx in prange(num_regions): + # Unpack slice info + label_id = slices_arr[idx, 0] + row_start = slices_arr[idx, 1] + row_stop = slices_arr[idx, 2] + col_start = slices_arr[idx, 3] + col_stop = slices_arr[idx, 4] + + # Extract binary submask for this label + submask = (label_map[row_start:row_stop, col_start:col_stop] == label_id) + + # Get local coordinates of all pixels in the region + ys, xs = np.nonzero(submask) + + # Compute the floating-point centroid within the submask + y_mean = ys.mean() + x_mean = xs.mean() + + # Find the pixel closest to the centroid by minimizing squared distance + dist_sq = (ys - y_mean) ** 2 + (xs - x_mean) ** 2 + closest_idx = dist_sq.argmin() + + # Convert to global coordinates + center_row = ys[closest_idx] + row_start + center_col = xs[closest_idx] + col_start + centers[idx, 0] = center_row + centers[idx, 1] = center_col + + # Compute extent as height + width + 2 (to include one-pixel border) + height = row_stop - row_start + width = col_stop - col_start + extents[idx] = height + width + 2 + + return centers, extents + + + def __propagate_centers_gpu( + self, + neighbor_indices: torch.Tensor, + center_indices: torch.Tensor, + valid_neighbor_mask: torch.Tensor, + output_shape: Tuple[int, int], + num_iterations: int = 200 + ) -> np.ndarray: + """ + Propagates center points across a mask using GPU-based diffusion. + + Args: + neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel. + center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers. + valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid. + output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W). + num_iterations (int, optional): Number of diffusion iterations. Defaults to 200. + + Returns: + np.ndarray: Array of shape (2, N) with the computed flows. + """ + # Determine total number of elements and choose dtype accordingly + total_elems = torch.prod(torch.tensor(output_shape)) + if total_elems > 4e7 or self._device.type == "mps": + diffusion_tensor = torch.zeros(output_shape, dtype=torch.float, device=self._device) + else: + diffusion_tensor = torch.zeros(output_shape, dtype=torch.double, device=self._device) + + # Unpack center row and column indices + center_rows, center_cols = center_indices + + # Unpack neighbor row and column indices for 9 neighbors per pixel + # Order: [0: center, 1: up, 2: down, 3: left, 4: right, + # 5: up-left, 6: up-right, 7: down-left, 8: down-right] + neigh_rows, neigh_cols = neighbor_indices # each of shape (9, N) + + # Perform diffusion iterations + for _ in range(num_iterations): + # Add source at each mask center + diffusion_tensor[center_rows, center_cols] += 1 + + # Sample neighbor values for each pixel + neighbor_vals = diffusion_tensor[neigh_rows, neigh_cols] # shape (9, N) + + # Zero out invalid neighbors + neighbor_vals *= valid_neighbor_mask + + # Update the first neighbor (index 0) with the average of valid neighbor values + diffusion_tensor[neigh_rows[0], neigh_cols[0]] = neighbor_vals.mean(dim=0) + + # Compute spatial gradients for 2D flow: dy and dx + # Using neighbor indices: up = 1, down = 2, left = 3, right = 4 + grad_samples = diffusion_tensor[ + neigh_rows[[2, 1, 4, 3]], # indices [down, up, right, left] + neigh_cols[[2, 1, 4, 3]] + ] # shape (4, N) + + dy = grad_samples[0] - grad_samples[1] + dx = grad_samples[2] - grad_samples[3] + + # Stack and convert to numpy flow field with shape (2, N) + flow_field = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=0) + + return flow_field + + + def __mask_to_flow_cpu(self, mask: np.ndarray) -> np.ndarray: + """ + Convert labeled masks to flow vectors by simulating diffusion from mask centers. + + Each mask's center is chosen as the pixel closest to its geometric centroid. + A diffusion process is run on a padded local patch, and flows are derived + as gradients (dy, dx) of the resulting density map. + + Args: + masks (np.ndarray): 3D integer array of labels `(C x H x W)`, + where 0 = background and positive integers = mask IDs. + + Returns: + flow_field (np.ndarray): Array of shape `(2*C, H, W)` containing + flow components [dy, dx] normalized per pixel. + """ + channels, height, width = mask.shape + flows = np.zeros((2*channels, height, width), np.float32) + + for channel in range(channels): + # Initialize flow_field with two channels: dy and dx + flow_field = np.zeros((2, height, width), dtype=np.float64) + + # Find bounding box for each labeled mask + mask_slices = find_objects(mask) + # centers: List[Tuple[int, int]] = [] + + # Iterate over mask labels in parallel + for label_idx in prange(len(mask_slices)): + slc = mask_slices[label_idx] + if slc is None: + continue + + # Extract row and column slice for this mask + row_slice, col_slice = slc + # Add 1-pixel border around the patch + patch_height = (row_slice.stop - row_slice.start) + 2 + patch_width = (col_slice.stop - col_slice.start) + 2 + + # Get local coordinates of mask pixels within the patch + local_rows, local_cols = np.nonzero( + mask[row_slice, col_slice] == (label_idx + 1) + ) + # Shift coords by +1 for the border padding + local_rows = local_rows.astype(np.int32) + 1 + local_cols = local_cols.astype(np.int32) + 1 + + # Compute centroid and find nearest pixel as diffusion seed + centroid_row = local_rows.mean() + centroid_col = local_cols.mean() + distances = (local_cols - centroid_col) ** 2 + (local_rows - centroid_row) ** 2 + seed_index = distances.argmin() + center_row = int(local_rows[seed_index]) + center_col = int(local_cols[seed_index]) + + # Determine number of iterations + total_iter = 2 * (patch_height + patch_width) + + # Initialize flat diffusion map for the local patch + diffusion_map = np.zeros(patch_height * patch_width, dtype=np.float64) + # Run diffusion from the seed center + diffusion_map = self.__diffuse_from_center( + diffusion_map, + local_rows, + local_cols, + center_row, + center_col, + patch_width, + total_iter + ) + + # Compute flow as finite differences (gradient) on the diffusion map + dy = ( + diffusion_map[(local_rows + 1) * patch_width + local_cols] - + diffusion_map[(local_rows - 1) * patch_width + local_cols] + ) + dx = ( + diffusion_map[local_rows * patch_width + (local_cols + 1)] - + diffusion_map[local_rows * patch_width + (local_cols - 1)] + ) + + # Write flows back into the global flow_field array + flow_field[0, + row_slice.start + local_rows - 1, + col_slice.start + local_cols - 1] = dy + flow_field[1, + row_slice.start + local_rows - 1, + col_slice.start + local_cols - 1] = dx + + # Store center location in original image coordinates + # centers.append( + # (row_slice.start + center_row - 1, + # col_slice.start + center_col - 1) + # ) + + # Normalize each vector [dy,dx] by its magnitude + magnitudes = np.sqrt((flow_field**2).sum(axis=0)) + 1e-60 + flow_field /= magnitudes + + flows[2*channel: 2*channel + 2] = flow_field + + return flows + + + @staticmethod + @njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True) + def __diffuse_from_center( + diffusion_map: np.ndarray, + row_coords: np.ndarray, + col_coords: np.ndarray, + center_row: int, + center_col: int, + patch_width: int, + num_iterations: int + ) -> np.ndarray: + """ + Perform diffusion of particles from a seed pixel across a local mask patch. + + At each iteration, one particle is added at the seed, and each mask pixel's + value is updated to the average of itself and its 8-connected neighbors. + + Args: + diffusion_map (np.ndarray): Flat array of length patch_height * patch_width. + row_coords (np.ndarray): 1D array of row indices for mask pixels (local coords). + col_coords (np.ndarray): 1D array of column indices for mask pixels (local coords). + center_row (int): Row index of the seed point in local patch coords. + center_col (int): Column index of the seed point in local patch coords. + patch_width (int): Width (number of columns) in the local patch. + num_iterations (int): Number of diffusion iterations to perform. + + Returns: + np.ndarray: Updated diffusion_map after performing diffusion. + """ + # Compute linear indices for each mask pixel and its neighbors + base_idx = row_coords * patch_width + col_coords + up = (row_coords - 1) * patch_width + col_coords + down = (row_coords + 1) * patch_width + col_coords + left = row_coords * patch_width + (col_coords - 1) + right = row_coords * patch_width + (col_coords + 1) + up_left = (row_coords - 1) * patch_width + (col_coords - 1) + up_right = (row_coords - 1) * patch_width + (col_coords + 1) + down_left = (row_coords + 1) * patch_width + (col_coords - 1) + down_right = (row_coords + 1) * patch_width + (col_coords + 1) + + for _ in range(num_iterations): + # Inject one particle at the seed location + diffusion_map[center_row * patch_width + center_col] += 1.0 + + # Update each mask pixel as the average over itself and neighbors + diffusion_map[base_idx] = ( + diffusion_map[base_idx] + + diffusion_map[up] + diffusion_map[down] + + diffusion_map[left] + diffusion_map[right] + + diffusion_map[up_left] + diffusion_map[up_right] + + diffusion_map[down_left] + diffusion_map[down_right] + ) * (1.0 / 9.0) + + return diffusion_map + + + def __segment_instances( + self, + probability_map: np.ndarray, + flow: np.ndarray, + prob_threshold: float = 0.0, + flow_threshold: float = 0.4, + num_iters: int = 200, + min_object_size: int = 0 + ) -> np.ndarray: + """ + Generate instance segmentation masks from probability and flow fields. + + Args: + probability_map: 3D array (channels, height, width) of cell probabilities. + flow: 3D array (2*channels, height, width) of forward flow vectors. + prob_threshold: threshold to binarize probability_map. (Default 0.0) + flow_threshold: threshold for filtering bad flow masks. (Default 0.4) + num_iters: number of iterations for flow-following. (Default 200) + min_object_size: minimum area to keep small instances. (Default 0) + + Returns: + 3D array of uint16 instance labels for each channel. + """ + # Create a binary mask of likely cell locations + probability_mask = probability_map > prob_threshold + + # If no cells exceed the threshold, return an empty mask + if not np.any(probability_mask): + logger.warning("No cell pixels found.") + return np.zeros_like(probability_map, dtype=np.uint16) + + # Prepare output array for instance labels + labeled_instances = np.zeros_like(probability_map, dtype=np.uint16) + + # Process each channel independently + for channel_index in range(probability_mask.shape[0]): + # Extract flow vectors for this channel (two components per channel) + channel_flow_vectors = flow[2 * channel_index : 2 * channel_index + 2] + # Extract binary mask for this channel + channel_mask = probability_mask[channel_index] + + nonzero_coords = np.stack(np.nonzero(channel_mask)) + + # Follow the flow vectors to generate coordinate mappings + flow_coordinates = self.__follow_flows( + flow_field=channel_flow_vectors * channel_mask / 5.0, + initial_coords=nonzero_coords, + num_iters=num_iters + ) + # If flow following fails, leave this channel empty + if flow_coordinates is None: + labeled_instances[channel_index] = np.zeros( + probability_map.shape[1:], dtype=np.uint16 + ) + continue + + if not torch.is_tensor(flow_coordinates): + flow_coordinates = torch.from_numpy( + flow_coordinates).to(self._device, dtype=torch.int32) + else: + flow_coordinates = flow_coordinates.int() + + # Obtain preliminary instance masks by clustering the coordinates + channel_instances_mask = self.__get_mask( + pixel_positions=flow_coordinates, + valid_indices=nonzero_coords, + original_shape=probability_map.shape[1:] + ) + + # Filter out bad flow-derived instances if requested + if channel_instances_mask.max() > 0 and flow_threshold > 0: + channel_instances_mask = self.__remove_inconsistent_flow_masks( + mask=channel_instances_mask, + flow_network=channel_flow_vectors, + error_threshold=flow_threshold + ) + + # Remove small objects or holes below the minimum size + if min_object_size > 0: + # channel_instances_mask = morphology.remove_small_holes( + # channel_instances_mask, area_threshold=min_object_size + # ) + # channel_instances_mask = morphology.remove_small_objects( + # channel_instances_mask, min_size=min_object_size + # ) + channel_instances_mask = self.__fill_holes_and_prune_small_masks( + channel_instances_mask, minimum_size=min_object_size + ) + + labeled_instances[channel_index] = channel_instances_mask + else: + # No valid instances found, leave the channel empty + labeled_instances[channel_index] = np.zeros( + probability_map.shape[1:], dtype=np.uint16 + ) + + return labeled_instances + + + def __follow_flows( + self, + flow_field: np.ndarray, + initial_coords: np.ndarray, + num_iters: int = 200 + ) -> Union[np.ndarray, torch.Tensor]: + """ + Trace pixel positions through a flow field via iterative interpolation. + + Args: + flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors. + initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions. + num_iters (int): Number of integration steps. + + Returns: + np.ndarray or torch.Tensor: Final (y, x) positions of each point. + """ + dims = 2 + # Extract spatial dimensions + height, width = flow_field.shape[1:] + + # Choose GPU/MPS path if available + if self._device.type in ("cuda", "mps"): + # Prepare point tensor: shape [1, 1, num_points, 2] + pts = torch.zeros((1, 1, initial_coords.shape[1], dims), + dtype=torch.float32, device=self._device) + # Prepare flow volume: shape [1, 2, height, width] + flow_vol = torch.zeros((1, dims, height, width), + dtype=torch.float32, device=self._device) + + # Load initial positions and flow into tensors (flip order for grid_sample) + # dim 0 = x + # dim 1 = y + for i in range(dims): + pts[0, 0, :, dims - i - 1] = ( + torch.from_numpy(initial_coords[i]) + .to(self._device, torch.float32) + ) + flow_vol[0, dims - i - 1] = ( + torch.from_numpy(flow_field[i]) + .to(self._device, torch.float32) + ) + + # Prepare normalization factors for x and y (max index) + max_indices = torch.tensor([width - 1, height - 1], + dtype=torch.float32, device=self._device) + # Reshape for broadcasting to point tensor dims + max_idx_pt = max_indices.view(1, 1, 1, dims) + # Reshape for broadcasting to flow volume dims + max_idx_flow = max_indices.view(1, dims, 1, 1) + + # Normalize flow values to [-1, 1] range + flow_vol = (flow_vol * 2) / max_idx_flow + # Normalize points to [-1, 1] + pts = (pts / max_idx_pt) * 2 - 1 + + # Iterate: sample flow and update points + for _ in range(num_iters): + sampled = torch.nn.functional.grid_sample( + flow_vol, pts, align_corners=False + ) + # Update each coordinate and clamp to valid range + for i in range(dims): + pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0) + + # Denormalize back to original pixel coordinates + pts = (pts + 1) * 0.5 * max_idx_pt + # Swap channels back to (y, x) and flatten + final_pts = pts[..., [1, 0]].squeeze() + # Convert from (num_points, 2) to (2, num_points) + return final_pts.T if final_pts.ndim > 1 else final_pts.unsqueeze(0).T + + # CPU fallback using numpy and scipy + current_pos = initial_coords.copy().astype(np.float32) + temp_delta = np.zeros_like(current_pos, dtype=np.float32) + + for _ in range(num_iters): + # Interpolate flow at current positions + self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta) + # Update positions and clamp to image bounds + current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1) + current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1) + + return current_pos + + + @staticmethod + @njit([ + "(int16[:,:,:], float32[:], float32[:], float32[:,:])", + "(float32[:,:,:], float32[:], float32[:], float32[:,:])" + ], cache=True) + def __map_coordinates( + image_data: np.ndarray, + y_coords: np.ndarray, + x_coords: np.ndarray, + output: np.ndarray + ) -> None: + """ + Perform in-place bilinear interpolation on an image volume. + + Args: + image_data (np.ndarray): Input volume with shape (C, H, W). + y_coords (np.ndarray): Array of new y positions (num_points). + x_coords (np.ndarray): Array of new x positions (num_points). + output (np.ndarray): Output array of shape (C, num_points) to fill. + + Returns: + None. Results written directly into `output`. + """ + channels, height, width = image_data.shape + # Compute integer (floor) and fractional parts for coords + y_floor = y_coords.astype(np.int32) + x_floor = x_coords.astype(np.int32) + y_frac = y_coords - y_floor + x_frac = x_coords - x_floor + + # Loop over each sample point + for idx in range(y_floor.shape[0]): + # Clamp base indices to valid range + y0 = min(max(y_floor[idx], 0), height - 1) + x0 = min(max(x_floor[idx], 0), width - 1) + y1 = min(y0 + 1, height - 1) + x1 = min(x0 + 1, width - 1) + + wy = y_frac[idx] + wx = x_frac[idx] + + # Interpolate per channel + for c in range(channels): + v00 = np.float32(image_data[c, y0, x0]) + v10 = np.float32(image_data[c, y0, x1]) + v01 = np.float32(image_data[c, y1, x0]) + v11 = np.float32(image_data[c, y1, x1]) + # Bilinear interpolation formula + output[c, idx] = ( + v00 * (1 - wy) * (1 - wx) + + v10 * (1 - wy) * wx + + v01 * wy * (1 - wx) + + v11 * wy * wx + ) + + + def __get_mask( + self, + pixel_positions: torch.Tensor, + valid_indices: np.ndarray, + original_shape: Tuple[int, ...], + pad_radius: int = 20, + max_size_fraction: float = 0.4 + ) -> np.ndarray: + """ + Generate labeled masks by clustering pixel trajectories via histogram peaks and region growing. + + This function executes the following steps: + 1. Pads and clamps pixel final positions to avoid border effects. + 2. Builds a dense histogram of pixel counts over spatial bins. + 3. Identifies local maxima in the histogram as seed points. + 4. Extracts local patches around each seed and grows regions by iteratively adding neighbors + that exceed an intensity threshold. + 5. Maps grown patches back to original image indices. + 6. Removes any masks that exceed a maximum size fraction of the image. + + Args: + pixel_positions (torch.Tensor): Tensor of shape `[2, N_pixels]`, dtype=int, containing + final pixel coordinates after dynamics for each dimension. + valid_indices (np.ndarray): Integer array of shape `[2, N_pixels]` + giving indices of pixels in the original image grid. + original_shape (tuple of ints): Spatial dimensions of the original image, e.g. (H, W). + pad_radius (int): Number of zero-padding pixels added on each side of the histogram. + Defaults to 20. + max_size_fraction (float): If any mask has a pixel count > max_size_fraction * total_pixels, + it will be removed. Defaults to 0.4. + + Returns: + np.ndarray: Integer mask array of shape `original_shape` with labels 0 (background) and 1..M. + + Raises: + ValueError: If input dimensions are inconsistent or pixel_positions shape is invalid. + """ + # Validate inputs + ndim = len(original_shape) + if pixel_positions.ndim != 2 or pixel_positions.size(0) != ndim: + msg = f"pixel_positions must be shape [{ndim}, N], got {tuple(pixel_positions.shape)}" + logger.error(msg) + raise ValueError(msg) + if pad_radius < 0: + msg = f"pad_radius must be non-negative, got {pad_radius}" + logger.error(msg) + raise ValueError(msg) + + # Step 1: Pad and clamp pixel positions + padded_positions = pixel_positions.clone().to(torch.int64) + pad_radius + for dim in range(ndim): + max_val = original_shape[dim] + pad_radius - 1 + padded_positions[dim] = torch.clamp(padded_positions[dim], min=0, max=max_val) + + # Build histogram dimensions + hist_shape = tuple(s + 2 * pad_radius for s in original_shape) + + # Step 2: Create sparse tensor and densify to get per-pixel counts + try: + counts_sparse = torch.sparse_coo_tensor( + padded_positions, + torch.ones(padded_positions.shape[1], dtype=torch.int32, device=pixel_positions.device), + size=hist_shape + ) + histogram = counts_sparse.to_dense() + except Exception as e: + logger.error("Failed to build dense histogram: %s", e) + raise + + # Step 3: Find peaks via 5x5 max-pooling + k = 5 + pooled = F.max_pool2d( + histogram.unsqueeze(0), + kernel_size=k, + stride=1, + padding=k // 2 + ).squeeze() + # Seeds are positions where histogram equals local max and count > threshold + seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10)) + if seed_positions.numel() == 0: + logger.warning("No seeds found: returning empty mask") + return np.zeros(original_shape, dtype=np.uint16) + + # Sort seeds by ascending count to process small peaks first + seed_counts = histogram[tuple(seed_positions.T)] + order = torch.argsort(seed_counts) + seed_positions = seed_positions[order] + del pooled, counts_sparse + + # Step 4: Extract local patches and perform region growing + num_seeds = seed_positions.shape[0] + # Tensor to hold local patches + patches = torch.zeros((num_seeds, 11, 11), device=pixel_positions.device) + for idx in range(num_seeds): + coords = seed_positions[idx] + slices = tuple(slice(c - 5, c + 6) for c in coords) + patches[idx] = histogram[slices] + del histogram + + # Initialize seed mask (center pixel of each patch) + seed_masks = torch.zeros_like(patches, device=pixel_positions.device) + seed_masks[:, 5, 5] = 1 + # Iterative dilation and thresholding + for _ in range(5): + seed_masks = F.max_pool2d( + seed_masks, + kernel_size=3, + stride=1, + padding=1 + ) + seed_masks = seed_masks & (patches > 2) + # Compute final mask coordinates + final_coords = [] + for idx in range(num_seeds): + coords_local = torch.nonzero(seed_masks[idx]) + # Shift back to global positions + coords_global = coords_local + seed_positions[idx] - 5 + final_coords.append(tuple(coords_global.T)) + + # Step 5: Paint masks into padded volume + dtype = torch.int32 if num_seeds < 2**16 else torch.int64 + mask_padded = torch.zeros(hist_shape, dtype=dtype, device=pixel_positions.device) + for label_idx, coords in enumerate(final_coords, start=1): + mask_padded[coords] = label_idx + + # Extract only the padded positions that correspond to original pixels + mask_values = mask_padded[tuple(padded_positions)] + mask_values = mask_values.cpu().numpy() + + # Step 6: Map to original image and remove oversized masks + mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32) + mask_final[valid_indices] = mask_values + + # Prune masks that are too large + labels, counts = fastremap.unique(mask_final, return_counts=True) + total_pixels = np.prod(original_shape) + oversized = labels[counts > (total_pixels * max_size_fraction)] + if oversized.size > 0: + mask_final = fastremap.mask(mask_final, oversized) + fastremap.renumber(mask_final, in_place=True) + + return mask_final + + + def __remove_inconsistent_flow_masks( + self, + mask: np.ndarray, + flow_network: np.ndarray, + error_threshold: float = 0.4 + ) -> np.ndarray: + """ + Remove labeled masks that have inconsistent optical flows compared to network-predicted flows. + + This performs a quality control step by computing flows from the provided masks + and comparing them to the flows predicted by the network. Masks with a mean squared + flow error above `error_threshold` are discarded (set to 0). + + Args: + mask (np.ndarray): Integer mask array with shape [H, W]. + Values: 0 = no mask; 1,2,... = mask labels. + flow_network (np.ndarray): Float array of network-predicted flows with shape + [2, H, W]. + error_threshold (float): Maximum allowed mean squared flow error per mask label. + Defaults to 0.4. + + Returns: + np.ndarray: The input mask with inconsistent masks removed (labels set to 0). + + Raises: + MemoryError: If the mask size exceeds available GPU memory. + """ + # If mask is very large and running on CUDA, check memory + num_pixels = mask.size + if ( + num_pixels > 10000 * 10000 + + and self._device.type == 'cuda' + ): + # Clear unused GPU cache + torch.cuda.empty_cache() + # Determine PyTorch version + major, minor = map(int, torch.__version__.split('.')[:2]) + # Determine current CUDA device index + device_index = ( + self._device.index + if hasattr(self._device, 'index') + else torch.cuda.current_device() + ) + # Get free and total memory + if major == 1 and minor < 10: + total_mem = torch.cuda.get_device_properties(device_index).total_memory + used_mem = torch.cuda.memory_allocated(device_index) + free_mem = total_mem - used_mem + else: + free_mem, total_mem = torch.cuda.mem_get_info(device_index) + # Estimate required memory for mask-based flow computation + # Assume float32 per pixel + required_bytes = num_pixels * np.dtype(np.float32).itemsize + if required_bytes > free_mem: + logger.error( + 'Image too large for GPU memory in flow QC step (required: %d B, available: %d B)', + required_bytes, free_mem + ) + raise MemoryError('Insufficient GPU memory for flow QC computation') + + # Compute flow errors per mask label + flow_errors, _ = self.__compute_flow_error(mask, flow_network) + + # Identify labels with error above threshold + bad_labels = np.nonzero(flow_errors > error_threshold)[0] + 1 + + # Remove bad masks by setting their label to 0 + mask[np.isin(mask, bad_labels)] = 0 + return mask + + + def __compute_flow_error( + self, + mask: np.ndarray, + flow_network: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute mean squared error between network-predicted flows and flows derived from masks. + + Args: + mask (np.ndarray): Integer masks, shape must match flow_network spatial dims. + flow_network (np.ndarray): Network predicted flows of shape [axis, ...]. + + Returns: + Tuple[np.ndarray, np.ndarray]: + - flow_errors: 1D array (length = max label) of mean squared error per label. + - computed_flows: Array of flows derived from the mask, same shape as flow_network. + + Raises: + ValueError: If the spatial dimensions of `mask_array` and `flow_network` do not match. + """ + # Ensure mask and flow shapes match + if flow_network.shape[1:] != mask.shape: + logger.error( + 'Shape mismatch: network flow shape %s vs mask shape %s', + flow_network.shape[1:], mask.shape + ) + raise ValueError('Network flow and mask shapes must match') + + # Compute flows from mask labels (user-provided function) + computed_flows = self.__compute_flow_from_mask(mask[None, ...]) + + # Prepare array for errors (one value per mask label) + num_labels = int(mask.max()) + flow_errors = np.zeros(num_labels, dtype=float) + + # Accumulate mean squared error over each flow axis + for axis_index in range(computed_flows.shape[0]): + # MSE per label: mean((computed - predicted/5)^2) + flow_errors += mean( + (computed_flows[axis_index] - flow_network[axis_index] / 5.0) ** 2, + mask, + index=np.arange(1, num_labels + 1) + ) + + return flow_errors, computed_flows + + + def __fill_holes_and_prune_small_masks( + self, + masks: np.ndarray, + minimum_size: int = 15 + ) -> np.ndarray: + """ + Fill holes in labeled masks and remove masks smaller than a given size. + + This function performs two steps: + 1. Fills internal holes in each labeled mask using `fill_voids.fill`. + 2. Discards any mask whose pixel count is below `minimum_size`. + + Args: + masks (np.ndarray): Integer mask array of dimension 2 or 3 (shape [H, W] or [D, H, W]). + Values: 0 = background; 1,2,... = mask labels. + minimum_size (int): Minimum number of pixels required to keep a mask. + Masks smaller than this will be removed. + Set to -1 to skip size-based pruning. Defaults to 15. + + Returns: + np.ndarray: Processed mask array with holes filled and small masks removed. + + Raises: + ValueError: If `masks` is not a 2D or 3D integer array. + """ + # Validate input dimensions + if masks.ndim not in (2, 3): + msg = f"Expected 2D or 3D mask array, got {masks.ndim}D." + logger.error(msg) + raise ValueError(msg) + + # Optionally remove masks smaller than minimum_size + if minimum_size >= 0: + # Compute label counts (skipping background at index 0) + labels, counts = fastremap.unique(masks, return_counts=True) + # Identify labels to remove: those with count < minimum_size + small_labels = labels[counts < minimum_size] + if small_labels.size > 0: + masks = fastremap.mask(masks, small_labels) + fastremap.renumber(masks, in_place=True) + + # Find bounding boxes for each mask label + object_slices = find_objects(masks) + new_label = 1 + output_masks = np.zeros_like(masks, dtype=masks.dtype) + + # Loop over each original slice, fill holes, and assign new labels + for original_label, slc in enumerate(object_slices, start=1): + if slc is None: + continue + # Extract sub-volume or sub-image + region = masks[slc] == original_label + if not np.any(region): + continue + # Fill internal holes + filled_region = fill_voids.fill(region) + # Write back into output mask with sequential labels + output_masks[slc][filled_region] = new_label + new_label += 1 + + # Final pruning of small masks after filling (optional) + if minimum_size >= 0: + labels, counts = fastremap.unique(output_masks, return_counts=True) + small_labels = labels[counts < minimum_size] + if small_labels.size > 0: + output_masks = fastremap.mask(output_masks, small_labels) + fastremap.renumber(output_masks, in_place=True) + + return output_masks \ No newline at end of file