You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2320 lines
96 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import random
import numpy as np
from numba import njit, prange
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
import fastremap
import fill_voids
from skimage import morphology
from skimage.segmentation import find_boundaries
from scipy.ndimage import mean, find_objects
from monai.data.dataset import Dataset
from monai.transforms import * # type: ignore
from monai.inferers.utils import sliding_window_inference
from monai.metrics.cumulative_average import CumulativeAverage
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os
import glob
import copy
import tifffile as tiff
from pprint import pformat
from tabulate import tabulate
from typing import Any, Dict, Literal, Optional, Tuple, List, Union
from tqdm import tqdm
import wandb
from config import Config
from core.models import *
from core.losses import *
from core.optimizers import *
from core.schedulers import *
from core.utils import (
compute_batch_segmentation_tp_fp_fn,
compute_f1_score,
compute_average_precision_score
)
from core.logger import get_logger
logger = get_logger()
class CellSegmentator:
def __init__(self, config: Config) -> None:
self.__set_seed(config.dataset_config.common.seed)
self.__parse_config(config)
self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu")
self._scaler = (
torch.amp.GradScaler(self._device.type) # type: ignore
if self._dataset_setup.is_training and self._dataset_setup.common.use_amp
else None
)
self._train_dataloader: Optional[DataLoader] = None
self._valid_dataloader: Optional[DataLoader] = None
self._test_dataloader: Optional[DataLoader] = None
self._predict_dataloader: Optional[DataLoader] = None
def create_dataloaders(
self,
train_transforms: Optional[Compose] = None,
valid_transforms: Optional[Compose] = None,
test_transforms: Optional[Compose] = None,
predict_transforms: Optional[Compose] = None
) -> None:
"""
Creates train, validation, test, and prediction dataloaders based on dataset configuration
and provided transforms.
Args:
train_transforms (Optional[Compose]): Transformations for training data.
valid_transforms (Optional[Compose]): Transformations for validation data.
test_transforms (Optional[Compose]): Transformations for testing data.
predict_transforms (Optional[Compose]): Transformations for prediction data.
Raises:
ValueError: If required transforms are missing.
RuntimeError: If critical dataset config values are missing.
"""
if self._dataset_setup.is_training and train_transforms is None:
raise ValueError("Training mode requires 'train_transforms' to be provided.")
elif not self._dataset_setup.is_training and test_transforms is None and predict_transforms is None:
raise ValueError("In inference mode, at least one of 'test_transforms' or 'predict_transforms' must be provided.")
if self._dataset_setup.is_training:
# Training mode: handle either pre-split datasets or splitting on the fly
if self._dataset_setup.training.is_split:
# Validate presence of validation transforms if validation directory and size are set
if (
self._dataset_setup.training.pre_split.valid_dir and
self._dataset_setup.training.valid_size and
valid_transforms is None
):
raise ValueError("Validation transforms must be provided when using pre-split validation data.")
# Use explicitly split directories
train_dir = self._dataset_setup.training.pre_split.train_dir
valid_dir = self._dataset_setup.training.pre_split.valid_dir
test_dir = self._dataset_setup.training.pre_split.test_dir
train_offset = self._dataset_setup.training.train_offset
valid_offset = self._dataset_setup.training.valid_offset
test_offset = self._dataset_setup.training.test_offset
shuffle = False
else:
# Same validation for split mode with full data directory
if (
self._dataset_setup.training.split.all_data_dir and
self._dataset_setup.training.valid_size and
valid_transforms is None
):
raise ValueError("Validation transforms must be provided when splitting dataset.")
# Automatically split dataset from one directory
train_dir = valid_dir = test_dir = self._dataset_setup.training.split.all_data_dir
number_of_images = len(os.listdir(os.path.join(train_dir, 'images')))
if number_of_images == 0:
raise FileNotFoundError(f"No images found in '{train_dir}/images'")
# Calculate train/valid sizes
train_size = (
self._dataset_setup.training.train_size
if isinstance(self._dataset_setup.training.train_size, int)
else int(number_of_images * self._dataset_setup.training.train_size)
)
valid_size = (
self._dataset_setup.training.valid_size
if isinstance(self._dataset_setup.training.valid_size, int)
else int(number_of_images * self._dataset_setup.training.valid_size)
)
train_offset = self._dataset_setup.training.train_offset
valid_offset = self._dataset_setup.training.valid_offset + train_size
test_offset = self._dataset_setup.training.test_offset + train_size + valid_size
shuffle = True
# Train dataloader
train_dataset = self.__get_dataset(
images_dir=os.path.join(train_dir, 'images'),
masks_dir=os.path.join(train_dir, 'masks'),
transforms=train_transforms, # type: ignore
size=self._dataset_setup.training.train_size,
offset=train_offset,
shuffle=shuffle
)
self._train_dataloader = DataLoader(train_dataset, batch_size=self._dataset_setup.training.batch_size, shuffle=True)
logger.info(f"Loaded training dataset with {len(train_dataset)} samples.")
# Validation dataloader
if valid_transforms is not None:
if not valid_dir or not self._dataset_setup.training.valid_size:
raise RuntimeError("Validation directory or size is not properly configured.")
valid_dataset = self.__get_dataset(
images_dir=os.path.join(valid_dir, 'images'),
masks_dir=os.path.join(valid_dir, 'masks'),
transforms=valid_transforms,
size=self._dataset_setup.training.valid_size,
offset=valid_offset,
shuffle=shuffle
)
self._valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
logger.info(f"Loaded validation dataset with {len(valid_dataset)} samples.")
# Test dataloader
if test_transforms is not None:
if not test_dir or not self._dataset_setup.training.test_size:
raise RuntimeError("Test directory or size is not properly configured.")
test_dataset = self.__get_dataset(
images_dir=os.path.join(test_dir, 'images'),
masks_dir=os.path.join(test_dir, 'masks'),
transforms=test_transforms,
size=self._dataset_setup.training.test_size,
offset=test_offset,
shuffle=shuffle
)
self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
logger.info(f"Loaded test dataset with {len(test_dataset)} samples.")
# Prediction dataloader
if predict_transforms is not None:
if not test_dir or not self._dataset_setup.training.test_size:
raise RuntimeError("Prediction directory or size is not properly configured.")
predict_dataset = self.__get_dataset(
images_dir=os.path.join(test_dir, 'images'),
masks_dir=None,
transforms=predict_transforms,
size=self._dataset_setup.training.test_size,
offset=test_offset,
shuffle=shuffle
)
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
else:
# Inference mode (no training)
test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images')
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks')
if test_transforms is not None:
test_dataset = self.__get_dataset(
images_dir=test_images,
masks_dir=test_masks,
transforms=test_transforms,
size=self._dataset_setup.testing.test_size,
offset=self._dataset_setup.testing.test_offset,
shuffle=self._dataset_setup.testing.shuffle
)
self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
logger.info(f"Loaded test dataset with {len(test_dataset)} samples.")
if predict_transforms is not None:
predict_dataset = self.__get_dataset(
images_dir=test_images,
masks_dir=None,
transforms=predict_transforms,
size=self._dataset_setup.testing.test_size,
offset=self._dataset_setup.testing.test_offset,
shuffle=self._dataset_setup.testing.shuffle
)
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
def print_data_info(
self,
loader_type: Literal["train", "valid", "test", "predict"],
index: Optional[int] = None
) -> None:
"""
Prints statistics for a single sample from the specified dataloader.
Args:
loader_type: One of "train", "valid", "test", "predict".
index: The sample index; if None, a random index is selected.
"""
# Retrieve the dataloader attribute, e.g., self._train_dataloader
loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None)
if loader is None:
logger.error(f"Dataloader '{loader_type}' is not initialized.")
return
dataset = loader.dataset
total = len(dataset) # type: ignore
if total == 0:
logger.error(f"Dataset for '{loader_type}' is empty.")
return
# Choose index
idx = index if index is not None else random.randint(0, total - 1)
if not (0 <= idx < total):
logger.error(f"Index {idx} is out of range [0, {total}).")
return
# Fetch the sample and apply transforms
sample = dataset[idx]
# Expecting a dict with {'image': ..., 'mask': ...} or {'image': ...}
img = sample["image"]
# Convert tensor to numpy if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
img = np.asarray(img)
# Compute image statistics
img_min, img_max = img.min(), img.max()
img_mean, img_std = float(img.mean()), float(img.std())
img_shape = img.shape
# Prepare log lines
lines = [
"=" * 40,
f"Dataloader: '{loader_type}', sample index: {idx} / {total - 1}",
f"Image — shape: {img_shape}, min: {img_min:.4f}, max: {img_max:.4f}, mean: {img_mean:.4f}, std: {img_std:.4f}"
]
# For 'predict', no mask is available
if loader_type != "predict":
mask = sample.get("mask", None)
if mask is not None:
# Convert tensor to numpy if needed
if isinstance(mask, torch.Tensor):
mask = mask.cpu().numpy()
mask = np.asarray(mask)
m_min, m_max = mask.min(), mask.max()
m_mean, m_std = float(mask.mean()), float(mask.std())
m_shape = mask.shape
lines.append(
f"Mask — shape: {m_shape}, min: {m_min:.4f}, "
f"max: {m_max:.4f}, mean: {m_mean:.4f}, std: {m_std:.4f}"
)
else:
lines.append("Mask — not available for this sample.")
lines.append("=" * 40)
# Output via logger
for l in lines:
logger.info(l)
def train(self) -> None:
"""
Train the model over multiple epochs, including validation and test.
"""
# Ensure training is enabled in dataset setup
if not self._dataset_setup.is_training:
raise RuntimeError("Dataset setup indicates training is disabled.")
# Determine device name for logging
if self._device.type == "cpu":
device_name = "cpu"
else:
idx = self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device()
device_name = torch.cuda.get_device_name(idx)
logger.info(f"\n{'=' * 50}\n")
logger.info(f"Training starts on device: {device_name}")
logger.info(f"\n{'=' * 50}")
best_f1_score = 0.0
best_weights = None
for epoch in range(1, self._dataset_setup.training.num_epochs + 1):
train_metrics = self.__run_epoch("train", epoch)
self.__print_with_logging(train_metrics, epoch)
# Step the scheduler after training
if self._scheduler is not None:
self._scheduler.step()
# Periodic validation or tuning
if epoch % self._dataset_setup.training.val_freq == 0:
if self._valid_dataloader is not None:
valid_metrics = self.__run_epoch("valid", epoch)
self.__print_with_logging(valid_metrics, epoch)
# Update best model on improved F1
f1 = valid_metrics.get("valid_f1_score", 0.0)
if f1 > best_f1_score:
best_f1_score = f1
# Deep copy weights to avoid reference issues
best_weights = copy.deepcopy(self._model.state_dict())
logger.info(f"Updated best model weights with F1 score: {f1:.4f}")
# Restore best model weights if available
if best_weights is not None:
self._model.load_state_dict(best_weights)
if self._test_dataloader is not None:
test_metrics = self.__run_epoch("test")
self.__print_with_logging(test_metrics, 0)
def evaluate(self) -> None:
"""
Run a full test epoch and display/log the resulting metrics.
"""
test_metrics = self.__run_epoch("test")
self.__print_with_logging(test_metrics, 0)
def predict(self) -> None:
"""
Run inference on the predict set and save the resulting instance masks.
"""
# Ensure the predict DataLoader has been set
if self._predict_dataloader is None:
raise RuntimeError("DataLoader for mode 'predict' is not set.")
batch_counter = 0
for batch in tqdm(self._predict_dataloader, desc="Predicting"):
# Move input images to the configured device (CPU/GPU)
inputs = batch["img"].to(self._device)
# Use automatic mixed precision if enabled in dataset setup
with torch.amp.autocast( # type: ignore
self._device.type,
enabled=self._dataset_setup.common.use_amp
):
# Disable gradient computation for inference
with torch.no_grad():
# Run the models forward pass in predict mode
raw_output = self.__run_inference(inputs, mode="predict")
# Convert logits/probabilities to discrete instance masks
# ground_truth is not passed here; only predictions are needed
preds, _ = self.__post_process_predictions(raw_output)
# Save out the predicted masks, using batch_counter to index files
self.__save_prediction_masks(batch, preds, batch_counter)
# Increment counter by batch size for unique file naming
batch_counter += inputs.shape[0]
def run(self) -> None:
"""
Orchestrate the full workflow:
- If training is enabled in the dataset setup, start training.
- Otherwise, if a test DataLoader is provided, run evaluation.
- Else if a prediction DataLoader is provided, run inference/prediction.
- If neither loader is available in nontraining mode, raise an error.
"""
# 1) TRAINING PATH
if self._dataset_setup.is_training:
# Launch the full training loop (with validation, scheduler steps, etc.)
self.train()
return
# 2) NON-TRAINING PATH (TEST or PREDICT)
# Prefer test if available
if self._test_dataloader is not None:
# Run a single evaluation epoch on the test set and log metrics
self.evaluate()
return
# If no test loader, fall back to prediction if available
if self._predict_dataloader is not None:
# Run inference on the predict set and save outputs
self.predict()
return
# 3) ERROR: no appropriate loader found
raise RuntimeError(
"Neither test nor predict DataLoader is set for nontraining mode."
)
def load_from_checkpoint(self, checkpoint_path: str) -> None:
"""
Loads model weights from a specified checkpoint into the current model.
Args:
checkpoint_path (str): Path to the checkpoint file containing the model weights.
"""
# Load the checkpoint onto the correct device (CPU or GPU)
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True)
# Load the state dict into the model, allowing for missing keys
self._model.load_state_dict(checkpoint['state_dict'], strict=False)
def save_checkpoint(self, checkpoint_path: str) -> None:
"""
Saves the current model weights to a checkpoint file.
Args:
checkpoint_path (str): Path where the checkpoint file will be saved.
"""
# Create a checkpoint dictionary containing the models state_dict
checkpoint = {
'state_dict': self._model.state_dict()
}
# Write the checkpoint to disk
torch.save(checkpoint, checkpoint_path)
def __parse_config(self, config: Config) -> None:
"""
Parses the given configuration object to initialize model, criterion,
optimizer, scheduler, and dataset setup.
Args:
config (Config): Configuration object with model, optimizer,
scheduler, criterion, and dataset setup information.
"""
model = config.model
criterion = config.criterion
optimizer = config.optimizer
scheduler = config.scheduler
logger.info("========== Parsed Configuration ==========")
logger.info("Model Config:\n%s", pformat(model.dump(), indent=2))
if criterion:
logger.info("Criterion Config:\n%s", pformat(criterion.dump(), indent=2))
if optimizer:
logger.info("Optimizer Config:\n%s", pformat(optimizer.dump(), indent=2))
if scheduler:
logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2))
logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2))
logger.info("==========================================")
# Initialize model using the model registry
self._model = ModelRegistry.get_model_class(model.name)(model.params)
# Loads model weights from a specified checkpoint
if config.dataset_config.is_training:
if config.dataset_config.training.pretrained_weights:
self.load_from_checkpoint(config.dataset_config.training.pretrained_weights)
# Initialize loss criterion if specified
self._criterion = (
CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params)
if criterion is not None
else None
)
if hasattr(self._criterion, "num_classes"):
nc_model = self._model.num_classes
nc_crit = getattr(self._criterion, "num_classes")
if nc_model != nc_crit:
raise ValueError(
f"Number of classes mismatch: model.num_classes={nc_model} "
f"but criterion.num_classes={nc_crit}"
)
# Initialize optimizer if specified
self._optimizer = (
OptimizerRegistry.get_optimizer_class(optimizer.name)(
model_params=self._model.parameters(),
optim_params=optimizer.params
)
if optimizer is not None
else None
)
# Initialize scheduler only if both scheduler and optimizer are defined
self._scheduler = (
SchedulerRegistry.get_scheduler_class(scheduler.name)(
optimizer=self._optimizer.optim,
params=scheduler.params
)
if scheduler is not None and self._optimizer is not None and self._optimizer.optim is not None
else None
)
logger.info("========== Model Components Initialization ==========")
logger.info("├─ Model: " + (f"{model.name}" if self._model else "Not specified"))
logger.info("├─ Criterion: " + (f"{criterion.name}" if self._criterion else "Not specified")) # type: ignore
logger.info("├─ Optimizer: " + (f"{optimizer.name}" if self._optimizer else "Not specified")) # type: ignore
logger.info("└─ Scheduler: " + (f"{scheduler.name}" if self._scheduler else "Not specified")) # type: ignore
logger.info("=====================================================")
# Save dataset config
self._dataset_setup = config.dataset_config
common = config.dataset_config.common
logger.info("========== Dataset Setup ==========")
logger.info("[COMMON]")
logger.info(f"├─ Seed: {common.seed}")
logger.info(f"├─ Device: {common.device}")
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
if config.dataset_config.is_training:
training = config.dataset_config.training
logger.info("[MODE] Training")
logger.info(f"├─ Batch size: {training.batch_size}")
logger.info(f"├─ Epochs: {training.num_epochs}")
logger.info(f"├─ Validation frequency: {training.val_freq}")
logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}")
if training.is_split:
logger.info(f"├─ Using pre-split directories:")
logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}")
logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}")
logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}")
else:
logger.info(f"├─ Using unified dataset with splits:")
logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}")
logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}")
logger.info(f"└─ Dataset split:")
logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}")
logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}")
logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}")
else:
testing = config.dataset_config.testing
logger.info("[MODE] Inference")
logger.info(f"├─ Test dir: {testing.test_dir}")
logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})")
logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}")
logger.info(f"├─ Use ensemble: {'yes' if testing.use_ensemble else 'no'}")
logger.info(f"└─ Pretrained weights:")
logger.info(f" ├─ Single model: {testing.pretrained_weights}")
logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}")
logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}")
logger.info("===================================")
def __set_seed(self, seed: Optional[int]) -> None:
"""
Sets the random seed for reproducibility across Python, NumPy, and PyTorch.
Args:
seed (Optional[int]): Seed value. If None, no seeding is performed.
"""
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logger.info(f"Random seed set to {seed}")
else:
logger.info("Seed not set (None provided)")
def __get_dataset(
self,
images_dir: str,
masks_dir: Optional[str],
transforms: Compose,
size: Union[int, float],
offset: int,
shuffle: bool
) -> Dataset:
"""
Loads and returns a dataset object from image and optional mask directories.
Args:
images_dir (str): Path to directory or glob pattern for input images.
masks_dir (Optional[str]): Path to directory or glob pattern for masks.
transforms (Compose): Transformations to apply to each image or pair.
size (Union[int, float]): Either an integer or a fraction of the dataset.
offset (int): Number of images to skip from the start.
shuffle (bool): Whether to shuffle the dataset before slicing.
Returns:
Dataset: A dataset containing image and optional mask paths.
Raises:
FileNotFoundError: If no images are found.
ValueError: If masks are provided but do not match image count.
ValueError: If dataset is too small for requested size or offset.
"""
# Collect sorted list of image paths
images = sorted(glob.glob(images_dir))
if not images:
raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'")
if masks_dir is not None:
# Collect and validate sorted list of mask paths
masks = sorted(glob.glob(masks_dir))
if len(images) != len(masks):
raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})")
# Convert float size (fraction) to absolute count
size = size if isinstance(size, int) else int(size * len(images))
if size <= 0:
raise ValueError(f"Size must be positive, got: {size}")
if len(images) < size:
raise ValueError(f"Not enough images ({len(images)}) for requested size ({size})")
if len(images) < size + offset:
raise ValueError(f"Offset ({offset}) + size ({size}) exceeds dataset length ({len(images)})")
# Shuffle image-mask pairs if requested
if shuffle:
if masks_dir is not None:
combined = list(zip(images, masks)) # type: ignore
random.shuffle(combined)
images, masks = zip(*combined)
else:
random.shuffle(images)
# Apply offset and limit by size
images = images[offset: offset + size]
if masks_dir is not None:
masks = masks[offset: offset + size] # type: ignore
# Prepare data structure for Dataset class
if masks_dir is not None:
data = [
{"image": image, "mask": mask}
for image, mask in zip(images, masks) # type: ignore
]
else:
data = [{"image": image} for image in images]
return Dataset(data, transforms)
def __print_with_logging(self, results: Dict[str, float], step: int) -> None:
"""
Print metrics in a tabular format and log to W&B.
Args:
results (Dict[str, float]): results dictionary.
step (int): epoch index.
"""
table = tabulate(
tabular_data=results.items(),
headers=["Metric", "Value"],
floatfmt=".4f",
tablefmt="fancy_grid"
)
print(table, "\n")
wandb.log(results, step=step)
def __run_epoch(self,
mode: Literal["train", "valid", "test"],
epoch: Optional[int] = None
) -> Dict[str, float]:
"""
Execute one epoch of training, validation, or testing.
Args:
mode (str): One of 'train', 'valid', or 'test'.
epoch (int, optional): Current epoch number for logging.
Returns:
Dict[str, float]: Loss metrics and F1 score for valid/test.
"""
# Ensure required components are available
if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None):
raise RuntimeError("Optimizer and loss function must be initialized for train/valid.")
# Set model mode and choose the appropriate data loader
if mode == "train":
self._model.train()
loader = self._train_dataloader
else:
self._model.eval()
loader = self._valid_dataloader if mode == "valid" else self._test_dataloader
# Check that the data loader is provided
if loader is None:
raise RuntimeError(f"DataLoader for mode '{mode}' is not set.")
all_tp, all_fp, all_fn = [], [], []
# Prepare tqdm description with epoch info if available
if epoch is not None:
desc = f"Epoch {epoch}/{self._dataset_setup.training.num_epochs} [{mode}]"
else:
desc = f"Epoch ({mode})"
# Iterate over batches
batch_counter = 0
for batch in tqdm(loader, desc=desc):
inputs = batch["img"].to(self._device)
targets = batch["label"].to(self._device)
# Zero gradients for training
if self._optimizer is not None:
self._optimizer.zero_grad()
# Mixed precision context if enabled
with torch.amp.autocast( # type: ignore
self._device.type,
enabled=self._dataset_setup.common.use_amp
):
# Only compute gradients in training phase
with torch.set_grad_enabled(mode == "train"):
# Forward pass
raw_output = self.__run_inference(inputs, mode)
if self._criterion is not None:
# Convert label masks to flow representations (one-hot)
flow_targets = self.__compute_flows_from_masks(targets)
# Compute loss for this batch
batch_loss = self._criterion(raw_output, flow_targets) # type: ignore
# Post-process and compute F1 during validation and testing
if mode in ("valid", "test"):
preds, labels_post = self.__post_process_predictions(
raw_output, targets
)
# Collecting statistics on the batch
tp, fp, fn = self.__compute_stats(
predicted_masks=preds,
ground_truth_masks=labels_post, # type: ignore
iou_threshold=0.5
)
all_tp.append(tp)
all_fp.append(fp)
all_fn.append(fn)
if mode == "test":
self.__save_prediction_masks(
batch, preds, batch_counter
)
# Backpropagation and optimizer step in training
if mode == "train":
if self._dataset_setup.common.use_amp and self._scaler is not None:
self._scaler.scale(batch_loss).backward() # type: ignore
self._scaler.unscale_(self._optimizer.optim) # type: ignore
self._scaler.step(self._optimizer.optim) # type: ignore
self._scaler.update()
else:
batch_loss.backward() # type: ignore
if self._optimizer is not None:
self._optimizer.step()
batch_counter += inputs.shape[0]
if self._criterion is not None:
# Collect loss metrics
epoch_metrics = {f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()}
# Reset internal loss metrics accumulator
self._criterion.reset_metrics()
else:
epoch_metrics = {}
# Include F1 and mAP for validation and testing
if mode in ("valid", "test"):
# Concatenating by batch: shape (num_batches*B, C)
tp_array = np.vstack(all_tp)
fp_array = np.vstack(all_fp)
fn_array = np.vstack(all_fn)
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="micro"
)
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="macro"
)
return epoch_metrics
def __run_inference(
self,
inputs: torch.Tensor,
mode: Literal["train", "valid", "test", "predict"] = "train"
) -> torch.Tensor:
"""
Perform model inference for different stages.
Args:
inputs (torch.Tensor): Input tensor of shape (B, C, H, W).
stage (Literal[...]): One of "train", "valid", "test", "predict".
Returns:
torch.Tensor: Model outputs tensor.
"""
if mode != "train":
# Use sliding window inference for non-training phases
outputs = sliding_window_inference(
inputs,
roi_size=512,
sw_batch_size=4,
predictor=self._model,
padding_mode="constant",
mode="gaussian",
overlap=0.5,
)
else:
# Direct forward pass during training
outputs = self._model(inputs)
return outputs # type: ignore
def __post_process_predictions(
self,
raw_outputs: torch.Tensor,
ground_truth: Optional[torch.Tensor] = None
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Post-process raw network outputs to extract instance segmentation masks.
Args:
raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W).
ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W).
Returns:
Tuple[np.ndarray, Optional[np.ndarray]]:
- instance_masks: Instance-wise masks array of shape (B, С, H, W).
- labels_np: Converted ground truth of shape (B, С, H, W) or None if
ground_truth was not provided.
"""
# Move outputs to CPU and convert to numpy
outputs_np = raw_outputs.cpu().numpy()
# Split channels: gradient flows then class logits
gradflow = outputs_np[:, :2 * self._model.num_classes]
logits = outputs_np[:, -self._model.num_classes :]
# Apply sigmoid to logits to get probabilities
probabilities = self.__sigmoid(logits)
batch_size, _, height, width = probabilities.shape
# Prepare container for instance masks
instance_masks = np.zeros((batch_size, self._model.num_classes, height, width), dtype=np.uint16)
for idx in range(batch_size):
instance_masks[idx] = self.__segment_instances(
probability_map=probabilities[idx],
flow=gradflow[idx],
prob_threshold=0.0,
flow_threshold=0.4,
min_object_size=15
)
# Convert ground truth to numpy
labels_np = ground_truth.cpu().numpy() if ground_truth is not None else None
return instance_masks, labels_np
def __compute_stats(
self,
predicted_masks: np.ndarray,
ground_truth_masks: np.ndarray,
iou_threshold: float = 0.5
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute batch-wise true positives, false positives, and false negatives
for instance segmentation, using a configurable IoU threshold.
Args:
predicted_masks (np.ndarray): Predicted instance masks of shape (B, C, H, W).
ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W).
iou_threshold (float): Intersection-over-Union threshold for matching predictions
to ground truths (default: 0.5).
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]:
- tp: True positives per batch and class, shape (B, C)
- fp: False positives per batch and class, shape (B, C)
- fn: False negatives per batch and class, shape (B, C)
"""
stats = compute_batch_segmentation_tp_fp_fn(
batch_ground_truth=ground_truth_masks,
batch_prediction=predicted_masks,
iou_threshold=iou_threshold,
remove_boundary_objects=True
)
tp = stats["tp"]
fp = stats["fp"]
fn = stats["fn"]
return tp, fp, fn
def __compute_f1_metric(
self,
true_positives: np.ndarray,
false_positives: np.ndarray,
false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro"
) -> Union[float, np.ndarray]:
"""
Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes.
Args:
true_positives: array of TP counts per sample and class.
false_positives: array of FP counts per sample and class.
false_negatives: array of FN counts per sample and class.
reduction:
- 'none': return F1 for each sample, class → shape (batch_size, num_classes)
- 'micro': global F1 over all samples & classes
- 'imagewise': F1 per sample (summing over classes), then average over samples
- 'macro': average class-wise F1 (classes summed over batch)
- 'weighted': class-wise F1 weighted by support (TP+FN)
Returns:
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
or np.ndarray of shape (batch_size, num_classes) if reduction='none'.
"""
batch_size, num_classes = true_positives.shape
# 1) No reduction: compute F1 for each (sample, class)
if reduction == "none":
f1_matrix = np.zeros((batch_size, num_classes), dtype=float)
for i in range(batch_size):
for c in range(num_classes):
tp_val = int(true_positives[i, c])
fp_val = int(false_positives[i, c])
fn_val = int(false_negatives[i, c])
_, _, f1_val = compute_f1_score(tp_val, fp_val, fn_val)
f1_matrix[i, c] = f1_val
return f1_matrix
# 2) Micro: sum all TP/FP/FN and compute a single F1
if reduction == "micro":
tp_total = int(true_positives.sum())
fp_total = int(false_positives.sum())
fn_total = int(false_negatives.sum())
_, _, f1_global = compute_f1_score(tp_total, fp_total, fn_total)
return f1_global
# 3) Imagewise: compute per-sample F1 (sum over classes), then average
if reduction == "imagewise":
f1_per_image = np.zeros(batch_size, dtype=float)
for i in range(batch_size):
tp_i = int(true_positives[i].sum())
fp_i = int(false_positives[i].sum())
fn_i = int(false_negatives[i].sum())
_, _, f1_i = compute_f1_score(tp_i, fp_i, fn_i)
f1_per_image[i] = f1_i
return float(f1_per_image.mean())
# For macro/weighted, first aggregate per class across the batch
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Macro: average F1 across classes equally
if reduction == "macro":
f1_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
_, _, f1_c = compute_f1_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
f1_per_class[c] = f1_c
return float(f1_per_class.mean())
# 5) Weighted: class-wise F1 weighted by support = TP + FN
if reduction == "weighted":
f1_per_class = np.zeros(num_classes, dtype=float)
support = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
tp_c = tp_per_class[c]
fp_c = fp_per_class[c]
fn_c = fn_per_class[c]
_, _, f1_c = compute_f1_score(tp_c, fp_c, fn_c)
f1_per_class[c] = f1_c
support[c] = tp_c + fn_c
total_support = support.sum()
if total_support == 0:
# fallback to unweighted macro if no positives
return float(f1_per_class.mean())
weights = support / total_support
return float((f1_per_class * weights).sum())
raise ValueError(f"Unknown reduction mode: {reduction}")
def __compute_average_precision_metric(
self,
true_positives: np.ndarray,
false_positives: np.ndarray,
false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro"
) -> Union[float, np.ndarray]:
"""
Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes.
AP is defined here as:
AP = TP / (TP + FP + FN)
Args:
true_positives: array of true positives per sample and class.
false_positives: array of false positives per sample and class.
false_negatives: array of false negatives per sample and class.
reduction:
- 'none': return AP for each sample and class → shape (batch_size, num_classes)
- 'micro': global AP over all samples & classes
- 'imagewise': AP per sample (summing stats over classes), then average over batch
- 'macro': average class-wise AP (each class summed over batch)
- 'weighted': class-wise AP weighted by support (TP+FN)
Returns:
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
or np.ndarray of shape (batch_size, num_classes) if reduction='none'.
"""
batch_size, num_classes = true_positives.shape
# 1) No reduction: AP per (sample, class)
if reduction == "none":
ap_matrix = np.zeros((batch_size, num_classes), dtype=float)
for i in range(batch_size):
for c in range(num_classes):
tp_val = int(true_positives[i, c])
fp_val = int(false_positives[i, c])
fn_val = int(false_negatives[i, c])
ap_val = compute_average_precision_score(tp_val, fp_val, fn_val)
ap_matrix[i, c] = ap_val
return ap_matrix
# 2) Micro: sum all TP/FP/FN and compute one AP
if reduction == "micro":
tp_total = int(true_positives.sum())
fp_total = int(false_positives.sum())
fn_total = int(false_negatives.sum())
return compute_average_precision_score(tp_total, fp_total, fn_total)
# 3) Imagewise: compute per-sample AP (sum over classes), then mean
if reduction == "imagewise":
ap_per_image = np.zeros(batch_size, dtype=float)
for i in range(batch_size):
tp_i = int(true_positives[i].sum())
fp_i = int(false_positives[i].sum())
fn_i = int(false_negatives[i].sum())
ap_per_image[i] = compute_average_precision_score(tp_i, fp_i, fn_i)
return float(ap_per_image.mean())
# For macro and weighted: first aggregate per class across batch
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Macro: average AP across classes equally
if reduction == "macro":
ap_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
ap_per_class[c] = compute_average_precision_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
return float(ap_per_class.mean())
# 5) Weighted: class-wise AP weighted by support = TP + FN
if reduction == "weighted":
ap_per_class = np.zeros(num_classes, dtype=float)
support = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
tp_c = tp_per_class[c]
fp_c = fp_per_class[c]
fn_c = fn_per_class[c]
ap_per_class[c] = compute_average_precision_score(tp_c, fp_c, fn_c)
support[c] = tp_c + fn_c
total_support = support.sum()
if total_support == 0:
# fallback to unweighted macro if no positive instances
return float(ap_per_class.mean())
weights = support / total_support
return float((ap_per_class * weights).sum())
raise ValueError(f"Unknown reduction mode: {reduction}")
@staticmethod
def __sigmoid(z: np.ndarray) -> np.ndarray:
"""
Numerically stable sigmoid activation for numpy arrays.
Args:
z (np.ndarray): Input array.
Returns:
np.ndarray: Sigmoid of the input.
"""
return 1 / (1 + np.exp(-z))
def __save_prediction_masks(
self,
sample: Dict[str, Any],
predicted_mask: Union[np.ndarray, torch.Tensor],
start_index: int = 0,
) -> None:
"""
Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders.
Args:
sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
start_index (int): Starting index for naming when metadata is missing.
"""
# Determine base paths
base_output_dir = self._dataset_setup.common.predictions_dir
masks_dir = base_output_dir
plots_dir = os.path.join(base_output_dir, "plots")
os.makedirs(masks_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)
# Extract image (C, H, W) or batch of images (B, C, H, W), and metadata
image_obj = sample.get("image") # Expected shape: (C, H, W) or (B, C, H, W)
mask_obj = sample.get("mask") # Expected shape: (C, H, W) or (B, C, H, W)
image_meta = sample.get("image_meta_dict")
# Convert tensors to numpy
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
return x
image_array = to_numpy(image_obj) if image_obj is not None else None
mask_array = to_numpy(mask_obj) if mask_obj is not None else None
pred_array = to_numpy(predicted_mask)
# Handle batch dimension: (B, C, H, W)
if pred_array.ndim == 4:
for idx in range(pred_array.shape[0]):
batch_sample = dict(sample)
if image_array is not None and image_array.ndim == 4:
batch_sample["image"] = image_array[idx]
if isinstance(image_meta, list):
batch_sample["image_meta_dict"] = image_meta[idx]
if mask_array is not None and mask_array.ndim == 4:
batch_sample["mask"] = mask_array[idx]
self.__save_prediction_masks(
batch_sample,
pred_array[idx],
start_index=start_index+idx
)
return
# Determine base filename
if image_meta and "filename_or_obj" in image_meta:
base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0]
else:
# Use provided start_index when metadata missing
base_name = f"prediction_{start_index:04d}"
# Now pred_array shape is (C, H, W)
num_channels = pred_array.shape[0]
for channel_idx in range(num_channels):
channel_mask = pred_array[channel_idx]
# File names
mask_filename = f"{base_name}_ch{channel_idx:02d}.tif"
plot_filename = f"{base_name}_ch{channel_idx:02d}.png"
mask_path = os.path.join(masks_dir, mask_filename)
plot_path = os.path.join(plots_dir, plot_filename)
# Save mask TIFF (16-bit)
tiff.imwrite(mask_path, channel_mask.astype(np.uint16), compression="zlib")
# Extract corresponding true mask channel if exists
true_mask = None
if mask_array is not None and mask_array.ndim == 3:
true_mask = mask_array[channel_idx]
# Generate and save visualization
self.__plot_mask(
file_path=plot_path,
image_data=image_array, # type: ignore
predicted_mask=channel_mask,
true_mask=true_mask,
)
def __plot_mask(
self,
file_path: str,
image_data: np.ndarray,
predicted_mask: np.ndarray,
true_mask: Optional[np.ndarray] = None,
) -> None:
"""
Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided.
"""
img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data
if true_mask is None:
fig, axs = plt.subplots(1, 3, figsize=(15,5))
plt.subplots_adjust(wspace=0.02, hspace=0)
self.__plot_panels(axs, img, predicted_mask, 'red',
('Original Image','Predicted Mask','Predicted Contours'))
else:
fig, axs = plt.subplots(2,3,figsize=(15,10))
plt.subplots_adjust(wspace=0.02, hspace=0.1)
# row 0: predicted
self.__plot_panels(axs[0], img, predicted_mask, 'red',
('Original Image','Predicted Mask','Predicted Contours'))
# row 1: true
self.__plot_panels(axs[1], img, true_mask, 'blue',
('Original Image','True Mask','True Contours'))
fig.savefig(file_path, bbox_inches='tight', dpi=300)
plt.close(fig)
def __plot_panels(
self,
axes,
img: np.ndarray,
mask: np.ndarray,
contour_color: str,
titles: Tuple[str, ...]
):
"""
Plot a row of three panels: original image, mask, and mask boundaries on image.
Args:
axes: list/array of three Axis objects.
img: Image array (H, W or H, W, C).
mask: Label mask (H, W).
contour_color: Color for boundaries.
titles: (title_img, title_mask, title_contours).
"""
# Panel 1: Original image
ax0, ax1, ax2 = axes
ax0.imshow(img, cmap='gray' if img.ndim == 2 else None)
ax0.set_title(titles[0]); ax0.axis('off')
# Compute boundaries once
boundaries = find_boundaries(mask, mode='thick')
# Panel 2: Mask with black boundaries
cmap = plt.get_cmap("gist_ncar")
cmap = mcolors.ListedColormap([
cmap(i/len(np.unique(mask))) for i in range(len(np.unique(mask)))
])
ax1.imshow(mask, cmap=cmap)
ax1.contour(boundaries, colors='black', linewidths=0.5)
ax1.set_title(titles[1])
ax1.axis('off')
# Panel 3: Original image with black contour overlay
ax2.imshow(img)
# Draw boundaries as contour lines
ax2.contour(boundaries, colors=contour_color, linewidths=0.5)
ax2.set_title(titles[2])
ax2.axis('off')
def __compute_flows_from_masks(
self,
true_masks: Tensor
) -> np.ndarray:
"""
Convert segmentation masks to flow fields for training.
Args:
true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks.
Returns:
numpy array of concatenated [flow_vectors, renumbered_true_masks] per image.
renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow.
"""
# Move to CPU numpy
_true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16)
batch_size = _true_masks.shape[0]
# Ensure each label has a channel dimension
if _true_masks.ndim == 3:
# shape (batch, H, W) -> (batch, 1, H, W)
_true_masks = _true_masks[:, np.newaxis, :, :]
batch_size, *_ = _true_masks.shape
# Renumber labels to ensure uniqueness
renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0]
for i in range(batch_size)])
# Compute vector flows per image
flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i])
for i in range(batch_size)])
return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32)
def __compute_flow_from_mask(
self,
mask: np.ndarray
) -> np.ndarray:
"""
Compute normalized flow vectors from a labeled mask.
Args:
mask: 3D array of instance-labeled mask of shape (C, H, W).
Returns:
flow: Array of shape (2 * C, H, W).
"""
if mask.max() == 0 or np.count_nonzero(mask) <= 1:
# No flow to compute
logger.warning("Empty mask!")
C, H, W = mask.shape
return np.zeros((2*C, H, W), dtype=np.float32)
# Delegate to GPU or CPU routine
if self._device.type == "cuda" or self._device.type == "mps":
return self.__mask_to_flow_gpu(mask)
else:
return self.__mask_to_flow_cpu(mask)
def __mask_to_flow_gpu(self, mask: np.ndarray) -> np.ndarray:
"""Convert masks to flows using diffusion from center pixel.
Center of masks where diffusion starts is defined by pixel closest to median within the mask.
Args:
masks (3D array): Labelled masks of shape (C, H, W).
Returns:
np.ndarray: A 3D array where for each channel the flows for each pixel
are represented along the X and Y axes.
"""
channels, height, width = mask.shape
flows = np.zeros((2*channels, height, width), np.float32)
for channel in range(channels):
padded_height, padded_width = height + 2, width + 2
# Pad the mask with a 1-pixel border
masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device)
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
# Get coordinates of all non-zero pixels in the padded mask
y, x = torch.nonzero(masks_padded, as_tuple=True)
y = y.int(); x = x.int() # ensure integer type
# Generate 8-connected neighbors (including center) via broadcasted offsets
offsets = torch.tensor([
[ 0, 0], # center
[-1, 0], # up
[ 1, 0], # down
[ 0, -1], # left
[ 0, 1], # right
[-1, -1], # up-left
[-1, 1], # up-right
[ 1, -1], # down-left
[ 1, 1], # down-right
], dtype=torch.int32, device=self._device) # (9, 2)
# coords: (N, 2)
coords = torch.stack((y, x), dim=1)
# neighbors: (9, N, 2)
neighbors = offsets[:, None, :] + coords[None, :, :]
# transpose into (2, 9, N) for the GPU kernel
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
# Build connectivity mask: True where neighbor label == center label
center_labels = masks_padded[y, x][None, :] # (1, N)
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
is_neighbor = neighbor_labels == center_labels # (9, N)
# Compute object slices and pack into array for get_centers
slices = find_objects(mask)
slices_arr = np.array([
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
for i, sl in enumerate(slices) if sl is not None
], dtype=int)
# Compute centers (pixel indices) and extents via the provided helper
centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr)
# Move centers to GPU and shift by +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding
# Determine number of diffusion iterations
n_iter = 2 * ext.max()
# Run the GPU diffusion kernel
mu = self.__propagate_centers_gpu(
neighbor_indices=neighbors,
center_indices=meds_p.T,
valid_neighbor_mask=is_neighbor,
output_shape=(padded_height, padded_width),
num_iterations=n_iter
)
# Cast to float64 and normalize flow vectors
mu = mu.astype(np.float64)
mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60
# Remove the padding and write into final output
flow_output = np.zeros((2, height, width), dtype=np.float32)
ys_np = y.cpu().numpy() - 1
xs_np = x.cpu().numpy() - 1
flow_output[:, ys_np, xs_np] = mu
flows[2*channel: 2*channel + 2] = flow_output
return flows
@staticmethod
@njit(nogil=True)
def __get_mask_centers_and_extents(
label_map: np.ndarray,
slices_arr: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute the centroids and extents of labeled regions in a 2D mask array.
Args:
label_map (np.ndarray): 2D array where each connected region has a unique integer label (1…K).
slices_arr (np.ndarray): Array of shape (K, 5), where each row is
(label_id, row_start, row_stop, col_start, col_stop).
Returns:
centers (np.ndarray): Integer array of shape (K, 2) with (row, col) center for each label.
extents (np.ndarray): Integer array of shape (K,) giving the sum of height and width + 2 for each region.
"""
num_regions = slices_arr.shape[0]
centers = np.zeros((num_regions, 2), dtype=np.int32)
extents = np.zeros(num_regions, dtype=np.int32)
for idx in prange(num_regions):
# Unpack slice info
label_id = slices_arr[idx, 0]
row_start = slices_arr[idx, 1]
row_stop = slices_arr[idx, 2]
col_start = slices_arr[idx, 3]
col_stop = slices_arr[idx, 4]
# Extract binary submask for this label
submask = (label_map[row_start:row_stop, col_start:col_stop] == label_id)
# Get local coordinates of all pixels in the region
ys, xs = np.nonzero(submask)
# Compute the floating-point centroid within the submask
y_mean = ys.mean()
x_mean = xs.mean()
# Find the pixel closest to the centroid by minimizing squared distance
dist_sq = (ys - y_mean) ** 2 + (xs - x_mean) ** 2
closest_idx = dist_sq.argmin()
# Convert to global coordinates
center_row = ys[closest_idx] + row_start
center_col = xs[closest_idx] + col_start
centers[idx, 0] = center_row
centers[idx, 1] = center_col
# Compute extent as height + width + 2 (to include one-pixel border)
height = row_stop - row_start
width = col_stop - col_start
extents[idx] = height + width + 2
return centers, extents
def __propagate_centers_gpu(
self,
neighbor_indices: torch.Tensor,
center_indices: torch.Tensor,
valid_neighbor_mask: torch.Tensor,
output_shape: Tuple[int, int],
num_iterations: int = 200
) -> np.ndarray:
"""
Propagates center points across a mask using GPU-based diffusion.
Args:
neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel.
center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers.
valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid.
output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W).
num_iterations (int, optional): Number of diffusion iterations. Defaults to 200.
Returns:
np.ndarray: Array of shape (2, N) with the computed flows.
"""
# Determine total number of elements and choose dtype accordingly
total_elems = torch.prod(torch.tensor(output_shape))
if total_elems > 4e7 or self._device.type == "mps":
diffusion_tensor = torch.zeros(output_shape, dtype=torch.float, device=self._device)
else:
diffusion_tensor = torch.zeros(output_shape, dtype=torch.double, device=self._device)
# Unpack center row and column indices
center_rows, center_cols = center_indices
# Unpack neighbor row and column indices for 9 neighbors per pixel
# Order: [0: center, 1: up, 2: down, 3: left, 4: right,
# 5: up-left, 6: up-right, 7: down-left, 8: down-right]
neigh_rows, neigh_cols = neighbor_indices # each of shape (9, N)
# Perform diffusion iterations
for _ in range(num_iterations):
# Add source at each mask center
diffusion_tensor[center_rows, center_cols] += 1
# Sample neighbor values for each pixel
neighbor_vals = diffusion_tensor[neigh_rows, neigh_cols] # shape (9, N)
# Zero out invalid neighbors
neighbor_vals *= valid_neighbor_mask
# Update the first neighbor (index 0) with the average of valid neighbor values
diffusion_tensor[neigh_rows[0], neigh_cols[0]] = neighbor_vals.mean(dim=0)
# Compute spatial gradients for 2D flow: dy and dx
# Using neighbor indices: up = 1, down = 2, left = 3, right = 4
grad_samples = diffusion_tensor[
neigh_rows[[2, 1, 4, 3]], # indices [down, up, right, left]
neigh_cols[[2, 1, 4, 3]]
] # shape (4, N)
dy = grad_samples[0] - grad_samples[1]
dx = grad_samples[2] - grad_samples[3]
# Stack and convert to numpy flow field with shape (2, N)
flow_field = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=0)
return flow_field
def __mask_to_flow_cpu(self, mask: np.ndarray) -> np.ndarray:
"""
Convert labeled masks to flow vectors by simulating diffusion from mask centers.
Each mask's center is chosen as the pixel closest to its geometric centroid.
A diffusion process is run on a padded local patch, and flows are derived
as gradients (dy, dx) of the resulting density map.
Args:
masks (np.ndarray): 3D integer array of labels `(C x H x W)`,
where 0 = background and positive integers = mask IDs.
Returns:
flow_field (np.ndarray): Array of shape `(2*C, H, W)` containing
flow components [dy, dx] normalized per pixel.
"""
channels, height, width = mask.shape
flows = np.zeros((2*channels, height, width), np.float32)
for channel in range(channels):
# Initialize flow_field with two channels: dy and dx
flow_field = np.zeros((2, height, width), dtype=np.float64)
# Find bounding box for each labeled mask
mask_slices = find_objects(mask)
# centers: List[Tuple[int, int]] = []
# Iterate over mask labels in parallel
for label_idx in prange(len(mask_slices)):
slc = mask_slices[label_idx]
if slc is None:
continue
# Extract row and column slice for this mask
row_slice, col_slice = slc
# Add 1-pixel border around the patch
patch_height = (row_slice.stop - row_slice.start) + 2
patch_width = (col_slice.stop - col_slice.start) + 2
# Get local coordinates of mask pixels within the patch
local_rows, local_cols = np.nonzero(
mask[row_slice, col_slice] == (label_idx + 1)
)
# Shift coords by +1 for the border padding
local_rows = local_rows.astype(np.int32) + 1
local_cols = local_cols.astype(np.int32) + 1
# Compute centroid and find nearest pixel as diffusion seed
centroid_row = local_rows.mean()
centroid_col = local_cols.mean()
distances = (local_cols - centroid_col) ** 2 + (local_rows - centroid_row) ** 2
seed_index = distances.argmin()
center_row = int(local_rows[seed_index])
center_col = int(local_cols[seed_index])
# Determine number of iterations
total_iter = 2 * (patch_height + patch_width)
# Initialize flat diffusion map for the local patch
diffusion_map = np.zeros(patch_height * patch_width, dtype=np.float64)
# Run diffusion from the seed center
diffusion_map = self.__diffuse_from_center(
diffusion_map,
local_rows,
local_cols,
center_row,
center_col,
patch_width,
total_iter
)
# Compute flow as finite differences (gradient) on the diffusion map
dy = (
diffusion_map[(local_rows + 1) * patch_width + local_cols] -
diffusion_map[(local_rows - 1) * patch_width + local_cols]
)
dx = (
diffusion_map[local_rows * patch_width + (local_cols + 1)] -
diffusion_map[local_rows * patch_width + (local_cols - 1)]
)
# Write flows back into the global flow_field array
flow_field[0,
row_slice.start + local_rows - 1,
col_slice.start + local_cols - 1] = dy
flow_field[1,
row_slice.start + local_rows - 1,
col_slice.start + local_cols - 1] = dx
# Store center location in original image coordinates
# centers.append(
# (row_slice.start + center_row - 1,
# col_slice.start + center_col - 1)
# )
# Normalize each vector [dy,dx] by its magnitude
magnitudes = np.sqrt((flow_field**2).sum(axis=0)) + 1e-60
flow_field /= magnitudes
flows[2*channel: 2*channel + 2] = flow_field
return flows
@staticmethod
@njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True)
def __diffuse_from_center(
diffusion_map: np.ndarray,
row_coords: np.ndarray,
col_coords: np.ndarray,
center_row: int,
center_col: int,
patch_width: int,
num_iterations: int
) -> np.ndarray:
"""
Perform diffusion of particles from a seed pixel across a local mask patch.
At each iteration, one particle is added at the seed, and each mask pixel's
value is updated to the average of itself and its 8-connected neighbors.
Args:
diffusion_map (np.ndarray): Flat array of length patch_height * patch_width.
row_coords (np.ndarray): 1D array of row indices for mask pixels (local coords).
col_coords (np.ndarray): 1D array of column indices for mask pixels (local coords).
center_row (int): Row index of the seed point in local patch coords.
center_col (int): Column index of the seed point in local patch coords.
patch_width (int): Width (number of columns) in the local patch.
num_iterations (int): Number of diffusion iterations to perform.
Returns:
np.ndarray: Updated diffusion_map after performing diffusion.
"""
# Compute linear indices for each mask pixel and its neighbors
base_idx = row_coords * patch_width + col_coords
up = (row_coords - 1) * patch_width + col_coords
down = (row_coords + 1) * patch_width + col_coords
left = row_coords * patch_width + (col_coords - 1)
right = row_coords * patch_width + (col_coords + 1)
up_left = (row_coords - 1) * patch_width + (col_coords - 1)
up_right = (row_coords - 1) * patch_width + (col_coords + 1)
down_left = (row_coords + 1) * patch_width + (col_coords - 1)
down_right = (row_coords + 1) * patch_width + (col_coords + 1)
for _ in range(num_iterations):
# Inject one particle at the seed location
diffusion_map[center_row * patch_width + center_col] += 1.0
# Update each mask pixel as the average over itself and neighbors
diffusion_map[base_idx] = (
diffusion_map[base_idx] +
diffusion_map[up] + diffusion_map[down] +
diffusion_map[left] + diffusion_map[right] +
diffusion_map[up_left] + diffusion_map[up_right] +
diffusion_map[down_left] + diffusion_map[down_right]
) * (1.0 / 9.0)
return diffusion_map
def __segment_instances(
self,
probability_map: np.ndarray,
flow: np.ndarray,
prob_threshold: float = 0.0,
flow_threshold: float = 0.4,
num_iters: int = 200,
min_object_size: int = 15
) -> np.ndarray:
"""
Generate instance segmentation masks from probability and flow fields.
Args:
probability_map: 3D array (channels, height, width) of cell probabilities.
flow: 3D array (2*channels, height, width) of forward flow vectors.
prob_threshold: threshold to binarize probability_map. (Default 0.0)
flow_threshold: threshold for filtering bad flow masks. (Default 0.4)
num_iters: number of iterations for flow-following. (Default 200)
min_object_size: minimum area to keep small instances. (Default 15)
Returns:
3D array of uint16 instance labels for each channel.
"""
# Create a binary mask of likely cell locations
probability_mask = probability_map > prob_threshold
# If no cells exceed the threshold, return an empty mask
if not np.any(probability_mask):
logger.warning("No cell pixels found.")
return np.zeros_like(probability_map, dtype=np.uint16)
# Prepare output array for instance labels
labeled_instances = np.zeros_like(probability_map, dtype=np.uint16)
# Process each channel independently
for channel_index in range(probability_mask.shape[0]):
# Extract flow vectors for this channel (two components per channel)
channel_flow_vectors = flow[2 * channel_index : 2 * channel_index + 2]
# Extract binary mask for this channel
channel_mask = probability_mask[channel_index]
nonzero_coords = np.stack(np.nonzero(channel_mask))
# Follow the flow vectors to generate coordinate mappings
flow_coordinates = self.__follow_flows(
flow_field=channel_flow_vectors * channel_mask / 5.0,
initial_coords=nonzero_coords,
num_iters=num_iters
)
# If flow following fails, leave this channel empty
if flow_coordinates is None:
labeled_instances[channel_index] = np.zeros(
probability_map.shape[1:], dtype=np.uint16
)
continue
if not torch.is_tensor(flow_coordinates):
flow_coordinates = torch.from_numpy(
flow_coordinates).to(self._device, dtype=torch.int32)
else:
flow_coordinates = flow_coordinates.int()
# Obtain preliminary instance masks by clustering the coordinates
channel_instances_mask = self.__get_mask(
pixel_positions=flow_coordinates,
valid_indices=nonzero_coords,
original_shape=probability_map.shape[1:]
)
# Filter out bad flow-derived instances if requested
if channel_instances_mask.max() > 0 and flow_threshold > 0:
channel_instances_mask = self.__remove_inconsistent_flow_masks(
mask=channel_instances_mask,
flow_network=channel_flow_vectors,
error_threshold=flow_threshold
)
# Remove small objects or holes below the minimum size
if min_object_size > 0:
# channel_instances_mask = morphology.remove_small_holes(
# channel_instances_mask, area_threshold=min_object_size
# )
# channel_instances_mask = morphology.remove_small_objects(
# channel_instances_mask, min_size=min_object_size
# )
channel_instances_mask = self.__fill_holes_and_prune_small_masks(
channel_instances_mask, minimum_size=min_object_size
)
labeled_instances[channel_index] = channel_instances_mask
else:
# No valid instances found, leave the channel empty
labeled_instances[channel_index] = np.zeros(
probability_map.shape[1:], dtype=np.uint16
)
return labeled_instances
def __follow_flows(
self,
flow_field: np.ndarray,
initial_coords: np.ndarray,
num_iters: int = 200
) -> Union[np.ndarray, torch.Tensor]:
"""
Trace pixel positions through a flow field via iterative interpolation.
Args:
flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors.
initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions.
num_iters (int): Number of integration steps.
Returns:
np.ndarray or torch.Tensor: Final (y, x) positions of each point.
"""
dims = 2
# Extract spatial dimensions
height, width = flow_field.shape[1:]
# Choose GPU/MPS path if available
if self._device.type in ("cuda", "mps"):
# Prepare point tensor: shape [1, 1, num_points, 2]
pts = torch.zeros((1, 1, initial_coords.shape[1], dims),
dtype=torch.float32, device=self._device)
# Prepare flow volume: shape [1, 2, height, width]
flow_vol = torch.zeros((1, dims, height, width),
dtype=torch.float32, device=self._device)
# Load initial positions and flow into tensors (flip order for grid_sample)
# dim 0 = x
# dim 1 = y
for i in range(dims):
pts[0, 0, :, dims - i - 1] = (
torch.from_numpy(initial_coords[i])
.to(self._device, torch.float32)
)
flow_vol[0, dims - i - 1] = (
torch.from_numpy(flow_field[i])
.to(self._device, torch.float32)
)
# Prepare normalization factors for x and y (max index)
max_indices = torch.tensor([width - 1, height - 1],
dtype=torch.float32, device=self._device)
# Reshape for broadcasting to point tensor dims
max_idx_pt = max_indices.view(1, 1, 1, dims)
# Reshape for broadcasting to flow volume dims
max_idx_flow = max_indices.view(1, dims, 1, 1)
# Normalize flow values to [-1, 1] range
flow_vol = (flow_vol * 2) / max_idx_flow
# Normalize points to [-1, 1]
pts = (pts / max_idx_pt) * 2 - 1
# Iterate: sample flow and update points
for _ in range(num_iters):
sampled = torch.nn.functional.grid_sample(
flow_vol, pts, align_corners=False
)
# Update each coordinate and clamp to valid range
for i in range(dims):
pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0)
# Denormalize back to original pixel coordinates
pts = (pts + 1) * 0.5 * max_idx_pt
# Swap channels back to (y, x) and flatten
final_pts = pts[..., [1, 0]].squeeze()
# Convert from (num_points, 2) to (2, num_points)
return final_pts.T if final_pts.ndim > 1 else final_pts.unsqueeze(0).T
# CPU fallback using numpy and scipy
current_pos = initial_coords.copy().astype(np.float32)
temp_delta = np.zeros_like(current_pos, dtype=np.float32)
for _ in range(num_iters):
# Interpolate flow at current positions
self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta)
# Update positions and clamp to image bounds
current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1)
current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1)
return current_pos
@staticmethod
@njit([
"(int16[:,:,:], float32[:], float32[:], float32[:,:])",
"(float32[:,:,:], float32[:], float32[:], float32[:,:])"
], cache=True)
def __map_coordinates(
image_data: np.ndarray,
y_coords: np.ndarray,
x_coords: np.ndarray,
output: np.ndarray
) -> None:
"""
Perform in-place bilinear interpolation on an image volume.
Args:
image_data (np.ndarray): Input volume with shape (C, H, W).
y_coords (np.ndarray): Array of new y positions (num_points).
x_coords (np.ndarray): Array of new x positions (num_points).
output (np.ndarray): Output array of shape (C, num_points) to fill.
Returns:
None. Results written directly into `output`.
"""
channels, height, width = image_data.shape
# Compute integer (floor) and fractional parts for coords
y_floor = y_coords.astype(np.int32)
x_floor = x_coords.astype(np.int32)
y_frac = y_coords - y_floor
x_frac = x_coords - x_floor
# Loop over each sample point
for idx in range(y_floor.shape[0]):
# Clamp base indices to valid range
y0 = min(max(y_floor[idx], 0), height - 1)
x0 = min(max(x_floor[idx], 0), width - 1)
y1 = min(y0 + 1, height - 1)
x1 = min(x0 + 1, width - 1)
wy = y_frac[idx]
wx = x_frac[idx]
# Interpolate per channel
for c in range(channels):
v00 = np.float32(image_data[c, y0, x0])
v10 = np.float32(image_data[c, y0, x1])
v01 = np.float32(image_data[c, y1, x0])
v11 = np.float32(image_data[c, y1, x1])
# Bilinear interpolation formula
output[c, idx] = (
v00 * (1 - wy) * (1 - wx) +
v10 * (1 - wy) * wx +
v01 * wy * (1 - wx) +
v11 * wy * wx
)
def __get_mask(
self,
pixel_positions: torch.Tensor,
valid_indices: np.ndarray,
original_shape: Tuple[int, ...],
pad_radius: int = 20,
max_size_fraction: float = 0.4
) -> np.ndarray:
"""
Generate labeled masks by clustering pixel trajectories via histogram peaks and region growing.
This function executes the following steps:
1. Pads and clamps pixel final positions to avoid border effects.
2. Builds a dense histogram of pixel counts over spatial bins.
3. Identifies local maxima in the histogram as seed points.
4. Extracts local patches around each seed and grows regions by iteratively adding neighbors
that exceed an intensity threshold.
5. Maps grown patches back to original image indices.
6. Removes any masks that exceed a maximum size fraction of the image.
Args:
pixel_positions (torch.Tensor): Tensor of shape `[2, N_pixels]`, dtype=int, containing
final pixel coordinates after dynamics for each dimension.
valid_indices (np.ndarray): Integer array of shape `[2, N_pixels]`
giving indices of pixels in the original image grid.
original_shape (tuple of ints): Spatial dimensions of the original image, e.g. (H, W).
pad_radius (int): Number of zero-padding pixels added on each side of the histogram.
Defaults to 20.
max_size_fraction (float): If any mask has a pixel count > max_size_fraction * total_pixels,
it will be removed. Defaults to 0.4.
Returns:
np.ndarray: Integer mask array of shape `original_shape` with labels 0 (background) and 1..M.
Raises:
ValueError: If input dimensions are inconsistent or pixel_positions shape is invalid.
"""
# Validate inputs
ndim = len(original_shape)
if pixel_positions.ndim != 2 or pixel_positions.size(0) != ndim:
msg = f"pixel_positions must be shape [{ndim}, N], got {tuple(pixel_positions.shape)}"
logger.error(msg)
raise ValueError(msg)
if pad_radius < 0:
msg = f"pad_radius must be non-negative, got {pad_radius}"
logger.error(msg)
raise ValueError(msg)
# Step 1: Pad and clamp pixel positions
padded_positions = pixel_positions.clone().to(torch.int64) + pad_radius
for dim in range(ndim):
max_val = original_shape[dim] + pad_radius - 1
padded_positions[dim] = torch.clamp(padded_positions[dim], min=0, max=max_val)
# Build histogram dimensions
hist_shape = tuple(s + 2 * pad_radius for s in original_shape)
# Step 2: Create sparse tensor and densify to get per-pixel counts
try:
counts_sparse = torch.sparse_coo_tensor(
padded_positions,
torch.ones(padded_positions.shape[1], dtype=torch.int32, device=pixel_positions.device),
size=hist_shape
)
histogram = counts_sparse.to_dense()
except Exception as e:
logger.error("Failed to build dense histogram: %s", e)
raise
# Step 3: Find peaks via 5x5 max-pooling
k = 5
pooled = F.max_pool2d(
histogram.unsqueeze(0),
kernel_size=k,
stride=1,
padding=k // 2
).squeeze()
# Seeds are positions where histogram equals local max and count > threshold
seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10))
if seed_positions.numel() == 0:
logger.warning("No seeds found: returning empty mask")
return np.zeros(original_shape, dtype=np.uint16)
# Sort seeds by ascending count to process small peaks first
seed_counts = histogram[tuple(seed_positions.T)]
order = torch.argsort(seed_counts)
seed_positions = seed_positions[order]
del pooled, counts_sparse
# Step 4: Extract local patches and perform region growing
num_seeds = seed_positions.shape[0]
# Tensor to hold local patches
patches = torch.zeros((num_seeds, 11, 11), device=pixel_positions.device)
for idx in range(num_seeds):
coords = seed_positions[idx]
slices = tuple(slice(c - 5, c + 6) for c in coords)
patches[idx] = histogram[slices]
del histogram
# Initialize seed mask (center pixel of each patch)
seed_masks = torch.zeros_like(patches, device=pixel_positions.device)
seed_masks[:, 5, 5] = 1
# Iterative dilation and thresholding
for _ in range(5):
seed_masks = F.max_pool2d(
seed_masks,
kernel_size=3,
stride=1,
padding=1
)
seed_masks = seed_masks & (patches > 2)
# Compute final mask coordinates
final_coords = []
for idx in range(num_seeds):
coords_local = torch.nonzero(seed_masks[idx])
# Shift back to global positions
coords_global = coords_local + seed_positions[idx] - 5
final_coords.append(tuple(coords_global.T))
# Step 5: Paint masks into padded volume
dtype = torch.int32 if num_seeds < 2**16 else torch.int64
mask_padded = torch.zeros(hist_shape, dtype=dtype, device=pixel_positions.device)
for label_idx, coords in enumerate(final_coords, start=1):
mask_padded[coords] = label_idx
# Extract only the padded positions that correspond to original pixels
mask_values = mask_padded[tuple(padded_positions)]
mask_values = mask_values.cpu().numpy()
# Step 6: Map to original image and remove oversized masks
mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32)
mask_final[valid_indices] = mask_values
# Prune masks that are too large
labels, counts = fastremap.unique(mask_final, return_counts=True)
total_pixels = np.prod(original_shape)
oversized = labels[counts > (total_pixels * max_size_fraction)]
if oversized.size > 0:
mask_final = fastremap.mask(mask_final, oversized)
fastremap.renumber(mask_final, in_place=True)
return mask_final
def __remove_inconsistent_flow_masks(
self,
mask: np.ndarray,
flow_network: np.ndarray,
error_threshold: float = 0.4
) -> np.ndarray:
"""
Remove labeled masks that have inconsistent optical flows compared to network-predicted flows.
This performs a quality control step by computing flows from the provided masks
and comparing them to the flows predicted by the network. Masks with a mean squared
flow error above `error_threshold` are discarded (set to 0).
Args:
mask (np.ndarray): Integer mask array with shape [H, W].
Values: 0 = no mask; 1,2,... = mask labels.
flow_network (np.ndarray): Float array of network-predicted flows with shape
[2, H, W].
error_threshold (float): Maximum allowed mean squared flow error per mask label.
Defaults to 0.4.
Returns:
np.ndarray: The input mask with inconsistent masks removed (labels set to 0).
Raises:
MemoryError: If the mask size exceeds available GPU memory.
"""
# If mask is very large and running on CUDA, check memory
num_pixels = mask.size
if (
num_pixels > 10000 * 10000
and self._device.type == 'cuda'
):
# Clear unused GPU cache
torch.cuda.empty_cache()
# Determine PyTorch version
major, minor = map(int, torch.__version__.split('.')[:2])
# Determine current CUDA device index
device_index = (
self._device.index
if hasattr(self._device, 'index')
else torch.cuda.current_device()
)
# Get free and total memory
if major == 1 and minor < 10:
total_mem = torch.cuda.get_device_properties(device_index).total_memory
used_mem = torch.cuda.memory_allocated(device_index)
free_mem = total_mem - used_mem
else:
free_mem, total_mem = torch.cuda.mem_get_info(device_index)
# Estimate required memory for mask-based flow computation
# Assume float32 per pixel
required_bytes = num_pixels * np.dtype(np.float32).itemsize
if required_bytes > free_mem:
logger.error(
'Image too large for GPU memory in flow QC step (required: %d B, available: %d B)',
required_bytes, free_mem
)
raise MemoryError('Insufficient GPU memory for flow QC computation')
# Compute flow errors per mask label
flow_errors, _ = self.__compute_flow_error(mask, flow_network)
# Identify labels with error above threshold
bad_labels = np.nonzero(flow_errors > error_threshold)[0] + 1
# Remove bad masks by setting their label to 0
mask[np.isin(mask, bad_labels)] = 0
return mask
def __compute_flow_error(
self,
mask: np.ndarray,
flow_network: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute mean squared error between network-predicted flows and flows derived from masks.
Args:
mask (np.ndarray): Integer masks, shape must match flow_network spatial dims.
flow_network (np.ndarray): Network predicted flows of shape [axis, ...].
Returns:
Tuple[np.ndarray, np.ndarray]:
- flow_errors: 1D array (length = max label) of mean squared error per label.
- computed_flows: Array of flows derived from the mask, same shape as flow_network.
Raises:
ValueError: If the spatial dimensions of `mask_array` and `flow_network` do not match.
"""
# Ensure mask and flow shapes match
if flow_network.shape[1:] != mask.shape:
logger.error(
'Shape mismatch: network flow shape %s vs mask shape %s',
flow_network.shape[1:], mask.shape
)
raise ValueError('Network flow and mask shapes must match')
# Compute flows from mask labels (user-provided function)
computed_flows = self.__compute_flow_from_mask(mask[None, ...])
# Prepare array for errors (one value per mask label)
num_labels = int(mask.max())
flow_errors = np.zeros(num_labels, dtype=float)
# Accumulate mean squared error over each flow axis
for axis_index in range(computed_flows.shape[0]):
# MSE per label: mean((computed - predicted/5)^2)
flow_errors += mean(
(computed_flows[axis_index] - flow_network[axis_index] / 5.0) ** 2,
mask,
index=np.arange(1, num_labels + 1)
)
return flow_errors, computed_flows
def __fill_holes_and_prune_small_masks(
self,
masks: np.ndarray,
minimum_size: int = 15
) -> np.ndarray:
"""
Fill holes in labeled masks and remove masks smaller than a given size.
This function performs two steps:
1. Fills internal holes in each labeled mask using `fill_voids.fill`.
2. Discards any mask whose pixel count is below `minimum_size`.
Args:
masks (np.ndarray): Integer mask array of dimension 2 or 3 (shape [H, W] or [D, H, W]).
Values: 0 = background; 1,2,... = mask labels.
minimum_size (int): Minimum number of pixels required to keep a mask.
Masks smaller than this will be removed.
Set to -1 to skip size-based pruning. Defaults to 15.
Returns:
np.ndarray: Processed mask array with holes filled and small masks removed.
Raises:
ValueError: If `masks` is not a 2D or 3D integer array.
"""
# Validate input dimensions
if masks.ndim not in (2, 3):
msg = f"Expected 2D or 3D mask array, got {masks.ndim}D."
logger.error(msg)
raise ValueError(msg)
# Optionally remove masks smaller than minimum_size
if minimum_size >= 0:
# Compute label counts (skipping background at index 0)
labels, counts = fastremap.unique(masks, return_counts=True)
# Identify labels to remove: those with count < minimum_size
small_labels = labels[counts < minimum_size]
if small_labels.size > 0:
masks = fastremap.mask(masks, small_labels)
fastremap.renumber(masks, in_place=True)
# Find bounding boxes for each mask label
object_slices = find_objects(masks)
new_label = 1
output_masks = np.zeros_like(masks, dtype=masks.dtype)
# Loop over each original slice, fill holes, and assign new labels
for original_label, slc in enumerate(object_slices, start=1):
if slc is None:
continue
# Extract sub-volume or sub-image
region = masks[slc] == original_label
if not np.any(region):
continue
# Fill internal holes
filled_region = fill_voids.fill(region)
# Write back into output mask with sequential labels
output_masks[slc][filled_region] = new_label
new_label += 1
# Final pruning of small masks after filling (optional)
if minimum_size >= 0:
labels, counts = fastremap.unique(output_masks, return_counts=True)
small_labels = labels[counts < minimum_size]
if small_labels.size > 0:
output_masks = fastremap.mask(output_masks, small_labels)
fastremap.renumber(output_masks, in_place=True)
return output_masks