You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1447 lines
61 KiB
1447 lines
61 KiB
import torch
|
|
import random
|
|
import numpy as np
|
|
from numba import njit, prange
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
|
|
import fastremap
|
|
|
|
from skimage import morphology
|
|
from scipy.ndimage import mean, find_objects
|
|
import fill_voids
|
|
|
|
from monai.data.dataset import Dataset
|
|
from monai.transforms import * # type: ignore
|
|
|
|
import os
|
|
import glob
|
|
from pprint import pformat
|
|
from typing import Optional, Tuple, List, Union
|
|
|
|
from config import Config
|
|
from core.models import *
|
|
from core.losses import *
|
|
from core.optimizers import *
|
|
from core.schedulers import *
|
|
|
|
from core.logger import get_logger
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class CellSegmentator:
|
|
def __init__(self, config: Config) -> None:
|
|
self.__set_seed(config.dataset_config.common.seed)
|
|
self.__parse_config(config)
|
|
|
|
self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu")
|
|
|
|
self._train_dataloader: Optional[DataLoader] = None
|
|
self._valid_dataloader: Optional[DataLoader] = None
|
|
self._test_dataloader: Optional[DataLoader] = None
|
|
self._predict_dataloader: Optional[DataLoader] = None
|
|
|
|
|
|
def create_dataloaders(
|
|
self,
|
|
train_transforms: Optional[Compose] = None,
|
|
valid_transforms: Optional[Compose] = None,
|
|
test_transforms: Optional[Compose] = None,
|
|
predict_transforms: Optional[Compose] = None
|
|
) -> None:
|
|
"""
|
|
Creates train, validation, test, and prediction dataloaders based on dataset configuration
|
|
and provided transforms.
|
|
|
|
Args:
|
|
train_transforms (Optional[Compose]): Transformations for training data.
|
|
valid_transforms (Optional[Compose]): Transformations for validation data.
|
|
test_transforms (Optional[Compose]): Transformations for testing data.
|
|
predict_transforms (Optional[Compose]): Transformations for prediction data.
|
|
|
|
Raises:
|
|
ValueError: If required transforms are missing.
|
|
RuntimeError: If critical dataset config values are missing.
|
|
"""
|
|
if self._dataset_setup.is_training and train_transforms is None:
|
|
raise ValueError("Training mode requires 'train_transforms' to be provided.")
|
|
elif not self._dataset_setup.is_training and test_transforms is None and predict_transforms is None:
|
|
raise ValueError("In inference mode, at least one of 'test_transforms' or 'predict_transforms' must be provided.")
|
|
|
|
if self._dataset_setup.is_training:
|
|
# Training mode: handle either pre-split datasets or splitting on the fly
|
|
if self._dataset_setup.training.is_split:
|
|
# Validate presence of validation transforms if validation directory and size are set
|
|
if (
|
|
self._dataset_setup.training.pre_split.valid_dir and
|
|
self._dataset_setup.training.valid_size and
|
|
valid_transforms is None
|
|
):
|
|
raise ValueError("Validation transforms must be provided when using pre-split validation data.")
|
|
|
|
# Use explicitly split directories
|
|
train_dir = self._dataset_setup.training.pre_split.train_dir
|
|
valid_dir = self._dataset_setup.training.pre_split.valid_dir
|
|
test_dir = self._dataset_setup.training.pre_split.test_dir
|
|
|
|
train_offset = self._dataset_setup.training.train_offset
|
|
valid_offset = self._dataset_setup.training.valid_offset
|
|
test_offset = self._dataset_setup.training.test_offset
|
|
|
|
shuffle = False
|
|
else:
|
|
# Same validation for split mode with full data directory
|
|
if (
|
|
self._dataset_setup.training.split.all_data_dir and
|
|
self._dataset_setup.training.valid_size and
|
|
valid_transforms is None
|
|
):
|
|
raise ValueError("Validation transforms must be provided when splitting dataset.")
|
|
|
|
# Automatically split dataset from one directory
|
|
train_dir = valid_dir = test_dir = self._dataset_setup.training.split.all_data_dir
|
|
|
|
number_of_images = len(os.listdir(os.path.join(train_dir, 'images')))
|
|
if number_of_images == 0:
|
|
raise FileNotFoundError(f"No images found in '{train_dir}/images'")
|
|
|
|
# Calculate train/valid sizes
|
|
train_size = (
|
|
self._dataset_setup.training.train_size
|
|
if isinstance(self._dataset_setup.training.train_size, int)
|
|
else int(number_of_images * self._dataset_setup.training.train_size)
|
|
)
|
|
valid_size = (
|
|
self._dataset_setup.training.valid_size
|
|
if isinstance(self._dataset_setup.training.valid_size, int)
|
|
else int(number_of_images * self._dataset_setup.training.valid_size)
|
|
)
|
|
|
|
train_offset = self._dataset_setup.training.train_offset
|
|
valid_offset = self._dataset_setup.training.valid_offset + train_size
|
|
test_offset = self._dataset_setup.training.test_offset + train_size + valid_size
|
|
|
|
shuffle = True
|
|
|
|
# Train dataloader
|
|
train_dataset = self.__get_dataset(
|
|
images_dir=os.path.join(train_dir, 'images'),
|
|
masks_dir=os.path.join(train_dir, 'masks'),
|
|
transforms=train_transforms, # type: ignore
|
|
size=self._dataset_setup.training.train_size,
|
|
offset=train_offset,
|
|
shuffle=shuffle
|
|
)
|
|
self._train_dataloader = DataLoader(train_dataset, batch_size=self._dataset_setup.training.batch_size, shuffle=True)
|
|
logger.info(f"Loaded training dataset with {len(train_dataset)} samples.")
|
|
|
|
# Validation dataloader
|
|
if valid_transforms is not None:
|
|
if not valid_dir or not self._dataset_setup.training.valid_size:
|
|
raise RuntimeError("Validation directory or size is not properly configured.")
|
|
valid_dataset = self.__get_dataset(
|
|
images_dir=os.path.join(valid_dir, 'images'),
|
|
masks_dir=os.path.join(valid_dir, 'masks'),
|
|
transforms=valid_transforms,
|
|
size=self._dataset_setup.training.valid_size,
|
|
offset=valid_offset,
|
|
shuffle=shuffle
|
|
)
|
|
self._valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
|
|
logger.info(f"Loaded validation dataset with {len(valid_dataset)} samples.")
|
|
|
|
# Test dataloader
|
|
if test_transforms is not None:
|
|
if not test_dir or not self._dataset_setup.training.test_size:
|
|
raise RuntimeError("Test directory or size is not properly configured.")
|
|
test_dataset = self.__get_dataset(
|
|
images_dir=os.path.join(test_dir, 'images'),
|
|
masks_dir=os.path.join(test_dir, 'masks'),
|
|
transforms=test_transforms,
|
|
size=self._dataset_setup.training.test_size,
|
|
offset=test_offset,
|
|
shuffle=shuffle
|
|
)
|
|
self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
|
logger.info(f"Loaded test dataset with {len(test_dataset)} samples.")
|
|
|
|
# Prediction dataloader
|
|
if predict_transforms is not None:
|
|
if not test_dir or not self._dataset_setup.training.test_size:
|
|
raise RuntimeError("Prediction directory or size is not properly configured.")
|
|
predict_dataset = self.__get_dataset(
|
|
images_dir=os.path.join(test_dir, 'images'),
|
|
masks_dir=None,
|
|
transforms=predict_transforms,
|
|
size=self._dataset_setup.training.test_size,
|
|
offset=test_offset,
|
|
shuffle=shuffle
|
|
)
|
|
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
|
|
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
|
|
|
|
else:
|
|
# Inference mode (no training)
|
|
test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images')
|
|
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks')
|
|
|
|
if test_transforms is not None:
|
|
test_dataset = self.__get_dataset(
|
|
images_dir=test_images,
|
|
masks_dir=test_masks,
|
|
transforms=test_transforms,
|
|
size=self._dataset_setup.testing.test_size,
|
|
offset=self._dataset_setup.testing.test_offset,
|
|
shuffle=self._dataset_setup.testing.shuffle
|
|
)
|
|
self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
|
logger.info(f"Loaded test dataset with {len(test_dataset)} samples.")
|
|
|
|
if predict_transforms is not None:
|
|
predict_dataset = self.__get_dataset(
|
|
images_dir=test_images,
|
|
masks_dir=None,
|
|
transforms=predict_transforms,
|
|
size=self._dataset_setup.testing.test_size,
|
|
offset=self._dataset_setup.testing.test_offset,
|
|
shuffle=self._dataset_setup.testing.shuffle
|
|
)
|
|
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
|
|
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
|
|
|
|
|
|
def train(self) -> None:
|
|
pass
|
|
|
|
|
|
def evaluate(self) -> None:
|
|
pass
|
|
|
|
|
|
def predict(self) -> None:
|
|
pass
|
|
|
|
|
|
def __parse_config(self, config: Config) -> None:
|
|
"""
|
|
Parses the given configuration object to initialize model, criterion,
|
|
optimizer, scheduler, and dataset setup.
|
|
|
|
Args:
|
|
config (Config): Configuration object with model, optimizer,
|
|
scheduler, criterion, and dataset setup information.
|
|
"""
|
|
model = config.model
|
|
criterion = config.criterion
|
|
optimizer = config.optimizer
|
|
scheduler = config.scheduler
|
|
|
|
logger.info("========== Parsed Configuration ==========")
|
|
logger.info("Model Config:\n%s", pformat(model.dump(), indent=2))
|
|
if criterion:
|
|
logger.info("Criterion Config:\n%s", pformat(criterion.dump(), indent=2))
|
|
if optimizer:
|
|
logger.info("Optimizer Config:\n%s", pformat(optimizer.dump(), indent=2))
|
|
if scheduler:
|
|
logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2))
|
|
logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2))
|
|
logger.info("==========================================")
|
|
|
|
# Initialize model using the model registry
|
|
self._model = ModelRegistry.get_model_class(model.name)(model.params)
|
|
|
|
# Initialize loss criterion if specified
|
|
self._criterion = (
|
|
CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params)
|
|
if criterion is not None
|
|
else None
|
|
)
|
|
|
|
# Initialize optimizer if specified
|
|
self._optimizer = (
|
|
OptimizerRegistry.get_optimizer_class(optimizer.name)(
|
|
model_params=self._model.parameters(),
|
|
optim_params=optimizer.params
|
|
)
|
|
if optimizer is not None
|
|
else None
|
|
)
|
|
|
|
# Initialize scheduler only if both scheduler and optimizer are defined
|
|
self._scheduler = (
|
|
SchedulerRegistry.get_scheduler_class(scheduler.name)(
|
|
optimizer=self._optimizer.optim,
|
|
params=scheduler.params
|
|
)
|
|
if scheduler is not None and self._optimizer is not None and self._optimizer.optim is not None
|
|
else None
|
|
)
|
|
|
|
logger.info("========== Model Components Initialization ==========")
|
|
logger.info("├─ Model: " + (f"{model.name}" if self._model else "Not specified"))
|
|
logger.info("├─ Criterion: " + (f"{criterion.name}" if self._criterion else "Not specified")) # type: ignore
|
|
logger.info("├─ Optimizer: " + (f"{optimizer.name}" if self._optimizer else "Not specified")) # type: ignore
|
|
logger.info("└─ Scheduler: " + (f"{scheduler.name}" if self._scheduler else "Not specified")) # type: ignore
|
|
logger.info("=====================================================")
|
|
|
|
|
|
# Save dataset config
|
|
self._dataset_setup = config.dataset_config
|
|
common = config.dataset_config.common
|
|
|
|
logger.info("========== Dataset Setup ==========")
|
|
logger.info("[COMMON]")
|
|
logger.info(f"├─ Seed: {common.seed}")
|
|
logger.info(f"├─ Device: {common.device}")
|
|
logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
|
|
|
|
if config.dataset_config.is_training:
|
|
training = config.dataset_config.training
|
|
logger.info("[MODE] Training")
|
|
logger.info(f"├─ Batch size: {training.batch_size}")
|
|
logger.info(f"├─ Epochs: {training.num_epochs}")
|
|
logger.info(f"├─ Validation frequency: {training.val_freq}")
|
|
logger.info(f"├─ Use AMP: {'yes' if training.use_amp else 'no'}")
|
|
logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}")
|
|
|
|
if training.is_split:
|
|
logger.info(f"├─ Using pre-split directories:")
|
|
logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}")
|
|
logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}")
|
|
logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}")
|
|
else:
|
|
logger.info(f"├─ Using unified dataset with splits:")
|
|
logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}")
|
|
logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}")
|
|
|
|
logger.info(f"└─ Dataset split:")
|
|
logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}")
|
|
logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}")
|
|
logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}")
|
|
|
|
else:
|
|
testing = config.dataset_config.testing
|
|
logger.info("[MODE] Inference")
|
|
logger.info(f"├─ Test dir: {testing.test_dir}")
|
|
logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})")
|
|
logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}")
|
|
logger.info(f"├─ Use ensemble: {'yes' if testing.use_ensemble else 'no'}")
|
|
logger.info(f"└─ Pretrained weights:")
|
|
logger.info(f" ├─ Single model: {testing.pretrained_weights}")
|
|
logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}")
|
|
logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}")
|
|
|
|
logger.info("===================================")
|
|
|
|
|
|
def __set_seed(self, seed: Optional[int]) -> None:
|
|
"""
|
|
Sets the random seed for reproducibility across Python, NumPy, and PyTorch.
|
|
|
|
Args:
|
|
seed (Optional[int]): Seed value. If None, no seeding is performed.
|
|
"""
|
|
if seed is not None:
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
logger.info(f"Random seed set to {seed}")
|
|
else:
|
|
logger.info("Seed not set (None provided)")
|
|
|
|
|
|
def __get_dataset(
|
|
self,
|
|
images_dir: str,
|
|
masks_dir: Optional[str],
|
|
transforms: Compose,
|
|
size: Union[int, float],
|
|
offset: int,
|
|
shuffle: bool
|
|
) -> Dataset:
|
|
"""
|
|
Loads and returns a dataset object from image and optional mask directories.
|
|
|
|
Args:
|
|
images_dir (str): Path to directory or glob pattern for input images.
|
|
masks_dir (Optional[str]): Path to directory or glob pattern for masks.
|
|
transforms (Compose): Transformations to apply to each image or pair.
|
|
size (Union[int, float]): Either an integer or a fraction of the dataset.
|
|
offset (int): Number of images to skip from the start.
|
|
shuffle (bool): Whether to shuffle the dataset before slicing.
|
|
|
|
Returns:
|
|
Dataset: A dataset containing image and optional mask paths.
|
|
|
|
Raises:
|
|
FileNotFoundError: If no images are found.
|
|
ValueError: If masks are provided but do not match image count.
|
|
ValueError: If dataset is too small for requested size or offset.
|
|
"""
|
|
# Collect sorted list of image paths
|
|
images = sorted(glob.glob(images_dir))
|
|
if not images:
|
|
raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'")
|
|
|
|
if masks_dir is not None:
|
|
# Collect and validate sorted list of mask paths
|
|
masks = sorted(glob.glob(masks_dir))
|
|
if len(images) != len(masks):
|
|
raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})")
|
|
|
|
# Convert float size (fraction) to absolute count
|
|
size = size if isinstance(size, int) else int(size * len(images))
|
|
|
|
if size <= 0:
|
|
raise ValueError(f"Size must be positive, got: {size}")
|
|
|
|
if len(images) < size:
|
|
raise ValueError(f"Not enough images ({len(images)}) for requested size ({size})")
|
|
|
|
if len(images) < size + offset:
|
|
raise ValueError(f"Offset ({offset}) + size ({size}) exceeds dataset length ({len(images)})")
|
|
|
|
# Shuffle image-mask pairs if requested
|
|
if shuffle:
|
|
if masks_dir is not None:
|
|
combined = list(zip(images, masks)) # type: ignore
|
|
random.shuffle(combined)
|
|
images, masks = zip(*combined)
|
|
else:
|
|
random.shuffle(images)
|
|
|
|
# Apply offset and limit by size
|
|
images = images[offset: offset + size]
|
|
if masks_dir is not None:
|
|
masks = masks[offset: offset + size] # type: ignore
|
|
|
|
# Prepare data structure for Dataset class
|
|
if masks_dir is not None:
|
|
data = [
|
|
{"image": image, "mask": mask}
|
|
for image, mask in zip(images, masks) # type: ignore
|
|
]
|
|
else:
|
|
data = [{"image": image} for image in images]
|
|
|
|
return Dataset(data, transforms)
|
|
|
|
|
|
def __compute_flows_from_masks(
|
|
self,
|
|
true_masks: Tensor
|
|
) -> np.ndarray:
|
|
"""
|
|
Convert segmentation masks to flow fields for training.
|
|
|
|
Args:
|
|
true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks.
|
|
|
|
Returns:
|
|
numpy array of concatenated [renumbered_true_masks, binary_masks, flow_vectors] per image.
|
|
renumbered_true_masks is labels, binary_masks is cell distance transform, flow_vectors[2] is Y flow, flows[k][3] is X flow,
|
|
and flow_vectors[4] is heat distribution.
|
|
"""
|
|
# Move to CPU numpy
|
|
_true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16)
|
|
batch_size = _true_masks.shape[0]
|
|
|
|
# Ensure each label has a channel dimension
|
|
if _true_masks.ndim == 3:
|
|
# shape (batch, H, W) -> (batch, 1, H, W)
|
|
_true_masks = _true_masks[:, np.newaxis, :, :]
|
|
|
|
batch_size, *_ = _true_masks.shape
|
|
|
|
# Renumber labels to ensure uniqueness
|
|
renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0]
|
|
for i in range(batch_size)])
|
|
# Compute vector flows per image
|
|
flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i])
|
|
for i in range(batch_size)])
|
|
|
|
return np.concatenate((renumbered, renumbered > 0.5, flow_vectors), axis=1).astype(np.float32)
|
|
|
|
|
|
def __compute_flow_from_mask(
|
|
self,
|
|
mask: np.ndarray
|
|
) -> np.ndarray:
|
|
"""
|
|
Compute normalized flow vectors from a labeled mask.
|
|
|
|
Args:
|
|
mask: 3D array of instance-labeled mask of shape (C, H, W).
|
|
|
|
Returns:
|
|
flow: Array of shape (2 * C, H, W).
|
|
"""
|
|
if mask.max() == 0 or np.count_nonzero(mask) <= 1:
|
|
# No flow to compute
|
|
logger.warning("Empty mask!")
|
|
C, H, W = mask.shape
|
|
return np.zeros((2*C, H, W), dtype=np.float32)
|
|
|
|
# Delegate to GPU or CPU routine
|
|
if self._device.type == "cuda" or self._device.type == "mps":
|
|
return self.__mask_to_flow_gpu(mask)
|
|
else:
|
|
return self.__mask_to_flow_cpu(mask)
|
|
|
|
|
|
def __mask_to_flow_gpu(self, mask: np.ndarray) -> np.ndarray:
|
|
"""Convert masks to flows using diffusion from center pixel.
|
|
|
|
Center of masks where diffusion starts is defined by pixel closest to median within the mask.
|
|
|
|
Args:
|
|
masks (3D array): Labelled masks of shape (C, H, W).
|
|
|
|
Returns:
|
|
np.ndarray: A 3D array where for each channel the flows for each pixel
|
|
are represented along the X and Y axes.
|
|
"""
|
|
|
|
channels, height, width = mask.shape
|
|
flows = np.zeros((2*channels, height, width), np.float32)
|
|
|
|
for channel in range(channels):
|
|
padded_height, padded_width = height + 2, width + 2
|
|
|
|
# Pad the mask with a 1-pixel border
|
|
masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device)
|
|
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
|
|
|
|
# Get coordinates of all non-zero pixels in the padded mask
|
|
y, x = torch.nonzero(masks_padded, as_tuple=True)
|
|
y = y.int(); x = x.int() # ensure integer type
|
|
|
|
# Generate 8-connected neighbors (including center) via broadcasted offsets
|
|
offsets = torch.tensor([
|
|
[ 0, 0], # center
|
|
[-1, 0], # up
|
|
[ 1, 0], # down
|
|
[ 0, -1], # left
|
|
[ 0, 1], # right
|
|
[-1, -1], # up-left
|
|
[-1, 1], # up-right
|
|
[ 1, -1], # down-left
|
|
[ 1, 1], # down-right
|
|
], dtype=torch.int32, device=self._device) # (9, 2)
|
|
|
|
# coords: (N, 2)
|
|
coords = torch.stack((y, x), dim=1)
|
|
|
|
# neighbors: (9, N, 2)
|
|
neighbors = offsets[:, None, :] + coords[None, :, :]
|
|
|
|
# transpose into (2, 9, N) for the GPU kernel
|
|
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
|
|
|
|
# Build connectivity mask: True where neighbor label == center label
|
|
center_labels = masks_padded[y, x][None, :] # (1, N)
|
|
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
|
|
is_neighbor = neighbor_labels == center_labels # (9, N)
|
|
|
|
# Compute object slices and pack into array for get_centers
|
|
slices = find_objects(mask)
|
|
slices_arr = np.array([
|
|
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
|
|
for i, sl in enumerate(slices) if sl is not None
|
|
], dtype=int)
|
|
|
|
# Compute centers (pixel indices) and extents via the provided helper
|
|
centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr)
|
|
# Move centers to GPU and shift by +1 for padding
|
|
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding
|
|
|
|
# Determine number of diffusion iterations
|
|
n_iter = 2 * ext.max()
|
|
|
|
# Run the GPU diffusion kernel
|
|
mu = self.__propagate_centers_gpu(
|
|
neighbor_indices=neighbors,
|
|
center_indices=meds_p.T,
|
|
valid_neighbor_mask=is_neighbor,
|
|
output_shape=(padded_height, padded_width),
|
|
num_iterations=n_iter
|
|
)
|
|
|
|
# Cast to float64 and normalize flow vectors
|
|
mu = mu.astype(np.float64)
|
|
mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60
|
|
|
|
# Remove the padding and write into final output
|
|
flow_output = np.zeros((2, height, width), dtype=np.float32)
|
|
ys_np = y.cpu().numpy() - 1
|
|
xs_np = x.cpu().numpy() - 1
|
|
flow_output[:, ys_np, xs_np] = mu
|
|
flows[2*channel: 2*channel + 2] = flow_output
|
|
|
|
return flows
|
|
|
|
|
|
@staticmethod
|
|
@njit(nogil=True)
|
|
def __get_mask_centers_and_extents(
|
|
label_map: np.ndarray,
|
|
slices_arr: np.ndarray
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Compute the centroids and extents of labeled regions in a 2D mask array.
|
|
|
|
Args:
|
|
label_map (np.ndarray): 2D array where each connected region has a unique integer label (1…K).
|
|
slices_arr (np.ndarray): Array of shape (K, 5), where each row is
|
|
(label_id, row_start, row_stop, col_start, col_stop).
|
|
|
|
Returns:
|
|
centers (np.ndarray): Integer array of shape (K, 2) with (row, col) center for each label.
|
|
extents (np.ndarray): Integer array of shape (K,) giving the sum of height and width + 2 for each region.
|
|
"""
|
|
num_regions = slices_arr.shape[0]
|
|
centers = np.zeros((num_regions, 2), dtype=np.int32)
|
|
extents = np.zeros(num_regions, dtype=np.int32)
|
|
|
|
for idx in prange(num_regions):
|
|
# Unpack slice info
|
|
label_id = slices_arr[idx, 0]
|
|
row_start = slices_arr[idx, 1]
|
|
row_stop = slices_arr[idx, 2]
|
|
col_start = slices_arr[idx, 3]
|
|
col_stop = slices_arr[idx, 4]
|
|
|
|
# Extract binary submask for this label
|
|
submask = (label_map[row_start:row_stop, col_start:col_stop] == label_id)
|
|
|
|
# Get local coordinates of all pixels in the region
|
|
ys, xs = np.nonzero(submask)
|
|
|
|
# Compute the floating-point centroid within the submask
|
|
y_mean = ys.mean()
|
|
x_mean = xs.mean()
|
|
|
|
# Find the pixel closest to the centroid by minimizing squared distance
|
|
dist_sq = (ys - y_mean) ** 2 + (xs - x_mean) ** 2
|
|
closest_idx = dist_sq.argmin()
|
|
|
|
# Convert to global coordinates
|
|
center_row = ys[closest_idx] + row_start
|
|
center_col = xs[closest_idx] + col_start
|
|
centers[idx, 0] = center_row
|
|
centers[idx, 1] = center_col
|
|
|
|
# Compute extent as height + width + 2 (to include one-pixel border)
|
|
height = row_stop - row_start
|
|
width = col_stop - col_start
|
|
extents[idx] = height + width + 2
|
|
|
|
return centers, extents
|
|
|
|
|
|
def __propagate_centers_gpu(
|
|
self,
|
|
neighbor_indices: torch.Tensor,
|
|
center_indices: torch.Tensor,
|
|
valid_neighbor_mask: torch.Tensor,
|
|
output_shape: Tuple[int, int],
|
|
num_iterations: int = 200
|
|
) -> np.ndarray:
|
|
"""
|
|
Propagates center points across a mask using GPU-based diffusion.
|
|
|
|
Args:
|
|
neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel.
|
|
center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers.
|
|
valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid.
|
|
output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W).
|
|
num_iterations (int, optional): Number of diffusion iterations. Defaults to 200.
|
|
|
|
Returns:
|
|
np.ndarray: Array of shape (2, N) with the computed flows.
|
|
"""
|
|
# Determine total number of elements and choose dtype accordingly
|
|
total_elems = torch.prod(torch.tensor(output_shape))
|
|
if total_elems > 4e7 or self._device.type == "mps":
|
|
diffusion_tensor = torch.zeros(output_shape, dtype=torch.float, device=self._device)
|
|
else:
|
|
diffusion_tensor = torch.zeros(output_shape, dtype=torch.double, device=self._device)
|
|
|
|
# Unpack center row and column indices
|
|
center_rows, center_cols = center_indices
|
|
|
|
# Unpack neighbor row and column indices for 9 neighbors per pixel
|
|
# Order: [0: center, 1: up, 2: down, 3: left, 4: right,
|
|
# 5: up-left, 6: up-right, 7: down-left, 8: down-right]
|
|
neigh_rows, neigh_cols = neighbor_indices # each of shape (9, N)
|
|
|
|
# Perform diffusion iterations
|
|
for _ in range(num_iterations):
|
|
# Add source at each mask center
|
|
diffusion_tensor[center_rows, center_cols] += 1
|
|
|
|
# Sample neighbor values for each pixel
|
|
neighbor_vals = diffusion_tensor[neigh_rows, neigh_cols] # shape (9, N)
|
|
|
|
# Zero out invalid neighbors
|
|
neighbor_vals *= valid_neighbor_mask
|
|
|
|
# Update the first neighbor (index 0) with the average of valid neighbor values
|
|
diffusion_tensor[neigh_rows[0], neigh_cols[0]] = neighbor_vals.mean(dim=0)
|
|
|
|
# Compute spatial gradients for 2D flow: dy and dx
|
|
# Using neighbor indices: up = 1, down = 2, left = 3, right = 4
|
|
grad_samples = diffusion_tensor[
|
|
neigh_rows[[2, 1, 4, 3]], # indices [down, up, right, left]
|
|
neigh_cols[[2, 1, 4, 3]]
|
|
] # shape (4, N)
|
|
|
|
dy = grad_samples[0] - grad_samples[1]
|
|
dx = grad_samples[2] - grad_samples[3]
|
|
|
|
# Stack and convert to numpy flow field with shape (2, N)
|
|
flow_field = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=0)
|
|
|
|
return flow_field
|
|
|
|
|
|
def __mask_to_flow_cpu(self, mask: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Convert labeled masks to flow vectors by simulating diffusion from mask centers.
|
|
|
|
Each mask's center is chosen as the pixel closest to its geometric centroid.
|
|
A diffusion process is run on a padded local patch, and flows are derived
|
|
as gradients (dy, dx) of the resulting density map.
|
|
|
|
Args:
|
|
masks (np.ndarray): 3D integer array of labels `(C x H x W)`,
|
|
where 0 = background and positive integers = mask IDs.
|
|
|
|
Returns:
|
|
flow_field (np.ndarray): Array of shape `(2*C, H, W)` containing
|
|
flow components [dy, dx] normalized per pixel.
|
|
"""
|
|
channels, height, width = mask.shape
|
|
flows = np.zeros((2*channels, height, width), np.float32)
|
|
|
|
for channel in range(channels):
|
|
# Initialize flow_field with two channels: dy and dx
|
|
flow_field = np.zeros((2, height, width), dtype=np.float64)
|
|
|
|
# Find bounding box for each labeled mask
|
|
mask_slices = find_objects(mask)
|
|
# centers: List[Tuple[int, int]] = []
|
|
|
|
# Iterate over mask labels in parallel
|
|
for label_idx in prange(len(mask_slices)):
|
|
slc = mask_slices[label_idx]
|
|
if slc is None:
|
|
continue
|
|
|
|
# Extract row and column slice for this mask
|
|
row_slice, col_slice = slc
|
|
# Add 1-pixel border around the patch
|
|
patch_height = (row_slice.stop - row_slice.start) + 2
|
|
patch_width = (col_slice.stop - col_slice.start) + 2
|
|
|
|
# Get local coordinates of mask pixels within the patch
|
|
local_rows, local_cols = np.nonzero(
|
|
mask[row_slice, col_slice] == (label_idx + 1)
|
|
)
|
|
# Shift coords by +1 for the border padding
|
|
local_rows = local_rows.astype(np.int32) + 1
|
|
local_cols = local_cols.astype(np.int32) + 1
|
|
|
|
# Compute centroid and find nearest pixel as diffusion seed
|
|
centroid_row = local_rows.mean()
|
|
centroid_col = local_cols.mean()
|
|
distances = (local_cols - centroid_col) ** 2 + (local_rows - centroid_row) ** 2
|
|
seed_index = distances.argmin()
|
|
center_row = int(local_rows[seed_index])
|
|
center_col = int(local_cols[seed_index])
|
|
|
|
# Determine number of iterations
|
|
total_iter = 2 * (patch_height + patch_width)
|
|
|
|
# Initialize flat diffusion map for the local patch
|
|
diffusion_map = np.zeros(patch_height * patch_width, dtype=np.float64)
|
|
# Run diffusion from the seed center
|
|
diffusion_map = self.__diffuse_from_center(
|
|
diffusion_map,
|
|
local_rows,
|
|
local_cols,
|
|
center_row,
|
|
center_col,
|
|
patch_width,
|
|
total_iter
|
|
)
|
|
|
|
# Compute flow as finite differences (gradient) on the diffusion map
|
|
dy = (
|
|
diffusion_map[(local_rows + 1) * patch_width + local_cols] -
|
|
diffusion_map[(local_rows - 1) * patch_width + local_cols]
|
|
)
|
|
dx = (
|
|
diffusion_map[local_rows * patch_width + (local_cols + 1)] -
|
|
diffusion_map[local_rows * patch_width + (local_cols - 1)]
|
|
)
|
|
|
|
# Write flows back into the global flow_field array
|
|
flow_field[0,
|
|
row_slice.start + local_rows - 1,
|
|
col_slice.start + local_cols - 1] = dy
|
|
flow_field[1,
|
|
row_slice.start + local_rows - 1,
|
|
col_slice.start + local_cols - 1] = dx
|
|
|
|
# Store center location in original image coordinates
|
|
# centers.append(
|
|
# (row_slice.start + center_row - 1,
|
|
# col_slice.start + center_col - 1)
|
|
# )
|
|
|
|
# Normalize each vector [dy,dx] by its magnitude
|
|
magnitudes = np.sqrt((flow_field**2).sum(axis=0)) + 1e-60
|
|
flow_field /= magnitudes
|
|
|
|
flows[2*channel: 2*channel + 2] = flow_field
|
|
|
|
return flows
|
|
|
|
|
|
@staticmethod
|
|
@njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True)
|
|
def __diffuse_from_center(
|
|
diffusion_map: np.ndarray,
|
|
row_coords: np.ndarray,
|
|
col_coords: np.ndarray,
|
|
center_row: int,
|
|
center_col: int,
|
|
patch_width: int,
|
|
num_iterations: int
|
|
) -> np.ndarray:
|
|
"""
|
|
Perform diffusion of particles from a seed pixel across a local mask patch.
|
|
|
|
At each iteration, one particle is added at the seed, and each mask pixel's
|
|
value is updated to the average of itself and its 8-connected neighbors.
|
|
|
|
Args:
|
|
diffusion_map (np.ndarray): Flat array of length patch_height * patch_width.
|
|
row_coords (np.ndarray): 1D array of row indices for mask pixels (local coords).
|
|
col_coords (np.ndarray): 1D array of column indices for mask pixels (local coords).
|
|
center_row (int): Row index of the seed point in local patch coords.
|
|
center_col (int): Column index of the seed point in local patch coords.
|
|
patch_width (int): Width (number of columns) in the local patch.
|
|
num_iterations (int): Number of diffusion iterations to perform.
|
|
|
|
Returns:
|
|
np.ndarray: Updated diffusion_map after performing diffusion.
|
|
"""
|
|
# Compute linear indices for each mask pixel and its neighbors
|
|
base_idx = row_coords * patch_width + col_coords
|
|
up = (row_coords - 1) * patch_width + col_coords
|
|
down = (row_coords + 1) * patch_width + col_coords
|
|
left = row_coords * patch_width + (col_coords - 1)
|
|
right = row_coords * patch_width + (col_coords + 1)
|
|
up_left = (row_coords - 1) * patch_width + (col_coords - 1)
|
|
up_right = (row_coords - 1) * patch_width + (col_coords + 1)
|
|
down_left = (row_coords + 1) * patch_width + (col_coords - 1)
|
|
down_right = (row_coords + 1) * patch_width + (col_coords + 1)
|
|
|
|
for _ in range(num_iterations):
|
|
# Inject one particle at the seed location
|
|
diffusion_map[center_row * patch_width + center_col] += 1.0
|
|
|
|
# Update each mask pixel as the average over itself and neighbors
|
|
diffusion_map[base_idx] = (
|
|
diffusion_map[base_idx] +
|
|
diffusion_map[up] + diffusion_map[down] +
|
|
diffusion_map[left] + diffusion_map[right] +
|
|
diffusion_map[up_left] + diffusion_map[up_right] +
|
|
diffusion_map[down_left] + diffusion_map[down_right]
|
|
) * (1.0 / 9.0)
|
|
|
|
return diffusion_map
|
|
|
|
|
|
def __segment_instances(
|
|
self,
|
|
probability_map: np.ndarray,
|
|
flow: np.ndarray,
|
|
prob_threshold: float = 0.0,
|
|
flow_threshold: float = 0.4,
|
|
num_iters: int = 200,
|
|
min_object_size: int = 0
|
|
) -> np.ndarray:
|
|
"""
|
|
Generate instance segmentation masks from probability and flow fields.
|
|
|
|
Args:
|
|
probability_map: 3D array (channels, height, width) of cell probabilities.
|
|
flow: 3D array (2*channels, height, width) of forward flow vectors.
|
|
prob_threshold: threshold to binarize probability_map. (Default 0.0)
|
|
flow_threshold: threshold for filtering bad flow masks. (Default 0.4)
|
|
num_iters: number of iterations for flow-following. (Default 200)
|
|
min_object_size: minimum area to keep small instances. (Default 0)
|
|
|
|
Returns:
|
|
3D array of uint16 instance labels for each channel.
|
|
"""
|
|
# Create a binary mask of likely cell locations
|
|
probability_mask = probability_map > prob_threshold
|
|
|
|
# If no cells exceed the threshold, return an empty mask
|
|
if not np.any(probability_mask):
|
|
logger.warning("No cell pixels found.")
|
|
return np.zeros_like(probability_map, dtype=np.uint16)
|
|
|
|
# Prepare output array for instance labels
|
|
labeled_instances = np.zeros_like(probability_map, dtype=np.uint16)
|
|
|
|
# Process each channel independently
|
|
for channel_index in range(probability_mask.shape[0]):
|
|
# Extract flow vectors for this channel (two components per channel)
|
|
channel_flow_vectors = flow[2 * channel_index : 2 * channel_index + 2]
|
|
# Extract binary mask for this channel
|
|
channel_mask = probability_mask[channel_index]
|
|
|
|
nonzero_coords = np.stack(np.nonzero(channel_mask))
|
|
|
|
# Follow the flow vectors to generate coordinate mappings
|
|
flow_coordinates = self.__follow_flows(
|
|
flow_field=channel_flow_vectors * channel_mask / 5.0,
|
|
initial_coords=nonzero_coords,
|
|
num_iters=num_iters
|
|
)
|
|
# If flow following fails, leave this channel empty
|
|
if flow_coordinates is None:
|
|
labeled_instances[channel_index] = np.zeros(
|
|
probability_map.shape[1:], dtype=np.uint16
|
|
)
|
|
continue
|
|
|
|
if not torch.is_tensor(flow_coordinates):
|
|
flow_coordinates = torch.from_numpy(
|
|
flow_coordinates).to(self._device, dtype=torch.int32)
|
|
else:
|
|
flow_coordinates = flow_coordinates.int()
|
|
|
|
# Obtain preliminary instance masks by clustering the coordinates
|
|
channel_instances_mask = self.__get_mask(
|
|
pixel_positions=flow_coordinates,
|
|
valid_indices=nonzero_coords,
|
|
original_shape=probability_map.shape[1:]
|
|
)
|
|
|
|
# Filter out bad flow-derived instances if requested
|
|
if channel_instances_mask.max() > 0 and flow_threshold > 0:
|
|
channel_instances_mask = self.__remove_inconsistent_flow_masks(
|
|
mask=channel_instances_mask,
|
|
flow_network=channel_flow_vectors,
|
|
error_threshold=flow_threshold
|
|
)
|
|
|
|
# Remove small objects or holes below the minimum size
|
|
if min_object_size > 0:
|
|
# channel_instances_mask = morphology.remove_small_holes(
|
|
# channel_instances_mask, area_threshold=min_object_size
|
|
# )
|
|
# channel_instances_mask = morphology.remove_small_objects(
|
|
# channel_instances_mask, min_size=min_object_size
|
|
# )
|
|
channel_instances_mask = self.__fill_holes_and_prune_small_masks(
|
|
channel_instances_mask, minimum_size=min_object_size
|
|
)
|
|
|
|
labeled_instances[channel_index] = channel_instances_mask
|
|
else:
|
|
# No valid instances found, leave the channel empty
|
|
labeled_instances[channel_index] = np.zeros(
|
|
probability_map.shape[1:], dtype=np.uint16
|
|
)
|
|
|
|
return labeled_instances
|
|
|
|
|
|
def __follow_flows(
|
|
self,
|
|
flow_field: np.ndarray,
|
|
initial_coords: np.ndarray,
|
|
num_iters: int = 200
|
|
) -> Union[np.ndarray, torch.Tensor]:
|
|
"""
|
|
Trace pixel positions through a flow field via iterative interpolation.
|
|
|
|
Args:
|
|
flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors.
|
|
initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions.
|
|
num_iters (int): Number of integration steps.
|
|
|
|
Returns:
|
|
np.ndarray or torch.Tensor: Final (y, x) positions of each point.
|
|
"""
|
|
dims = 2
|
|
# Extract spatial dimensions
|
|
height, width = flow_field.shape[1:]
|
|
|
|
# Choose GPU/MPS path if available
|
|
if self._device.type in ("cuda", "mps"):
|
|
# Prepare point tensor: shape [1, 1, num_points, 2]
|
|
pts = torch.zeros((1, 1, initial_coords.shape[1], dims),
|
|
dtype=torch.float32, device=self._device)
|
|
# Prepare flow volume: shape [1, 2, height, width]
|
|
flow_vol = torch.zeros((1, dims, height, width),
|
|
dtype=torch.float32, device=self._device)
|
|
|
|
# Load initial positions and flow into tensors (flip order for grid_sample)
|
|
# dim 0 = x
|
|
# dim 1 = y
|
|
for i in range(dims):
|
|
pts[0, 0, :, dims - i - 1] = (
|
|
torch.from_numpy(initial_coords[i])
|
|
.to(self._device, torch.float32)
|
|
)
|
|
flow_vol[0, dims - i - 1] = (
|
|
torch.from_numpy(flow_field[i])
|
|
.to(self._device, torch.float32)
|
|
)
|
|
|
|
# Prepare normalization factors for x and y (max index)
|
|
max_indices = torch.tensor([width - 1, height - 1],
|
|
dtype=torch.float32, device=self._device)
|
|
# Reshape for broadcasting to point tensor dims
|
|
max_idx_pt = max_indices.view(1, 1, 1, dims)
|
|
# Reshape for broadcasting to flow volume dims
|
|
max_idx_flow = max_indices.view(1, dims, 1, 1)
|
|
|
|
# Normalize flow values to [-1, 1] range
|
|
flow_vol = (flow_vol * 2) / max_idx_flow
|
|
# Normalize points to [-1, 1]
|
|
pts = (pts / max_idx_pt) * 2 - 1
|
|
|
|
# Iterate: sample flow and update points
|
|
for _ in range(num_iters):
|
|
sampled = torch.nn.functional.grid_sample(
|
|
flow_vol, pts, align_corners=False
|
|
)
|
|
# Update each coordinate and clamp to valid range
|
|
for i in range(dims):
|
|
pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0)
|
|
|
|
# Denormalize back to original pixel coordinates
|
|
pts = (pts + 1) * 0.5 * max_idx_pt
|
|
# Swap channels back to (y, x) and flatten
|
|
final_pts = pts[..., [1, 0]].squeeze()
|
|
# Convert from (num_points, 2) to (2, num_points)
|
|
return final_pts.T if final_pts.ndim > 1 else final_pts.unsqueeze(0).T
|
|
|
|
# CPU fallback using numpy and scipy
|
|
current_pos = initial_coords.copy().astype(np.float32)
|
|
temp_delta = np.zeros_like(current_pos, dtype=np.float32)
|
|
|
|
for _ in range(num_iters):
|
|
# Interpolate flow at current positions
|
|
self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta)
|
|
# Update positions and clamp to image bounds
|
|
current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1)
|
|
current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1)
|
|
|
|
return current_pos
|
|
|
|
|
|
@staticmethod
|
|
@njit([
|
|
"(int16[:,:,:], float32[:], float32[:], float32[:,:])",
|
|
"(float32[:,:,:], float32[:], float32[:], float32[:,:])"
|
|
], cache=True)
|
|
def __map_coordinates(
|
|
image_data: np.ndarray,
|
|
y_coords: np.ndarray,
|
|
x_coords: np.ndarray,
|
|
output: np.ndarray
|
|
) -> None:
|
|
"""
|
|
Perform in-place bilinear interpolation on an image volume.
|
|
|
|
Args:
|
|
image_data (np.ndarray): Input volume with shape (C, H, W).
|
|
y_coords (np.ndarray): Array of new y positions (num_points).
|
|
x_coords (np.ndarray): Array of new x positions (num_points).
|
|
output (np.ndarray): Output array of shape (C, num_points) to fill.
|
|
|
|
Returns:
|
|
None. Results written directly into `output`.
|
|
"""
|
|
channels, height, width = image_data.shape
|
|
# Compute integer (floor) and fractional parts for coords
|
|
y_floor = y_coords.astype(np.int32)
|
|
x_floor = x_coords.astype(np.int32)
|
|
y_frac = y_coords - y_floor
|
|
x_frac = x_coords - x_floor
|
|
|
|
# Loop over each sample point
|
|
for idx in range(y_floor.shape[0]):
|
|
# Clamp base indices to valid range
|
|
y0 = min(max(y_floor[idx], 0), height - 1)
|
|
x0 = min(max(x_floor[idx], 0), width - 1)
|
|
y1 = min(y0 + 1, height - 1)
|
|
x1 = min(x0 + 1, width - 1)
|
|
|
|
wy = y_frac[idx]
|
|
wx = x_frac[idx]
|
|
|
|
# Interpolate per channel
|
|
for c in range(channels):
|
|
v00 = np.float32(image_data[c, y0, x0])
|
|
v10 = np.float32(image_data[c, y0, x1])
|
|
v01 = np.float32(image_data[c, y1, x0])
|
|
v11 = np.float32(image_data[c, y1, x1])
|
|
# Bilinear interpolation formula
|
|
output[c, idx] = (
|
|
v00 * (1 - wy) * (1 - wx) +
|
|
v10 * (1 - wy) * wx +
|
|
v01 * wy * (1 - wx) +
|
|
v11 * wy * wx
|
|
)
|
|
|
|
|
|
def __get_mask(
|
|
self,
|
|
pixel_positions: torch.Tensor,
|
|
valid_indices: np.ndarray,
|
|
original_shape: Tuple[int, ...],
|
|
pad_radius: int = 20,
|
|
max_size_fraction: float = 0.4
|
|
) -> np.ndarray:
|
|
"""
|
|
Generate labeled masks by clustering pixel trajectories via histogram peaks and region growing.
|
|
|
|
This function executes the following steps:
|
|
1. Pads and clamps pixel final positions to avoid border effects.
|
|
2. Builds a dense histogram of pixel counts over spatial bins.
|
|
3. Identifies local maxima in the histogram as seed points.
|
|
4. Extracts local patches around each seed and grows regions by iteratively adding neighbors
|
|
that exceed an intensity threshold.
|
|
5. Maps grown patches back to original image indices.
|
|
6. Removes any masks that exceed a maximum size fraction of the image.
|
|
|
|
Args:
|
|
pixel_positions (torch.Tensor): Tensor of shape `[2, N_pixels]`, dtype=int, containing
|
|
final pixel coordinates after dynamics for each dimension.
|
|
valid_indices (np.ndarray): Integer array of shape `[2, N_pixels]`
|
|
giving indices of pixels in the original image grid.
|
|
original_shape (tuple of ints): Spatial dimensions of the original image, e.g. (H, W).
|
|
pad_radius (int): Number of zero-padding pixels added on each side of the histogram.
|
|
Defaults to 20.
|
|
max_size_fraction (float): If any mask has a pixel count > max_size_fraction * total_pixels,
|
|
it will be removed. Defaults to 0.4.
|
|
|
|
Returns:
|
|
np.ndarray: Integer mask array of shape `original_shape` with labels 0 (background) and 1..M.
|
|
|
|
Raises:
|
|
ValueError: If input dimensions are inconsistent or pixel_positions shape is invalid.
|
|
"""
|
|
# Validate inputs
|
|
ndim = len(original_shape)
|
|
if pixel_positions.ndim != 2 or pixel_positions.size(0) != ndim:
|
|
msg = f"pixel_positions must be shape [{ndim}, N], got {tuple(pixel_positions.shape)}"
|
|
logger.error(msg)
|
|
raise ValueError(msg)
|
|
if pad_radius < 0:
|
|
msg = f"pad_radius must be non-negative, got {pad_radius}"
|
|
logger.error(msg)
|
|
raise ValueError(msg)
|
|
|
|
# Step 1: Pad and clamp pixel positions
|
|
padded_positions = pixel_positions.clone().to(torch.int64) + pad_radius
|
|
for dim in range(ndim):
|
|
max_val = original_shape[dim] + pad_radius - 1
|
|
padded_positions[dim] = torch.clamp(padded_positions[dim], min=0, max=max_val)
|
|
|
|
# Build histogram dimensions
|
|
hist_shape = tuple(s + 2 * pad_radius for s in original_shape)
|
|
|
|
# Step 2: Create sparse tensor and densify to get per-pixel counts
|
|
try:
|
|
counts_sparse = torch.sparse_coo_tensor(
|
|
padded_positions,
|
|
torch.ones(padded_positions.shape[1], dtype=torch.int32, device=pixel_positions.device),
|
|
size=hist_shape
|
|
)
|
|
histogram = counts_sparse.to_dense()
|
|
except Exception as e:
|
|
logger.error("Failed to build dense histogram: %s", e)
|
|
raise
|
|
|
|
# Step 3: Find peaks via 5x5 max-pooling
|
|
k = 5
|
|
pooled = F.max_pool2d(
|
|
histogram.unsqueeze(0),
|
|
kernel_size=k,
|
|
stride=1,
|
|
padding=k // 2
|
|
).squeeze()
|
|
# Seeds are positions where histogram equals local max and count > threshold
|
|
seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10))
|
|
if seed_positions.numel() == 0:
|
|
logger.warning("No seeds found: returning empty mask")
|
|
return np.zeros(original_shape, dtype=np.uint16)
|
|
|
|
# Sort seeds by ascending count to process small peaks first
|
|
seed_counts = histogram[tuple(seed_positions.T)]
|
|
order = torch.argsort(seed_counts)
|
|
seed_positions = seed_positions[order]
|
|
del pooled, counts_sparse
|
|
|
|
# Step 4: Extract local patches and perform region growing
|
|
num_seeds = seed_positions.shape[0]
|
|
# Tensor to hold local patches
|
|
patches = torch.zeros((num_seeds, 11, 11), device=pixel_positions.device)
|
|
for idx in range(num_seeds):
|
|
coords = seed_positions[idx]
|
|
slices = tuple(slice(c - 5, c + 6) for c in coords)
|
|
patches[idx] = histogram[slices]
|
|
del histogram
|
|
|
|
# Initialize seed mask (center pixel of each patch)
|
|
seed_masks = torch.zeros_like(patches, device=pixel_positions.device)
|
|
seed_masks[:, 5, 5] = 1
|
|
# Iterative dilation and thresholding
|
|
for _ in range(5):
|
|
seed_masks = F.max_pool2d(
|
|
seed_masks,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1
|
|
)
|
|
seed_masks = seed_masks & (patches > 2)
|
|
# Compute final mask coordinates
|
|
final_coords = []
|
|
for idx in range(num_seeds):
|
|
coords_local = torch.nonzero(seed_masks[idx])
|
|
# Shift back to global positions
|
|
coords_global = coords_local + seed_positions[idx] - 5
|
|
final_coords.append(tuple(coords_global.T))
|
|
|
|
# Step 5: Paint masks into padded volume
|
|
dtype = torch.int32 if num_seeds < 2**16 else torch.int64
|
|
mask_padded = torch.zeros(hist_shape, dtype=dtype, device=pixel_positions.device)
|
|
for label_idx, coords in enumerate(final_coords, start=1):
|
|
mask_padded[coords] = label_idx
|
|
|
|
# Extract only the padded positions that correspond to original pixels
|
|
mask_values = mask_padded[tuple(padded_positions)]
|
|
mask_values = mask_values.cpu().numpy()
|
|
|
|
# Step 6: Map to original image and remove oversized masks
|
|
mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32)
|
|
mask_final[valid_indices] = mask_values
|
|
|
|
# Prune masks that are too large
|
|
labels, counts = fastremap.unique(mask_final, return_counts=True)
|
|
total_pixels = np.prod(original_shape)
|
|
oversized = labels[counts > (total_pixels * max_size_fraction)]
|
|
if oversized.size > 0:
|
|
mask_final = fastremap.mask(mask_final, oversized)
|
|
fastremap.renumber(mask_final, in_place=True)
|
|
|
|
return mask_final
|
|
|
|
|
|
def __remove_inconsistent_flow_masks(
|
|
self,
|
|
mask: np.ndarray,
|
|
flow_network: np.ndarray,
|
|
error_threshold: float = 0.4
|
|
) -> np.ndarray:
|
|
"""
|
|
Remove labeled masks that have inconsistent optical flows compared to network-predicted flows.
|
|
|
|
This performs a quality control step by computing flows from the provided masks
|
|
and comparing them to the flows predicted by the network. Masks with a mean squared
|
|
flow error above `error_threshold` are discarded (set to 0).
|
|
|
|
Args:
|
|
mask (np.ndarray): Integer mask array with shape [H, W].
|
|
Values: 0 = no mask; 1,2,... = mask labels.
|
|
flow_network (np.ndarray): Float array of network-predicted flows with shape
|
|
[2, H, W].
|
|
error_threshold (float): Maximum allowed mean squared flow error per mask label.
|
|
Defaults to 0.4.
|
|
|
|
Returns:
|
|
np.ndarray: The input mask with inconsistent masks removed (labels set to 0).
|
|
|
|
Raises:
|
|
MemoryError: If the mask size exceeds available GPU memory.
|
|
"""
|
|
# If mask is very large and running on CUDA, check memory
|
|
num_pixels = mask.size
|
|
if (
|
|
num_pixels > 10000 * 10000
|
|
|
|
and self._device.type == 'cuda'
|
|
):
|
|
# Clear unused GPU cache
|
|
torch.cuda.empty_cache()
|
|
# Determine PyTorch version
|
|
major, minor = map(int, torch.__version__.split('.')[:2])
|
|
# Determine current CUDA device index
|
|
device_index = (
|
|
self._device.index
|
|
if hasattr(self._device, 'index')
|
|
else torch.cuda.current_device()
|
|
)
|
|
# Get free and total memory
|
|
if major == 1 and minor < 10:
|
|
total_mem = torch.cuda.get_device_properties(device_index).total_memory
|
|
used_mem = torch.cuda.memory_allocated(device_index)
|
|
free_mem = total_mem - used_mem
|
|
else:
|
|
free_mem, total_mem = torch.cuda.mem_get_info(device_index)
|
|
# Estimate required memory for mask-based flow computation
|
|
# Assume float32 per pixel
|
|
required_bytes = num_pixels * np.dtype(np.float32).itemsize
|
|
if required_bytes > free_mem:
|
|
logger.error(
|
|
'Image too large for GPU memory in flow QC step (required: %d B, available: %d B)',
|
|
required_bytes, free_mem
|
|
)
|
|
raise MemoryError('Insufficient GPU memory for flow QC computation')
|
|
|
|
# Compute flow errors per mask label
|
|
flow_errors, _ = self.__compute_flow_error(mask, flow_network)
|
|
|
|
# Identify labels with error above threshold
|
|
bad_labels = np.nonzero(flow_errors > error_threshold)[0] + 1
|
|
|
|
# Remove bad masks by setting their label to 0
|
|
mask[np.isin(mask, bad_labels)] = 0
|
|
return mask
|
|
|
|
|
|
def __compute_flow_error(
|
|
self,
|
|
mask: np.ndarray,
|
|
flow_network: np.ndarray
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Compute mean squared error between network-predicted flows and flows derived from masks.
|
|
|
|
Args:
|
|
mask (np.ndarray): Integer masks, shape must match flow_network spatial dims.
|
|
flow_network (np.ndarray): Network predicted flows of shape [axis, ...].
|
|
|
|
Returns:
|
|
Tuple[np.ndarray, np.ndarray]:
|
|
- flow_errors: 1D array (length = max label) of mean squared error per label.
|
|
- computed_flows: Array of flows derived from the mask, same shape as flow_network.
|
|
|
|
Raises:
|
|
ValueError: If the spatial dimensions of `mask_array` and `flow_network` do not match.
|
|
"""
|
|
# Ensure mask and flow shapes match
|
|
if flow_network.shape[1:] != mask.shape:
|
|
logger.error(
|
|
'Shape mismatch: network flow shape %s vs mask shape %s',
|
|
flow_network.shape[1:], mask.shape
|
|
)
|
|
raise ValueError('Network flow and mask shapes must match')
|
|
|
|
# Compute flows from mask labels (user-provided function)
|
|
computed_flows = self.__compute_flow_from_mask(mask[None, ...])
|
|
|
|
# Prepare array for errors (one value per mask label)
|
|
num_labels = int(mask.max())
|
|
flow_errors = np.zeros(num_labels, dtype=float)
|
|
|
|
# Accumulate mean squared error over each flow axis
|
|
for axis_index in range(computed_flows.shape[0]):
|
|
# MSE per label: mean((computed - predicted/5)^2)
|
|
flow_errors += mean(
|
|
(computed_flows[axis_index] - flow_network[axis_index] / 5.0) ** 2,
|
|
mask,
|
|
index=np.arange(1, num_labels + 1)
|
|
)
|
|
|
|
return flow_errors, computed_flows
|
|
|
|
|
|
def __fill_holes_and_prune_small_masks(
|
|
self,
|
|
masks: np.ndarray,
|
|
minimum_size: int = 15
|
|
) -> np.ndarray:
|
|
"""
|
|
Fill holes in labeled masks and remove masks smaller than a given size.
|
|
|
|
This function performs two steps:
|
|
1. Fills internal holes in each labeled mask using `fill_voids.fill`.
|
|
2. Discards any mask whose pixel count is below `minimum_size`.
|
|
|
|
Args:
|
|
masks (np.ndarray): Integer mask array of dimension 2 or 3 (shape [H, W] or [D, H, W]).
|
|
Values: 0 = background; 1,2,... = mask labels.
|
|
minimum_size (int): Minimum number of pixels required to keep a mask.
|
|
Masks smaller than this will be removed.
|
|
Set to -1 to skip size-based pruning. Defaults to 15.
|
|
|
|
Returns:
|
|
np.ndarray: Processed mask array with holes filled and small masks removed.
|
|
|
|
Raises:
|
|
ValueError: If `masks` is not a 2D or 3D integer array.
|
|
"""
|
|
# Validate input dimensions
|
|
if masks.ndim not in (2, 3):
|
|
msg = f"Expected 2D or 3D mask array, got {masks.ndim}D."
|
|
logger.error(msg)
|
|
raise ValueError(msg)
|
|
|
|
# Optionally remove masks smaller than minimum_size
|
|
if minimum_size >= 0:
|
|
# Compute label counts (skipping background at index 0)
|
|
labels, counts = fastremap.unique(masks, return_counts=True)
|
|
# Identify labels to remove: those with count < minimum_size
|
|
small_labels = labels[counts < minimum_size]
|
|
if small_labels.size > 0:
|
|
masks = fastremap.mask(masks, small_labels)
|
|
fastremap.renumber(masks, in_place=True)
|
|
|
|
# Find bounding boxes for each mask label
|
|
object_slices = find_objects(masks)
|
|
new_label = 1
|
|
output_masks = np.zeros_like(masks, dtype=masks.dtype)
|
|
|
|
# Loop over each original slice, fill holes, and assign new labels
|
|
for original_label, slc in enumerate(object_slices, start=1):
|
|
if slc is None:
|
|
continue
|
|
# Extract sub-volume or sub-image
|
|
region = masks[slc] == original_label
|
|
if not np.any(region):
|
|
continue
|
|
# Fill internal holes
|
|
filled_region = fill_voids.fill(region)
|
|
# Write back into output mask with sequential labels
|
|
output_masks[slc][filled_region] = new_label
|
|
new_label += 1
|
|
|
|
# Final pruning of small masks after filling (optional)
|
|
if minimum_size >= 0:
|
|
labels, counts = fastremap.unique(output_masks, return_counts=True)
|
|
small_labels = labels[counts < minimum_size]
|
|
if small_labels.size > 0:
|
|
output_masks = fastremap.mask(output_masks, small_labels)
|
|
fastremap.renumber(output_masks, in_place=True)
|
|
|
|
return output_masks |