|
|
"""
|
|
|
This code is adapted from the following codes:
|
|
|
[1] https://github.com/MouseLand/cellpose/blob/main/cellpose/utils.py
|
|
|
[2] https://github.com/MouseLand/cellpose/blob/main/cellpose/dynamics.py
|
|
|
[3] https://github.com/MouseLand/cellpose/blob/main/cellpose/metrics.py
|
|
|
[4] https://github.com/Lee-Gihun/MEDIAR/tree/main/core
|
|
|
[5] https://github.com/Lee-Gihun/MEDIAR/tree/main/core/MEDIAR
|
|
|
"""
|
|
|
|
|
|
import random
|
|
|
import numpy as np
|
|
|
from numba import njit, prange
|
|
|
|
|
|
import torch
|
|
|
from torch import Tensor
|
|
|
import torch.nn.functional as F
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
import fastremap
|
|
|
|
|
|
import fill_voids
|
|
|
# from skimage import morphology
|
|
|
from skimage.segmentation import find_boundaries
|
|
|
from scipy.special import expit
|
|
|
from scipy.ndimage import mean, find_objects
|
|
|
|
|
|
from monai.data.dataset import Dataset
|
|
|
from monai.transforms.compose import Compose
|
|
|
from monai.inferers.utils import sliding_window_inference
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
import matplotlib.colors as mcolors
|
|
|
|
|
|
import os
|
|
|
import glob
|
|
|
import csv
|
|
|
import copy
|
|
|
import time
|
|
|
import tifffile as tiff
|
|
|
from itertools import chain
|
|
|
|
|
|
from pprint import pformat
|
|
|
from tabulate import tabulate
|
|
|
from typing import Any, Literal
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
import wandb
|
|
|
|
|
|
from config import Config
|
|
|
from core.models import ModelRegistry
|
|
|
from core.losses import CriterionRegistry
|
|
|
from core.optimizers import OptimizerRegistry
|
|
|
from core.schedulers import SchedulerRegistry
|
|
|
from core.utils import (
|
|
|
compute_batch_segmentation_tp_fp_fn,
|
|
|
compute_f1_score,
|
|
|
compute_average_precision_score
|
|
|
)
|
|
|
from core.data.transforms.load_image import SUPPORTED_IMAGE_FORMATS
|
|
|
|
|
|
from core.logger import get_logger
|
|
|
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
class CellSegmentator:
|
|
|
def __init__(self, config: Config) -> None:
|
|
|
self._device: torch.device = torch.device(config.dataset_config.common.device or "cpu")
|
|
|
|
|
|
self.__set_seed(config.dataset_config.common.seed)
|
|
|
self.__parse_config(config)
|
|
|
|
|
|
self._scaler = (
|
|
|
torch.amp.GradScaler(self._device.type) # type: ignore
|
|
|
if self._dataset_setup.is_training and self._dataset_setup.common.use_amp
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
self._train_dataloader: DataLoader | None = None
|
|
|
self._valid_dataloader: DataLoader | None = None
|
|
|
self._test_dataloader: DataLoader | None = None
|
|
|
self._predict_dataloader: DataLoader | None = None
|
|
|
|
|
|
self._best_weights = None
|
|
|
|
|
|
|
|
|
def create_dataloaders(
|
|
|
self,
|
|
|
train_transforms: Compose | None = None,
|
|
|
valid_transforms: Compose | None = None,
|
|
|
test_transforms: Compose | None = None,
|
|
|
predict_transforms: Compose | None = None
|
|
|
) -> None:
|
|
|
"""
|
|
|
Creates train, validation, test, and prediction dataloaders based on dataset configuration
|
|
|
and provided transforms.
|
|
|
|
|
|
Args:
|
|
|
train_transforms (Compose | None): Transformations for training data.
|
|
|
valid_transforms (Compose | None): Transformations for validation data.
|
|
|
test_transforms (Compose | None): Transformations for testing data.
|
|
|
predict_transforms (Compose | None): Transformations for prediction data.
|
|
|
|
|
|
Raises:
|
|
|
ValueError: If required transforms are missing.
|
|
|
RuntimeError: If critical dataset config values are missing.
|
|
|
"""
|
|
|
if self._dataset_setup.is_training and train_transforms is None:
|
|
|
raise ValueError("Training mode requires 'train_transforms' to be provided.")
|
|
|
elif not self._dataset_setup.is_training and test_transforms is None and predict_transforms is None:
|
|
|
raise ValueError("In inference mode, at least one of 'test_transforms' or 'predict_transforms' must be provided.")
|
|
|
|
|
|
if self._dataset_setup.is_training:
|
|
|
# Training mode: handle either pre-split datasets or splitting on the fly
|
|
|
if self._dataset_setup.training.is_split:
|
|
|
# Validate presence of validation transforms if validation directory and size are set
|
|
|
if (
|
|
|
self._dataset_setup.training.pre_split.valid_dir and
|
|
|
self._dataset_setup.training.valid_size and
|
|
|
valid_transforms is None
|
|
|
):
|
|
|
raise ValueError("Validation transforms must be provided when using pre-split validation data.")
|
|
|
|
|
|
# Use explicitly split directories
|
|
|
train_dir = self._dataset_setup.training.pre_split.train_dir
|
|
|
valid_dir = self._dataset_setup.training.pre_split.valid_dir
|
|
|
test_dir = self._dataset_setup.training.pre_split.test_dir
|
|
|
|
|
|
train_offset = self._dataset_setup.training.train_offset
|
|
|
valid_offset = self._dataset_setup.training.valid_offset
|
|
|
test_offset = self._dataset_setup.training.test_offset
|
|
|
|
|
|
shuffle = False
|
|
|
else:
|
|
|
# Same validation for split mode with full data directory
|
|
|
if (
|
|
|
self._dataset_setup.training.split.all_data_dir and
|
|
|
self._dataset_setup.training.valid_size and
|
|
|
valid_transforms is None
|
|
|
):
|
|
|
raise ValueError("Validation transforms must be provided when splitting dataset.")
|
|
|
|
|
|
# Automatically split dataset from one directory
|
|
|
train_dir = valid_dir = test_dir = self._dataset_setup.training.split.all_data_dir
|
|
|
|
|
|
number_of_images = len(os.listdir(os.path.join(train_dir, 'images')))
|
|
|
if number_of_images == 0:
|
|
|
raise FileNotFoundError(f"No images found in '{train_dir}/images'")
|
|
|
|
|
|
# Calculate train/valid sizes
|
|
|
train_size = (
|
|
|
self._dataset_setup.training.train_size
|
|
|
if isinstance(self._dataset_setup.training.train_size, int)
|
|
|
else int(number_of_images * self._dataset_setup.training.train_size)
|
|
|
)
|
|
|
valid_size = (
|
|
|
self._dataset_setup.training.valid_size
|
|
|
if isinstance(self._dataset_setup.training.valid_size, int)
|
|
|
else int(number_of_images * self._dataset_setup.training.valid_size)
|
|
|
)
|
|
|
|
|
|
train_offset = self._dataset_setup.training.train_offset
|
|
|
valid_offset = self._dataset_setup.training.valid_offset + train_size
|
|
|
test_offset = self._dataset_setup.training.test_offset + train_size + valid_size
|
|
|
|
|
|
shuffle = True
|
|
|
|
|
|
# Train dataloader
|
|
|
train_dataset = self.__get_dataset(
|
|
|
images_dir=os.path.join(train_dir, 'images'),
|
|
|
masks_dir=os.path.join(train_dir, 'masks', self._dataset_setup.common.masks_subdir),
|
|
|
transforms=train_transforms, # type: ignore
|
|
|
size=self._dataset_setup.training.train_size,
|
|
|
offset=train_offset,
|
|
|
shuffle=shuffle
|
|
|
)
|
|
|
self._train_dataloader = DataLoader(train_dataset, batch_size=self._dataset_setup.training.batch_size, shuffle=True)
|
|
|
logger.info(f"Loaded training dataset with {len(train_dataset)} samples.")
|
|
|
|
|
|
# Validation dataloader
|
|
|
if valid_transforms is not None:
|
|
|
if not valid_dir or not self._dataset_setup.training.valid_size:
|
|
|
raise RuntimeError("Validation directory or size is not properly configured.")
|
|
|
valid_dataset = self.__get_dataset(
|
|
|
images_dir=os.path.join(valid_dir, 'images'),
|
|
|
masks_dir=os.path.join(valid_dir, 'masks', self._dataset_setup.common.masks_subdir),
|
|
|
transforms=valid_transforms,
|
|
|
size=self._dataset_setup.training.valid_size,
|
|
|
offset=valid_offset,
|
|
|
shuffle=shuffle
|
|
|
)
|
|
|
self._valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
|
|
|
logger.info(f"Loaded validation dataset with {len(valid_dataset)} samples.")
|
|
|
|
|
|
# Test dataloader
|
|
|
if test_transforms is not None:
|
|
|
if not test_dir or not self._dataset_setup.training.test_size:
|
|
|
raise RuntimeError("Test directory or size is not properly configured.")
|
|
|
test_dataset = self.__get_dataset(
|
|
|
images_dir=os.path.join(test_dir, 'images'),
|
|
|
masks_dir=os.path.join(test_dir, 'masks', self._dataset_setup.common.masks_subdir),
|
|
|
transforms=test_transforms,
|
|
|
size=self._dataset_setup.training.test_size,
|
|
|
offset=test_offset,
|
|
|
shuffle=shuffle
|
|
|
)
|
|
|
self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
|
|
logger.info(f"Loaded test dataset with {len(test_dataset)} samples.")
|
|
|
|
|
|
# Prediction dataloader
|
|
|
if predict_transforms is not None:
|
|
|
if not test_dir or not self._dataset_setup.training.test_size:
|
|
|
raise RuntimeError("Prediction directory or size is not properly configured.")
|
|
|
predict_dataset = self.__get_dataset(
|
|
|
images_dir=os.path.join(test_dir, 'images'),
|
|
|
masks_dir=None,
|
|
|
transforms=predict_transforms,
|
|
|
size=self._dataset_setup.training.test_size,
|
|
|
offset=test_offset,
|
|
|
shuffle=shuffle
|
|
|
)
|
|
|
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
|
|
|
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
|
|
|
|
|
|
else:
|
|
|
# Inference mode (no training)
|
|
|
test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images')
|
|
|
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks', self._dataset_setup.common.masks_subdir)
|
|
|
|
|
|
if test_transforms is not None:
|
|
|
test_dataset = self.__get_dataset(
|
|
|
images_dir=test_images,
|
|
|
masks_dir=test_masks,
|
|
|
transforms=test_transforms,
|
|
|
size=self._dataset_setup.testing.test_size,
|
|
|
offset=self._dataset_setup.testing.test_offset,
|
|
|
shuffle=self._dataset_setup.testing.shuffle
|
|
|
)
|
|
|
self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
|
|
logger.info(f"Loaded test dataset with {len(test_dataset)} samples.")
|
|
|
|
|
|
if predict_transforms is not None:
|
|
|
predict_dataset = self.__get_dataset(
|
|
|
images_dir=test_images,
|
|
|
masks_dir=None,
|
|
|
transforms=predict_transforms,
|
|
|
size=self._dataset_setup.testing.test_size,
|
|
|
offset=self._dataset_setup.testing.test_offset,
|
|
|
shuffle=self._dataset_setup.testing.shuffle
|
|
|
)
|
|
|
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
|
|
|
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
|
|
|
|
|
|
|
|
|
def print_data_info(
|
|
|
self,
|
|
|
loader_type: Literal["train", "valid", "test", "predict"],
|
|
|
index: int | None = None
|
|
|
) -> None:
|
|
|
"""
|
|
|
Prints statistics for a single sample from the specified dataloader.
|
|
|
|
|
|
Args:
|
|
|
loader_type: One of "train", "valid", "test", "predict".
|
|
|
index: The sample index; if None, a random index is selected.
|
|
|
"""
|
|
|
# Retrieve the dataloader attribute, e.g., self._train_dataloader
|
|
|
loader: DataLoader | None = getattr(self, f"_{loader_type}_dataloader", None)
|
|
|
if loader is None:
|
|
|
logger.error(f"Dataloader '{loader_type}' is not initialized.")
|
|
|
return
|
|
|
|
|
|
dataset = loader.dataset
|
|
|
total = len(dataset) # type: ignore
|
|
|
if total == 0:
|
|
|
logger.error(f"Dataset for '{loader_type}' is empty.")
|
|
|
return
|
|
|
|
|
|
# Choose index
|
|
|
idx = index if index is not None else random.randint(0, total - 1)
|
|
|
if not (0 <= idx < total):
|
|
|
logger.error(f"Index {idx} is out of range [0, {total}).")
|
|
|
return
|
|
|
|
|
|
# Fetch the sample and apply transforms
|
|
|
sample = dataset[idx]
|
|
|
# Expecting a dict with {'image': ..., 'mask': ...} or {'image': ...}
|
|
|
img = sample["image"]
|
|
|
# Convert tensor to numpy if needed
|
|
|
if isinstance(img, torch.Tensor):
|
|
|
img = img.cpu().numpy()
|
|
|
img = np.asarray(img)
|
|
|
|
|
|
# Compute image statistics
|
|
|
img_min, img_max = img.min(), img.max()
|
|
|
img_mean, img_std = float(img.mean()), float(img.std())
|
|
|
img_shape = img.shape
|
|
|
|
|
|
# Prepare log lines
|
|
|
lines = [
|
|
|
"=" * 40,
|
|
|
f"Dataloader: '{loader_type}', sample index: {idx} / {total - 1}",
|
|
|
f"Image — shape: {img_shape}, min: {img_min:.4f}, max: {img_max:.4f}, mean: {img_mean:.4f}, std: {img_std:.4f}"
|
|
|
]
|
|
|
|
|
|
# For 'predict', no mask is available
|
|
|
if loader_type != "predict":
|
|
|
mask = sample.get("mask", None)
|
|
|
if mask is not None:
|
|
|
# Convert tensor to numpy if needed
|
|
|
if isinstance(mask, torch.Tensor):
|
|
|
mask = mask.cpu().numpy()
|
|
|
mask = np.asarray(mask)
|
|
|
m_min, m_max = mask.min(), mask.max()
|
|
|
m_mean, m_std = float(mask.mean()), float(mask.std())
|
|
|
m_shape = mask.shape
|
|
|
lines.append(
|
|
|
f"Mask — shape: {m_shape}, min: {m_min:.4f}, "
|
|
|
f"max: {m_max:.4f}, mean: {m_mean:.4f}, std: {m_std:.4f}"
|
|
|
)
|
|
|
else:
|
|
|
lines.append("Mask — not available for this sample.")
|
|
|
|
|
|
lines.append("=" * 40)
|
|
|
|
|
|
# Output via logger
|
|
|
for line in lines:
|
|
|
logger.info(line)
|
|
|
|
|
|
|
|
|
def train(self, save_results: bool = True, only_masks: bool = False) -> None:
|
|
|
"""
|
|
|
Train the model over multiple epochs, including validation and test.
|
|
|
|
|
|
Args:
|
|
|
save_results (bool): If True, the predicted masks and test metrics will be saved.
|
|
|
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
|
|
|
without visualization overlays.
|
|
|
"""
|
|
|
# Ensure training is enabled in dataset setup
|
|
|
if not self._dataset_setup.is_training:
|
|
|
raise RuntimeError("Dataset setup indicates training is disabled.")
|
|
|
|
|
|
# Determine device name for logging
|
|
|
if self._device.type == "cpu":
|
|
|
device_name = "cpu"
|
|
|
else:
|
|
|
idx = self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device()
|
|
|
device_name = torch.cuda.get_device_name(idx)
|
|
|
|
|
|
logger.info(f"\n{'=' * 50}\n")
|
|
|
logger.info(f"Training starts on device: {device_name}")
|
|
|
logger.info(f"\n{'=' * 50}")
|
|
|
|
|
|
best_f1_score = 0.0
|
|
|
|
|
|
for epoch in range(1, self._dataset_setup.training.num_epochs + 1):
|
|
|
train_metrics = self.__run_epoch("train", epoch)
|
|
|
self.__print_with_logging(train_metrics, epoch)
|
|
|
|
|
|
# Step the scheduler after training
|
|
|
if self._scheduler is not None:
|
|
|
self._scheduler.step()
|
|
|
|
|
|
# Periodic validation or tuning
|
|
|
if epoch % self._dataset_setup.training.val_freq == 0:
|
|
|
if self._valid_dataloader is not None:
|
|
|
valid_metrics = self.__run_epoch("valid", epoch)
|
|
|
self.__print_with_logging(valid_metrics, epoch)
|
|
|
|
|
|
# Update best model on improved F1
|
|
|
f1 = valid_metrics.get("valid_f1_score", 0.0)
|
|
|
if f1 > best_f1_score:
|
|
|
best_f1_score = f1
|
|
|
# Deep copy weights to avoid reference issues
|
|
|
self._best_weights = copy.deepcopy(self._model.state_dict())
|
|
|
logger.info(f"Updated best model weights with F1 score: {f1:.4f}")
|
|
|
|
|
|
# Restore best model weights if available
|
|
|
if self._best_weights is not None:
|
|
|
self._model.load_state_dict(self._best_weights)
|
|
|
|
|
|
if self._test_dataloader is not None:
|
|
|
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
|
|
|
self.__print_with_logging(test_metrics, 0)
|
|
|
|
|
|
save_path = self._dataset_setup.common.predictions_dir
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
|
|
|
|
|
|
|
|
|
def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None:
|
|
|
"""
|
|
|
Run a full test epoch and display/log the resulting metrics.
|
|
|
|
|
|
Args:
|
|
|
save_results (bool): If True, the predicted masks and test metrics will be saved.
|
|
|
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
|
|
|
without visualization overlays.
|
|
|
"""
|
|
|
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
|
|
|
self.__print_with_logging(test_metrics, 0)
|
|
|
|
|
|
save_path = self._dataset_setup.common.predictions_dir
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
|
|
|
|
|
|
|
|
|
def predict(self, only_masks: bool = False) -> None:
|
|
|
"""
|
|
|
Run inference on the predict set and save the resulting instance masks.
|
|
|
|
|
|
Args:
|
|
|
only_masks (bool): If True, only raw predicted masks are saved,
|
|
|
without visualization overlays.
|
|
|
"""
|
|
|
# Ensure the predict DataLoader has been set
|
|
|
if self._predict_dataloader is None:
|
|
|
raise RuntimeError("DataLoader for mode 'predict' is not set.")
|
|
|
|
|
|
batch_counter = 0
|
|
|
for batch in tqdm(self._predict_dataloader, desc="Predicting"):
|
|
|
# Move input images to the configured device (CPU/GPU)
|
|
|
inputs = batch["image"].to(self._device)
|
|
|
|
|
|
# Use automatic mixed precision if enabled in dataset setup
|
|
|
with torch.amp.autocast( # type: ignore
|
|
|
self._device.type,
|
|
|
enabled=self._dataset_setup.common.use_amp
|
|
|
):
|
|
|
# Disable gradient computation for inference
|
|
|
with torch.no_grad():
|
|
|
# Run the model’s forward pass in ‘predict’ mode
|
|
|
raw_output = self.__run_inference(inputs, mode="predict")
|
|
|
|
|
|
# Convert logits/probabilities to discrete instance masks
|
|
|
# ground_truth is not passed here; only predictions are needed
|
|
|
preds, _ = self.__post_process_predictions(raw_output)
|
|
|
|
|
|
# Save out the predicted masks, using batch_counter to index files
|
|
|
self.__save_prediction_masks(
|
|
|
sample=batch,
|
|
|
predicted_mask=preds,
|
|
|
start_index=batch_counter,
|
|
|
only_masks=only_masks
|
|
|
)
|
|
|
|
|
|
# Increment counter by batch size for unique file naming
|
|
|
batch_counter += inputs.shape[0]
|
|
|
|
|
|
|
|
|
def run(self, save_results: bool = True, only_masks: bool = False) -> None:
|
|
|
"""
|
|
|
Orchestrate the full workflow and report execution time:
|
|
|
- If training is enabled in the dataset setup, start training.
|
|
|
- Otherwise, if a test DataLoader is provided, run evaluation.
|
|
|
- Else if a prediction DataLoader is provided, run inference/prediction.
|
|
|
- If neither loader is available in non‐training mode, raise an error.
|
|
|
|
|
|
Args:
|
|
|
save_results (bool): If True, the predicted masks and test metrics will be saved.
|
|
|
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
|
|
|
without visualization overlays.
|
|
|
"""
|
|
|
start_time = time.time()
|
|
|
|
|
|
logger.info(
|
|
|
f"Masks saving: {'enabled' if save_results else 'disabled'}; "
|
|
|
f"Additional visualizations: "
|
|
|
f"{'enabled' if save_results and not only_masks else 'disabled'}"
|
|
|
)
|
|
|
|
|
|
# 1) TRAINING PATH
|
|
|
if self._dataset_setup.is_training:
|
|
|
# Launch the full training loop (with validation, scheduler steps, etc.)
|
|
|
self.train(save_results=save_results, only_masks=only_masks)
|
|
|
else:
|
|
|
# 2) NON-TRAINING PATH (TEST or PREDICT)
|
|
|
if self._test_dataloader is not None:
|
|
|
# Run a single evaluation epoch on the test set and log metrics
|
|
|
self.evaluate(save_results=save_results, only_masks=only_masks)
|
|
|
elif self._predict_dataloader is not None:
|
|
|
# Run inference on the predict set and save outputs
|
|
|
self.predict(only_masks=only_masks)
|
|
|
else:
|
|
|
# 3) ERROR: no appropriate loader found
|
|
|
raise RuntimeError(
|
|
|
"Neither test nor predict DataLoader is set for non‐training mode."
|
|
|
)
|
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
if elapsed < 60:
|
|
|
logger.info(f"Total execution time: {elapsed:.2f} seconds")
|
|
|
elif elapsed < 3600:
|
|
|
minutes = int(elapsed // 60)
|
|
|
seconds = elapsed % 60
|
|
|
logger.info(f"Total execution time: {minutes} min {seconds:.2f} sec")
|
|
|
else:
|
|
|
hours = int(elapsed // 3600)
|
|
|
minutes = int((elapsed % 3600) // 60)
|
|
|
seconds = elapsed % 60
|
|
|
logger.info(f"Total execution time: {hours} h {minutes} min {seconds:.2f} sec")
|
|
|
|
|
|
|
|
|
def load_from_checkpoint(self, checkpoint_path: str) -> None:
|
|
|
"""
|
|
|
Loads model weights from a specified checkpoint into the current model,
|
|
|
but only for parameters whose shapes match. Parameters with mismatched
|
|
|
shapes (e.g., classification heads with different output sizes) remain
|
|
|
at their initialized values.
|
|
|
|
|
|
Args:
|
|
|
checkpoint_path (str): Path to the checkpoint file containing the model weights.
|
|
|
"""
|
|
|
# Load the checkpoint (state_dict) from file onto CPU
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
|
|
# Extract nested state_dict if present
|
|
|
state_dict = checkpoint.get("state_dict", checkpoint)
|
|
|
|
|
|
# Get the current model's parameter dictionary
|
|
|
model_dict = self._model.state_dict()
|
|
|
|
|
|
# Filter pretrained parameters to those matching in name and shape
|
|
|
pretrained_dict = {
|
|
|
k: v for k, v in state_dict.items()
|
|
|
if k in model_dict and v.size() == model_dict[k].size()
|
|
|
}
|
|
|
|
|
|
# Log how many parameters are loaded, skipped, or missing
|
|
|
skipped = [k for k in state_dict if k not in pretrained_dict]
|
|
|
missing = [k for k in model_dict if k not in pretrained_dict]
|
|
|
logger.info(
|
|
|
f"Loaded {len(pretrained_dict)} parameters;"
|
|
|
f" skipped {len(skipped)} params from checkpoint;"
|
|
|
f" {len(missing)} params remain uninitialized in model."
|
|
|
)
|
|
|
|
|
|
# Update the model's state_dict and load it
|
|
|
model_dict.update(pretrained_dict)
|
|
|
self._model.load_state_dict(model_dict)
|
|
|
|
|
|
|
|
|
def save_checkpoint(self, checkpoint_path: str) -> None:
|
|
|
"""
|
|
|
Saves the current model weights to a checkpoint file.
|
|
|
|
|
|
Args:
|
|
|
checkpoint_path (str): Path where the checkpoint file will be saved.
|
|
|
"""
|
|
|
# Write the checkpoint to disk
|
|
|
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
|
|
torch.save((
|
|
|
self._model.state_dict()
|
|
|
if self._best_weights is None
|
|
|
else self._best_weights),
|
|
|
checkpoint_path
|
|
|
)
|
|
|
|
|
|
|
|
|
def __parse_config(self, config: Config) -> None:
|
|
|
"""
|
|
|
Parses the given configuration object to initialize model, criterion,
|
|
|
optimizer, scheduler, and dataset setup.
|
|
|
|
|
|
Args:
|
|
|
config (Config): Configuration object with model, optimizer,
|
|
|
scheduler, criterion, and dataset setup information.
|
|
|
"""
|
|
|
model = config.model
|
|
|
criterion = config.criterion
|
|
|
optimizer = config.optimizer
|
|
|
scheduler = config.scheduler
|
|
|
|
|
|
logger.info("========== Parsed Configuration ==========")
|
|
|
logger.info("Model Config:\n%s", pformat(model.dump(), indent=2))
|
|
|
if criterion:
|
|
|
logger.info("Criterion Config:\n%s", pformat(criterion.dump(), indent=2))
|
|
|
if optimizer:
|
|
|
logger.info("Optimizer Config:\n%s", pformat(optimizer.dump(), indent=2))
|
|
|
if scheduler:
|
|
|
logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2))
|
|
|
logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2))
|
|
|
logger.info("Wandb Config:\n%s", pformat(config.wandb_config.model_dump(), indent=2))
|
|
|
logger.info("==========================================")
|
|
|
|
|
|
# Initialize model using the model registry
|
|
|
self._model = ModelRegistry.get_model_class(model.name)(model.params)
|
|
|
|
|
|
# Loads model weights from a specified checkpoint
|
|
|
pretrained_weights = config.dataset_config.common.pretrained_weights
|
|
|
if pretrained_weights:
|
|
|
self.load_from_checkpoint(pretrained_weights)
|
|
|
logger.info(f"Loaded pre-trained weights from: {pretrained_weights}")
|
|
|
|
|
|
self._model = self._model.to(self._device)
|
|
|
|
|
|
# Initialize loss criterion if specified
|
|
|
self._criterion = (
|
|
|
CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params)
|
|
|
if criterion is not None
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
if hasattr(self._criterion, "num_classes"):
|
|
|
nc_model = self._model.num_classes
|
|
|
nc_crit = getattr(self._criterion, "num_classes")
|
|
|
if nc_model != nc_crit:
|
|
|
raise ValueError(
|
|
|
f"Number of classes mismatch: model.num_classes={nc_model} "
|
|
|
f"but criterion.num_classes={nc_crit}"
|
|
|
)
|
|
|
|
|
|
# Initialize optimizer if specified
|
|
|
self._optimizer = (
|
|
|
OptimizerRegistry.get_optimizer_class(optimizer.name)(
|
|
|
model_params=self._model.parameters(),
|
|
|
optim_params=optimizer.params
|
|
|
)
|
|
|
if optimizer is not None
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
# Initialize scheduler only if both scheduler and optimizer are defined
|
|
|
self._scheduler = (
|
|
|
SchedulerRegistry.get_scheduler_class(scheduler.name)(
|
|
|
optimizer=self._optimizer.optim,
|
|
|
params=scheduler.params
|
|
|
)
|
|
|
if scheduler is not None and self._optimizer is not None and self._optimizer.optim is not None
|
|
|
else None
|
|
|
)
|
|
|
|
|
|
logger.info("========== Model Components Initialization ==========")
|
|
|
logger.info("├─ Model: " + (f"{model.name}" if self._model else "Not specified"))
|
|
|
logger.info("├─ Criterion: " + (f"{criterion.name}" if self._criterion else "Not specified")) # type: ignore
|
|
|
logger.info("├─ Optimizer: " + (f"{optimizer.name}" if self._optimizer else "Not specified")) # type: ignore
|
|
|
logger.info("└─ Scheduler: " + (f"{scheduler.name}" if self._scheduler else "Not specified")) # type: ignore
|
|
|
logger.info("=====================================================")
|
|
|
|
|
|
|
|
|
# Save dataset config
|
|
|
self._dataset_setup = config.dataset_config
|
|
|
common = config.dataset_config.common
|
|
|
|
|
|
logger.info("========== Dataset Setup ==========")
|
|
|
logger.info("[COMMON]")
|
|
|
logger.info(f"├─ Seed: {common.seed}")
|
|
|
logger.info(f"├─ Device: {common.device}")
|
|
|
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
|
|
|
logger.info(f"├─ ROI size: {common.roi_size}")
|
|
|
logger.info(f"├─ Remove boundary objects: {'yes' if common.remove_boundary_objects else 'no'}")
|
|
|
logger.info(f"├─ Masks subdirectory: {common.masks_subdir}")
|
|
|
logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
|
|
|
logger.info(f"├─ Pretrained weights: {common.pretrained_weights or 'None'}")
|
|
|
|
|
|
if config.dataset_config.is_training:
|
|
|
training = config.dataset_config.training
|
|
|
logger.info("[MODE] Training")
|
|
|
logger.info(f"├─ Batch size: {training.batch_size}")
|
|
|
logger.info(f"├─ Epochs: {training.num_epochs}")
|
|
|
logger.info(f"├─ Validation frequency: {training.val_freq}")
|
|
|
|
|
|
if training.is_split:
|
|
|
logger.info( "├─ Using pre-split directories:")
|
|
|
logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}")
|
|
|
logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}")
|
|
|
logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}")
|
|
|
else:
|
|
|
logger.info( "├─ Using unified dataset with splits:")
|
|
|
logger.info( "│ ├─ All data dir: {training.split.all_data_dir}")
|
|
|
logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}")
|
|
|
|
|
|
logger.info( "└─ Dataset split:")
|
|
|
logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}")
|
|
|
logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}")
|
|
|
logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}")
|
|
|
|
|
|
else:
|
|
|
testing = config.dataset_config.testing
|
|
|
logger.info("[MODE] Inference")
|
|
|
logger.info(f"├─ Test dir: {testing.test_dir}")
|
|
|
logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})")
|
|
|
logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}")
|
|
|
|
|
|
self._wandb_config = config.wandb_config
|
|
|
if self._wandb_config.use_wandb:
|
|
|
logger.info("[W&B]")
|
|
|
logger.info(f"├─ Project: {self._wandb_config.project}")
|
|
|
if self._wandb_config.group:
|
|
|
logger.info(f"├─ Group: {self._wandb_config.group}")
|
|
|
if self._wandb_config.entity:
|
|
|
logger.info(f"├─ Entity: {self._wandb_config.entity}")
|
|
|
if self._wandb_config.name:
|
|
|
logger.info(f"├─ Run name: {self._wandb_config.name}")
|
|
|
if self._wandb_config.tags:
|
|
|
logger.info(f"├─ Tags: {', '.join(self._wandb_config.tags)}")
|
|
|
if self._wandb_config.notes:
|
|
|
logger.info(f"├─ Notes: {self._wandb_config.notes}")
|
|
|
logger.info(f"└─ Save code: {'yes' if self._wandb_config.save_code else 'no'}")
|
|
|
else:
|
|
|
logger.info("[W&B] Logging disabled")
|
|
|
|
|
|
logger.info("===================================")
|
|
|
|
|
|
|
|
|
def __set_seed(self, seed: int | None) -> None:
|
|
|
"""
|
|
|
Sets the random seed for reproducibility across Python, NumPy, and PyTorch.
|
|
|
|
|
|
Args:
|
|
|
seed (int | None): Seed value. If None, no seeding is performed.
|
|
|
"""
|
|
|
if seed is not None:
|
|
|
random.seed(seed)
|
|
|
np.random.seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
logger.info(f"Random seed set to {seed}")
|
|
|
else:
|
|
|
logger.info("Seed not set (None provided)")
|
|
|
|
|
|
|
|
|
def __get_dataset(
|
|
|
self,
|
|
|
images_dir: str,
|
|
|
masks_dir: str | None,
|
|
|
transforms: Compose,
|
|
|
size: int | float,
|
|
|
offset: int,
|
|
|
shuffle: bool
|
|
|
) -> Dataset:
|
|
|
"""
|
|
|
Loads and returns a dataset object from image and optional mask directories.
|
|
|
|
|
|
Args:
|
|
|
images_dir (str): Path to directory or glob pattern for input images.
|
|
|
masks_dir (str | None): Path to directory or glob pattern for masks.
|
|
|
transforms (Compose): Transformations to apply to each image or pair.
|
|
|
size (int | float): Either an integer or a fraction of the dataset.
|
|
|
offset (int): Number of images to skip from the start.
|
|
|
shuffle (bool): Whether to shuffle the dataset before slicing.
|
|
|
|
|
|
Returns:
|
|
|
Dataset: A dataset containing image and optional mask paths.
|
|
|
|
|
|
Raises:
|
|
|
FileNotFoundError: If no images are found.
|
|
|
ValueError: If masks are provided but do not match image count.
|
|
|
ValueError: If dataset is too small for requested size or offset.
|
|
|
"""
|
|
|
# Collect sorted list of image paths
|
|
|
images = sorted(
|
|
|
chain(glob.glob(
|
|
|
os.path.join(images_dir, f'*.{ext}')) for ext in SUPPORTED_IMAGE_FORMATS
|
|
|
)
|
|
|
)
|
|
|
if not images:
|
|
|
raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'")
|
|
|
|
|
|
if masks_dir is not None:
|
|
|
# Collect and validate sorted list of mask paths
|
|
|
masks = sorted(
|
|
|
chain(glob.glob(
|
|
|
os.path.join(masks_dir, f'*.{ext}')) for ext in SUPPORTED_IMAGE_FORMATS
|
|
|
)
|
|
|
)
|
|
|
if len(images) != len(masks):
|
|
|
raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})")
|
|
|
|
|
|
# Convert float size (fraction) to absolute count
|
|
|
size = size if isinstance(size, int) else int(size * len(images))
|
|
|
|
|
|
if size <= 0:
|
|
|
raise ValueError(f"Size must be positive, got: {size}")
|
|
|
|
|
|
if len(images) < size:
|
|
|
raise ValueError(f"Not enough images ({len(images)}) for requested size ({size})")
|
|
|
|
|
|
if len(images) < size + offset:
|
|
|
raise ValueError(f"Offset ({offset}) + size ({size}) exceeds dataset length ({len(images)})")
|
|
|
|
|
|
# Shuffle image-mask pairs if requested
|
|
|
if shuffle:
|
|
|
if masks_dir is not None:
|
|
|
combined = list(zip(images, masks)) # type: ignore
|
|
|
random.shuffle(combined)
|
|
|
images, masks = zip(*combined)
|
|
|
else:
|
|
|
random.shuffle(images)
|
|
|
|
|
|
# Apply offset and limit by size
|
|
|
images = images[offset: offset + size]
|
|
|
if masks_dir is not None:
|
|
|
masks = masks[offset: offset + size] # type: ignore
|
|
|
|
|
|
# Prepare data structure for Dataset class
|
|
|
if masks_dir is not None:
|
|
|
data = [
|
|
|
{"image": image, "mask": mask}
|
|
|
for image, mask in zip(images, masks) # type: ignore
|
|
|
]
|
|
|
else:
|
|
|
data = [{"image": image} for image in images]
|
|
|
|
|
|
return Dataset(data, transforms)
|
|
|
|
|
|
|
|
|
def __print_with_logging(self, metrics: dict[str, float | np.ndarray], step: int) -> None:
|
|
|
"""
|
|
|
Print metrics in a tabular format and log to W&B.
|
|
|
|
|
|
Args:
|
|
|
metrics (dict(str, float | np.ndarray)): Mapping from metric names
|
|
|
to either a float or a ND numpy array.
|
|
|
step (int): epoch index.
|
|
|
"""
|
|
|
rows: list[tuple[str, str]] = []
|
|
|
for key, val in metrics.items():
|
|
|
if isinstance(val, np.ndarray):
|
|
|
# Convert array to string, e.g. '[0.2, 0.8, 0.5]'
|
|
|
val_str = np.array2string(val, separator=', ')
|
|
|
else:
|
|
|
# Format scalar with 4 decimal places
|
|
|
val_str = f"{val:.4f}"
|
|
|
rows.append((key, val_str))
|
|
|
|
|
|
table = tabulate(
|
|
|
tabular_data=rows,
|
|
|
headers=["Metric", "Value"],
|
|
|
floatfmt=".4f",
|
|
|
tablefmt="fancy_grid"
|
|
|
)
|
|
|
print(table, "\n")
|
|
|
|
|
|
if self._wandb_config.use_wandb:
|
|
|
# Keep only scalar values
|
|
|
scalar_results: dict[str, float] = {}
|
|
|
for key, val in metrics.items():
|
|
|
if isinstance(val, np.ndarray):
|
|
|
continue
|
|
|
# Ensure float type
|
|
|
scalar_results[key] = float(val)
|
|
|
wandb.log(scalar_results, step=step)
|
|
|
|
|
|
|
|
|
def __save_metrics_to_csv(
|
|
|
self,
|
|
|
metrics: dict[str, float | np.ndarray],
|
|
|
output_path: str
|
|
|
) -> None:
|
|
|
"""
|
|
|
Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'.
|
|
|
|
|
|
Args:
|
|
|
metrics (dict(str, float | np.ndarray)):
|
|
|
Mapping from metric names to scalar values or numpy arrays.
|
|
|
output_path (str):
|
|
|
Path to the output CSV file.
|
|
|
"""
|
|
|
with open(output_path, mode='w', newline='') as csv_file:
|
|
|
writer = csv.writer(csv_file)
|
|
|
writer.writerow(['Metric', 'Value'])
|
|
|
for name, value in metrics.items():
|
|
|
# Convert numpy arrays to string representation
|
|
|
if isinstance(value, np.ndarray):
|
|
|
# Flatten and join with commas
|
|
|
flat = value.flatten()
|
|
|
val_str = ','.join([f"{v}" for v in flat])
|
|
|
else:
|
|
|
val_str = f"{value}"
|
|
|
writer.writerow([name, val_str])
|
|
|
|
|
|
|
|
|
def __run_epoch(self,
|
|
|
mode: Literal["train", "valid", "test"],
|
|
|
epoch: int | None = None,
|
|
|
save_results: bool = True,
|
|
|
only_masks: bool = False
|
|
|
) -> dict[str, float | np.ndarray]:
|
|
|
"""
|
|
|
Execute one epoch of training, validation, or testing.
|
|
|
|
|
|
Args:
|
|
|
mode (str): One of 'train', 'valid', or 'test'.
|
|
|
epoch (int | None): Current epoch number for logging.
|
|
|
save_results (bool): If True, the predicted masks and test metrics will be saved.
|
|
|
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
|
|
|
without visualization overlays.
|
|
|
|
|
|
Returns:
|
|
|
dict(str, float | np.ndarray): Metrics for valid/test.
|
|
|
"""
|
|
|
# Ensure required components are available
|
|
|
if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None):
|
|
|
raise RuntimeError("Optimizer and loss function must be initialized for train/valid.")
|
|
|
|
|
|
# Set model mode and choose the appropriate data loader
|
|
|
if mode == "train":
|
|
|
self._model.train()
|
|
|
loader = self._train_dataloader
|
|
|
else:
|
|
|
self._model.eval()
|
|
|
loader = self._valid_dataloader if mode == "valid" else self._test_dataloader
|
|
|
|
|
|
# Check that the data loader is provided
|
|
|
if loader is None:
|
|
|
raise RuntimeError(f"DataLoader for mode '{mode}' is not set.")
|
|
|
|
|
|
all_tp, all_fp, all_fn = [], [], []
|
|
|
|
|
|
# Prepare tqdm description with epoch info if available
|
|
|
if epoch is not None:
|
|
|
desc = f"Epoch {epoch}/{self._dataset_setup.training.num_epochs} [{mode}]"
|
|
|
else:
|
|
|
desc = f"Epoch ({mode})"
|
|
|
|
|
|
# Iterate over batches
|
|
|
batch_counter = 0
|
|
|
for batch in tqdm(loader, desc=desc):
|
|
|
inputs = batch["image"].to(self._device)
|
|
|
targets = batch["mask"].to(self._device)
|
|
|
|
|
|
# Zero gradients for training
|
|
|
if self._optimizer is not None:
|
|
|
self._optimizer.zero_grad()
|
|
|
|
|
|
# Mixed precision context if enabled
|
|
|
with torch.amp.autocast( # type: ignore
|
|
|
self._device.type,
|
|
|
enabled=self._dataset_setup.common.use_amp
|
|
|
):
|
|
|
# Only compute gradients in training phase
|
|
|
with torch.set_grad_enabled(mode == "train"):
|
|
|
# Forward pass
|
|
|
raw_output = self.__run_inference(inputs, mode)
|
|
|
|
|
|
if self._criterion is not None:
|
|
|
# Convert label masks to flow representations (one-hot)
|
|
|
flow_targets = self.__compute_flows_from_masks(targets)
|
|
|
|
|
|
# Compute loss for this batch
|
|
|
batch_loss = self._criterion(
|
|
|
raw_output,
|
|
|
torch.from_numpy(flow_targets).to(device=raw_output.device)
|
|
|
)
|
|
|
|
|
|
# Post-process and compute F1 during validation and testing
|
|
|
if mode in ("valid", "test"):
|
|
|
preds, labels_post = self.__post_process_predictions(
|
|
|
raw_output, targets
|
|
|
)
|
|
|
|
|
|
# Collecting statistics on the batch
|
|
|
tp, fp, fn, tp_masks, fp_masks, fn_masks = self.__compute_stats(
|
|
|
predicted_masks=preds,
|
|
|
ground_truth_masks=labels_post, # type: ignore
|
|
|
iou_threshold=0.5,
|
|
|
return_error_masks=(mode == "test") and save_results is True and not only_masks
|
|
|
)
|
|
|
all_tp.append(tp)
|
|
|
all_fp.append(fp)
|
|
|
all_fn.append(fn)
|
|
|
|
|
|
if mode == "test" and save_results is True:
|
|
|
masks = (tp_masks, fp_masks, fn_masks) if not only_masks else None
|
|
|
self.__save_prediction_masks(
|
|
|
sample=batch,
|
|
|
predicted_mask=preds,
|
|
|
start_index=batch_counter,
|
|
|
only_masks=only_masks,
|
|
|
masks=masks # type: ignore
|
|
|
)
|
|
|
|
|
|
# Backpropagation and optimizer step in training
|
|
|
if mode == "train":
|
|
|
if self._dataset_setup.common.use_amp and self._scaler is not None:
|
|
|
self._scaler.scale(batch_loss).backward() # type: ignore
|
|
|
self._scaler.unscale_(self._optimizer.optim) # type: ignore
|
|
|
self._scaler.step(self._optimizer.optim) # type: ignore
|
|
|
self._scaler.update()
|
|
|
else:
|
|
|
batch_loss.backward() # type: ignore
|
|
|
if self._optimizer is not None:
|
|
|
self._optimizer.step()
|
|
|
|
|
|
batch_counter += inputs.shape[0]
|
|
|
|
|
|
if self._criterion is not None:
|
|
|
# Collect loss metrics
|
|
|
epoch_metrics: dict[str, float | np.ndarray] = {
|
|
|
f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()
|
|
|
}
|
|
|
# Reset internal loss metrics accumulator
|
|
|
self._criterion.reset_metrics()
|
|
|
else:
|
|
|
epoch_metrics = {}
|
|
|
|
|
|
# Include F1 and mAP for validation and testing
|
|
|
if mode in ("valid", "test"):
|
|
|
# Concatenating by batch: shape (num_batches*B, C)
|
|
|
tp_array = np.vstack(all_tp)
|
|
|
fp_array = np.vstack(all_fp)
|
|
|
fn_array = np.vstack(all_fn)
|
|
|
|
|
|
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric(
|
|
|
tp_array, fp_array, fn_array, reduction="micro"
|
|
|
)
|
|
|
epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric(
|
|
|
tp_array, fp_array, fn_array, reduction="imagewise"
|
|
|
)
|
|
|
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric(
|
|
|
tp_array, fp_array, fn_array, reduction="macro"
|
|
|
)
|
|
|
|
|
|
return epoch_metrics
|
|
|
|
|
|
|
|
|
def __run_inference(
|
|
|
self,
|
|
|
inputs: torch.Tensor,
|
|
|
mode: Literal["train", "valid", "test", "predict"] = "train"
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Perform model inference for different stages.
|
|
|
|
|
|
Args:
|
|
|
inputs (torch.Tensor): Input tensor of shape (B, C, H, W).
|
|
|
stage (Literal[...]): One of "train", "valid", "test", "predict".
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: Model outputs tensor.
|
|
|
"""
|
|
|
if mode != "train":
|
|
|
# Use sliding window inference for non-training phases
|
|
|
outputs = sliding_window_inference(
|
|
|
inputs,
|
|
|
roi_size=self._dataset_setup.common.roi_size,
|
|
|
sw_batch_size=4,
|
|
|
predictor=self._model,
|
|
|
padding_mode="constant",
|
|
|
mode="gaussian",
|
|
|
overlap=0.5,
|
|
|
)
|
|
|
else:
|
|
|
# Direct forward pass during training
|
|
|
outputs = self._model(inputs)
|
|
|
return outputs # type: ignore
|
|
|
|
|
|
|
|
|
def __post_process_predictions(
|
|
|
self,
|
|
|
raw_outputs: torch.Tensor,
|
|
|
ground_truth: torch.Tensor | None = None
|
|
|
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
|
"""
|
|
|
Post-process raw network outputs to extract instance segmentation masks.
|
|
|
|
|
|
Args:
|
|
|
raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W).
|
|
|
ground_truth (torch.Tensor | None): Ground truth masks of shape (B, С, H, W).
|
|
|
|
|
|
Returns:
|
|
|
tuple(np.ndarray, np.ndarray | None):
|
|
|
- instance_masks: Instance-wise masks array of shape (B, С, H, W).
|
|
|
- labels_np: Converted ground truth of shape (B, С, H, W) or None if
|
|
|
ground_truth was not provided.
|
|
|
"""
|
|
|
# Move outputs to CPU and convert to numpy
|
|
|
outputs_np = raw_outputs.cpu().numpy()
|
|
|
# Split channels: gradient flows then class logits
|
|
|
gradflow = outputs_np[:, :2 * self._model.num_classes]
|
|
|
logits = outputs_np[:, -self._model.num_classes :]
|
|
|
# Apply sigmoid to logits to get probabilities
|
|
|
probabilities = self.__sigmoid(logits)
|
|
|
|
|
|
batch_size, _, height, width = probabilities.shape
|
|
|
# Prepare container for instance masks
|
|
|
instance_masks = np.zeros((batch_size, self._model.num_classes, height, width), dtype=np.uint16)
|
|
|
for idx in range(batch_size):
|
|
|
instance_masks[idx] = self.__segment_instances(
|
|
|
probability_map=probabilities[idx],
|
|
|
flow=gradflow[idx],
|
|
|
prob_threshold=0.5,
|
|
|
flow_threshold=0.4,
|
|
|
min_object_size=15
|
|
|
)
|
|
|
|
|
|
# Convert ground truth to numpy
|
|
|
labels_np = ground_truth.cpu().numpy() if ground_truth is not None else None
|
|
|
return instance_masks, labels_np
|
|
|
|
|
|
|
|
|
def __compute_stats(
|
|
|
self,
|
|
|
predicted_masks: np.ndarray,
|
|
|
ground_truth_masks: np.ndarray,
|
|
|
iou_threshold: float = 0.5,
|
|
|
return_error_masks: bool = False
|
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray,
|
|
|
np.ndarray | None, np.ndarray | None, np.ndarray | None]:
|
|
|
"""
|
|
|
Compute batch-wise true positives, false positives, and false negatives
|
|
|
for instance segmentation, using a configurable IoU threshold.
|
|
|
|
|
|
Args:
|
|
|
predicted_masks (np.ndarray): Predicted instance masks of shape (B, C, H, W).
|
|
|
ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W).
|
|
|
iou_threshold (float): Intersection-over-Union threshold for matching predictions
|
|
|
to ground truths (default: 0.5).
|
|
|
return_error_masks (bool): Whether to also return binary error masks.
|
|
|
|
|
|
Returns:
|
|
|
tuple(np.ndarray, np.ndarray, np.ndarray,
|
|
|
np.ndarray | None, np.ndarray | None, np.ndarray | None):
|
|
|
- tp: True positives per batch and class, shape (B, C)
|
|
|
- fp: False positives per batch and class, shape (B, C)
|
|
|
- fn: False negatives per batch and class, shape (B, C)
|
|
|
- tp_maks: True positives mask per batch and class, shape (B, C, H, W)
|
|
|
- fp_maks: False positives mask per batch and class, shape (B, C, H, W)
|
|
|
- fn_maks: False negatives mask per batch and class, shape (B, C, H, W)
|
|
|
"""
|
|
|
stats = compute_batch_segmentation_tp_fp_fn(
|
|
|
batch_ground_truth=ground_truth_masks,
|
|
|
batch_prediction=predicted_masks,
|
|
|
iou_threshold=iou_threshold,
|
|
|
return_error_masks=return_error_masks,
|
|
|
remove_boundary_objects=self._dataset_setup.common.remove_boundary_objects
|
|
|
)
|
|
|
tp = stats["tp"]
|
|
|
fp = stats["fp"]
|
|
|
fn = stats["fn"]
|
|
|
|
|
|
tp_mask = stats["tp_mask"] if return_error_masks else None
|
|
|
fp_mask = stats["fp_mask"] if return_error_masks else None
|
|
|
fn_mask = stats["fn_mask"] if return_error_masks else None
|
|
|
return tp, fp, fn, tp_mask, fp_mask, fn_mask
|
|
|
|
|
|
|
|
|
def __compute_f1_metric(
|
|
|
self,
|
|
|
true_positives: np.ndarray,
|
|
|
false_positives: np.ndarray,
|
|
|
false_negatives: np.ndarray,
|
|
|
reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro"
|
|
|
) -> float | np.ndarray:
|
|
|
"""
|
|
|
Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes.
|
|
|
|
|
|
Args:
|
|
|
true_positives: array of TP counts per sample and class.
|
|
|
false_positives: array of FP counts per sample and class.
|
|
|
false_negatives: array of FN counts per sample and class.
|
|
|
reduction:
|
|
|
- 'none': return F1 for each sample, class → shape (batch_size, num_classes)
|
|
|
- 'micro': global F1 over all samples & classes
|
|
|
- 'macro': average class-wise F1 (classes summed over batch)
|
|
|
- 'imagewise': F1 per sample (summing over classes), then average over samples
|
|
|
- 'per_class': F1 per class (summing over batch), return vector of shape (num_classes,)
|
|
|
- 'weighted': class-wise F1 weighted by support (TP+FN)
|
|
|
Returns:
|
|
|
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
|
|
|
or np.ndarray of shape (batch_size, num_classes) if reduction='none'.
|
|
|
"""
|
|
|
batch_size, num_classes = true_positives.shape
|
|
|
|
|
|
# 1) No reduction: compute F1 for each (sample, class)
|
|
|
if reduction == "none":
|
|
|
f1_matrix = np.zeros((batch_size, num_classes), dtype=float)
|
|
|
for i in range(batch_size):
|
|
|
for c in range(num_classes):
|
|
|
tp_val = int(true_positives[i, c])
|
|
|
fp_val = int(false_positives[i, c])
|
|
|
fn_val = int(false_negatives[i, c])
|
|
|
_, _, f1_val = compute_f1_score(
|
|
|
tp_val,
|
|
|
fp_val,
|
|
|
fn_val
|
|
|
)
|
|
|
f1_matrix[i, c] = f1_val
|
|
|
return f1_matrix
|
|
|
|
|
|
# 2) Micro: sum all TP/FP/FN and compute a single F1
|
|
|
if reduction == "micro":
|
|
|
tp_total = int(true_positives.sum())
|
|
|
fp_total = int(false_positives.sum())
|
|
|
fn_total = int(false_negatives.sum())
|
|
|
_, _, f1_global = compute_f1_score(
|
|
|
tp_total,
|
|
|
fp_total,
|
|
|
fn_total
|
|
|
)
|
|
|
return f1_global
|
|
|
|
|
|
# 3) Imagewise: compute per-sample F1 (sum over classes), then average
|
|
|
if reduction == "imagewise":
|
|
|
f1_per_image = np.zeros(batch_size, dtype=float)
|
|
|
for i in range(batch_size):
|
|
|
tp_i = int(true_positives[i].sum())
|
|
|
fp_i = int(false_positives[i].sum())
|
|
|
fn_i = int(false_negatives[i].sum())
|
|
|
_, _, f1_i = compute_f1_score(
|
|
|
tp_i,
|
|
|
fp_i,
|
|
|
fn_i
|
|
|
)
|
|
|
f1_per_image[i] = f1_i
|
|
|
return float(f1_per_image.mean())
|
|
|
|
|
|
# Aggregate per class across the batch for per_class, macro, weighted
|
|
|
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
|
|
|
fp_per_class = false_positives.sum(axis=0).astype(int)
|
|
|
fn_per_class = false_negatives.sum(axis=0).astype(int)
|
|
|
|
|
|
# 4) Per-class: compute F1 for each class and return vector
|
|
|
if reduction == "per_class":
|
|
|
f1_per_class = np.zeros(num_classes, dtype=float)
|
|
|
for c in range(num_classes):
|
|
|
_, _, f1_per_class[c] = compute_f1_score(
|
|
|
tp_per_class[c],
|
|
|
fp_per_class[c],
|
|
|
fn_per_class[c]
|
|
|
)
|
|
|
return f1_per_class
|
|
|
|
|
|
# 5) Macro: average F1 across classes equally
|
|
|
if reduction == "macro":
|
|
|
f1_per_class = np.zeros(num_classes, dtype=float)
|
|
|
for c in range(num_classes):
|
|
|
_, _, f1_c = compute_f1_score(
|
|
|
tp_per_class[c],
|
|
|
fp_per_class[c],
|
|
|
fn_per_class[c]
|
|
|
)
|
|
|
f1_per_class[c] = f1_c
|
|
|
return float(f1_per_class.mean())
|
|
|
|
|
|
# 6) Weighted: class-wise F1 weighted by support = TP + FN
|
|
|
if reduction == "weighted":
|
|
|
f1_per_class = np.zeros(num_classes, dtype=float)
|
|
|
support = np.zeros(num_classes, dtype=float)
|
|
|
for c in range(num_classes):
|
|
|
tp_c = tp_per_class[c]
|
|
|
fp_c = fp_per_class[c]
|
|
|
fn_c = fn_per_class[c]
|
|
|
_, _, f1_c = compute_f1_score(
|
|
|
tp_c,
|
|
|
fp_c,
|
|
|
fn_c
|
|
|
)
|
|
|
f1_per_class[c] = f1_c
|
|
|
support[c] = tp_c + fn_c
|
|
|
total_support = support.sum()
|
|
|
if total_support == 0:
|
|
|
# fallback to unweighted macro if no positives
|
|
|
return float(f1_per_class.mean())
|
|
|
weights = support / total_support
|
|
|
return float((f1_per_class * weights).sum())
|
|
|
|
|
|
raise ValueError(f"Unknown reduction mode: {reduction}")
|
|
|
|
|
|
|
|
|
def __compute_average_precision_metric(
|
|
|
self,
|
|
|
true_positives: np.ndarray,
|
|
|
false_positives: np.ndarray,
|
|
|
false_negatives: np.ndarray,
|
|
|
reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro"
|
|
|
) -> float | np.ndarray:
|
|
|
"""
|
|
|
Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes.
|
|
|
|
|
|
AP is defined here as:
|
|
|
AP = TP / (TP + FP + FN)
|
|
|
|
|
|
Args:
|
|
|
true_positives: array of true positives per sample and class.
|
|
|
false_positives: array of false positives per sample and class.
|
|
|
false_negatives: array of false negatives per sample and class.
|
|
|
reduction:
|
|
|
- 'none': return AP for each sample and class → shape (batch_size, num_classes)
|
|
|
- 'micro': global AP over all samples & classes
|
|
|
- 'macro': average class-wise AP (each class summed over batch)
|
|
|
- 'imagewise': AP per sample (summing stats over classes), then average over batch
|
|
|
- 'per_class': AP per class (summing over batch), return vector of shape (num_classes,)
|
|
|
- 'weighted': class-wise AP weighted by support (TP+FN)
|
|
|
|
|
|
Returns:
|
|
|
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
|
|
|
or np.ndarray of shape (batch_size, num_classes) if reduction='none'.
|
|
|
"""
|
|
|
batch_size, num_classes = true_positives.shape
|
|
|
|
|
|
# 1) No reduction: AP per (sample, class)
|
|
|
if reduction == "none":
|
|
|
ap_matrix = np.zeros((batch_size, num_classes), dtype=float)
|
|
|
for i in range(batch_size):
|
|
|
for c in range(num_classes):
|
|
|
tp_val = int(true_positives[i, c])
|
|
|
fp_val = int(false_positives[i, c])
|
|
|
fn_val = int(false_negatives[i, c])
|
|
|
ap_val = compute_average_precision_score(
|
|
|
tp_val,
|
|
|
fp_val,
|
|
|
fn_val
|
|
|
)
|
|
|
ap_matrix[i, c] = ap_val
|
|
|
return ap_matrix
|
|
|
|
|
|
# 2) Micro: sum all TP/FP/FN and compute one AP
|
|
|
if reduction == "micro":
|
|
|
tp_total = int(true_positives.sum())
|
|
|
fp_total = int(false_positives.sum())
|
|
|
fn_total = int(false_negatives.sum())
|
|
|
return compute_average_precision_score(
|
|
|
tp_total,
|
|
|
fp_total,
|
|
|
fn_total
|
|
|
)
|
|
|
|
|
|
# 3) Imagewise: compute per-sample AP (sum over classes), then mean
|
|
|
if reduction == "imagewise":
|
|
|
ap_per_image = np.zeros(batch_size, dtype=float)
|
|
|
for i in range(batch_size):
|
|
|
tp_i = int(true_positives[i].sum())
|
|
|
fp_i = int(false_positives[i].sum())
|
|
|
fn_i = int(false_negatives[i].sum())
|
|
|
ap_per_image[i] = compute_average_precision_score(
|
|
|
tp_i,
|
|
|
fp_i,
|
|
|
fn_i
|
|
|
)
|
|
|
return float(ap_per_image.mean())
|
|
|
|
|
|
# For macro and weighted: first aggregate per class across batch
|
|
|
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
|
|
|
fp_per_class = false_positives.sum(axis=0).astype(int)
|
|
|
fn_per_class = false_negatives.sum(axis=0).astype(int)
|
|
|
|
|
|
# 4) Per-class: compute F1 for each class and return vector
|
|
|
if reduction == "per_class":
|
|
|
ap_per_class = np.zeros(num_classes, dtype=float)
|
|
|
for c in range(num_classes):
|
|
|
ap_per_class[c] = compute_average_precision_score(
|
|
|
tp_per_class[c],
|
|
|
fp_per_class[c],
|
|
|
fn_per_class[c]
|
|
|
)
|
|
|
return ap_per_class
|
|
|
|
|
|
# 5) Macro: average AP across classes equally
|
|
|
if reduction == "macro":
|
|
|
ap_per_class = np.zeros(num_classes, dtype=float)
|
|
|
for c in range(num_classes):
|
|
|
ap_per_class[c] = compute_average_precision_score(
|
|
|
tp_per_class[c],
|
|
|
fp_per_class[c],
|
|
|
fn_per_class[c]
|
|
|
)
|
|
|
return float(ap_per_class.mean())
|
|
|
|
|
|
# 6) Weighted: class-wise AP weighted by support = TP + FN
|
|
|
if reduction == "weighted":
|
|
|
ap_per_class = np.zeros(num_classes, dtype=float)
|
|
|
support = np.zeros(num_classes, dtype=float)
|
|
|
for c in range(num_classes):
|
|
|
tp_c = tp_per_class[c]
|
|
|
fp_c = fp_per_class[c]
|
|
|
fn_c = fn_per_class[c]
|
|
|
ap_per_class[c] = compute_average_precision_score(
|
|
|
tp_c,
|
|
|
fp_c,
|
|
|
fn_c
|
|
|
)
|
|
|
support[c] = tp_c + fn_c
|
|
|
total_support = support.sum()
|
|
|
if total_support == 0:
|
|
|
# fallback to unweighted macro if no positive instances
|
|
|
return float(ap_per_class.mean())
|
|
|
weights = support / total_support
|
|
|
return float((ap_per_class * weights).sum())
|
|
|
|
|
|
raise ValueError(f"Unknown reduction mode: {reduction}")
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
def __sigmoid(z: np.ndarray) -> np.ndarray:
|
|
|
"""
|
|
|
Numerically stable sigmoid activation for numpy arrays.
|
|
|
|
|
|
Args:
|
|
|
z (np.ndarray): Input array.
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: Sigmoid of the input.
|
|
|
"""
|
|
|
return expit(z)
|
|
|
|
|
|
|
|
|
def __save_prediction_masks(
|
|
|
self,
|
|
|
sample: dict[str, Any],
|
|
|
predicted_mask: np.ndarray | torch.Tensor,
|
|
|
start_index: int = 0,
|
|
|
only_masks: bool = False,
|
|
|
masks: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None
|
|
|
) -> None:
|
|
|
"""
|
|
|
Save multi-channel predicted masks as TIFFs and
|
|
|
corresponding visualizations as PNGs in separate folders.
|
|
|
|
|
|
Args:
|
|
|
sample (dict(str, Any)): Batch sample from MONAI
|
|
|
LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
|
|
|
predicted_mask (np.ndarray | torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
|
|
|
start_index (int): Starting index for naming when metadata is missing.
|
|
|
only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations.
|
|
|
masks (tuple[np.ndarray, np.ndarray, np.ndarray] | None):
|
|
|
A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None.
|
|
|
"""
|
|
|
# Base directories (created once per call)
|
|
|
base_output_dir = self._dataset_setup.common.predictions_dir
|
|
|
masks_dir = base_output_dir
|
|
|
plots_dir = os.path.join(base_output_dir, "plots")
|
|
|
evaluate_dir = os.path.join(plots_dir, "evaluate")
|
|
|
os.makedirs(masks_dir, exist_ok=True)
|
|
|
os.makedirs(plots_dir, exist_ok=True)
|
|
|
os.makedirs(evaluate_dir, exist_ok=True)
|
|
|
|
|
|
# Convert tensors to numpy if necessary
|
|
|
def to_numpy(x: np.ndarray | torch.Tensor) -> np.ndarray:
|
|
|
return x.cpu().numpy() if isinstance(x, torch.Tensor) else x
|
|
|
|
|
|
pred_array = to_numpy(predicted_mask).astype(np.uint16)
|
|
|
|
|
|
# Handle batch dimension
|
|
|
for idx in range(pred_array.shape[0]):
|
|
|
batch_sample: dict[str, Any] = {}
|
|
|
# copy per-sample image and meta
|
|
|
img = to_numpy(sample["image"])
|
|
|
if img.ndim == 4:
|
|
|
batch_sample["image"] = img[idx]
|
|
|
if "mask" in sample:
|
|
|
msk = to_numpy(sample["mask"]).astype(np.uint16)
|
|
|
if msk.ndim == 4:
|
|
|
batch_sample["mask"] = msk[idx]
|
|
|
|
|
|
image_meta = sample.get("image_meta_dict")
|
|
|
if isinstance(image_meta, dict) and "filename_or_obj" in image_meta:
|
|
|
fname = image_meta["filename_or_obj"][idx]
|
|
|
batch_sample["image_name"] = fname
|
|
|
|
|
|
single_masks = (
|
|
|
(masks[0][idx], masks[1][idx], masks[2][idx]) if masks is not None else None
|
|
|
)
|
|
|
self.__save_single_prediction_mask(
|
|
|
sample=batch_sample,
|
|
|
pred_array=pred_array[idx],
|
|
|
start_index=start_index + idx,
|
|
|
masks_dir=masks_dir,
|
|
|
plots_dir=plots_dir,
|
|
|
evaluate_dir=evaluate_dir,
|
|
|
only_masks=only_masks,
|
|
|
masks=single_masks,
|
|
|
)
|
|
|
|
|
|
|
|
|
def __save_single_prediction_mask(
|
|
|
self,
|
|
|
sample: dict[str, Any],
|
|
|
pred_array: np.ndarray,
|
|
|
start_index: int,
|
|
|
masks_dir: str,
|
|
|
plots_dir: str,
|
|
|
evaluate_dir: str,
|
|
|
only_masks: bool = False,
|
|
|
masks: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None
|
|
|
) -> None:
|
|
|
"""
|
|
|
Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations.
|
|
|
Assumes output directories already exist.
|
|
|
|
|
|
Args:
|
|
|
sample (dict(str, Any)): Dictionary containing 'image', 'mask',
|
|
|
and optional 'image_meta_dict' for metadata.
|
|
|
pred_array (np.ndarray): Predicted mask array of shape (C,H,W).
|
|
|
start_index (int): Base index for generating filenames when metadata is missing.
|
|
|
masks_dir (str): Directory for saving TIFF masks.
|
|
|
plots_dir (str): Directory for saving PNG visualizations.
|
|
|
evaluate_dir (str): Directory for saving PNG visualizations of evaluation results.
|
|
|
only_masks (bool): If True, saves only TIFF mask files; skips PNG plots.
|
|
|
masks (tuple[np.ndarray, np.ndarray, np.ndarray] | None): A tuple of
|
|
|
true-positive, false-positive, and false-negative mask arrays,
|
|
|
each of shape (C,H,W). Defaults to None.
|
|
|
"""
|
|
|
if pred_array.ndim == 2:
|
|
|
pred_array = np.expand_dims(pred_array, axis=0)
|
|
|
elif pred_array.ndim != 3:
|
|
|
raise ValueError(
|
|
|
f"Unsupported predicted_mask dimensions: {pred_array.ndim}."
|
|
|
"Expected 2D (H,W) or 3D (C,H,W)."
|
|
|
)
|
|
|
# Handle image array if present
|
|
|
image_array: np.ndarray = sample["image"]
|
|
|
if image_array.ndim == 2:
|
|
|
image_array = np.expand_dims(image_array, axis=0)
|
|
|
elif image_array.ndim != 3:
|
|
|
raise ValueError(
|
|
|
f"Unsupported image dimensions: {image_array.ndim}."
|
|
|
"Expected 2D (H,W) or 3D (C,H,W)."
|
|
|
)
|
|
|
|
|
|
true_mask_array: np.ndarray | None = sample.get("mask")
|
|
|
if isinstance(true_mask_array, np.ndarray):
|
|
|
if true_mask_array.ndim == 2:
|
|
|
true_mask_array = np.expand_dims(true_mask_array, axis=0)
|
|
|
elif true_mask_array.ndim != 3:
|
|
|
raise ValueError(
|
|
|
f"Unsupported true_mask_array dimensions: {true_mask_array.ndim}."
|
|
|
"Expected 2D (H,W) or 3D (C,H,W)."
|
|
|
)
|
|
|
|
|
|
# Determine filename base
|
|
|
image_meta = sample.get("image_name")
|
|
|
if isinstance(image_meta, (str, os.PathLike)):
|
|
|
base_name = os.path.splitext(os.path.basename(image_meta))[0]
|
|
|
else:
|
|
|
base_name = f"prediction_{start_index:04d}"
|
|
|
|
|
|
# Save main mask TIFF
|
|
|
mask_path = os.path.join(masks_dir, f"{base_name}_mask.tif")
|
|
|
tiff.imwrite(mask_path, pred_array.astype(np.uint16), compression="zlib")
|
|
|
|
|
|
if only_masks:
|
|
|
return
|
|
|
|
|
|
# Save channel-wise plots
|
|
|
num_channels = pred_array.shape[0]
|
|
|
for ch in range(num_channels):
|
|
|
true_ch = true_mask_array[ch] if true_mask_array is not None else None
|
|
|
|
|
|
self.__plot_mask(
|
|
|
file_path=os.path.join(plots_dir, f"{base_name}_ch{ch}.png"),
|
|
|
image_data=image_array,
|
|
|
predicted_mask=pred_array[ch],
|
|
|
true_mask=true_ch,
|
|
|
)
|
|
|
|
|
|
if masks is not None and true_ch is not None:
|
|
|
self.__save_mask_comparison_visuals(
|
|
|
gt=true_ch,
|
|
|
pred=pred_array[ch],
|
|
|
tp_mask=masks[0][ch],
|
|
|
fp_mask=masks[1][ch],
|
|
|
fn_mask=masks[2][ch],
|
|
|
file_path=os.path.join(evaluate_dir, f"{base_name}_ch{ch}.png")
|
|
|
)
|
|
|
|
|
|
|
|
|
def __plot_mask(
|
|
|
self,
|
|
|
file_path: str,
|
|
|
image_data: np.ndarray,
|
|
|
predicted_mask: np.ndarray,
|
|
|
true_mask: np.ndarray | None = None,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided.
|
|
|
|
|
|
Args:
|
|
|
file_path (str): Path where the visualization image will be saved.
|
|
|
image_data (np.ndarray): The original input image array, expected shape (C, H, W).
|
|
|
predicted_mask (np.ndarray): The predicted mask array, shape (H, W),
|
|
|
depending on the task.
|
|
|
true_mask (np.ndarray | None): The ground-truth mask array.
|
|
|
If provided, an additional row with true mask and overlap visualization
|
|
|
will be added to the plot. Default is None.
|
|
|
|
|
|
"""
|
|
|
img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data
|
|
|
|
|
|
if true_mask is None:
|
|
|
fig, axs = plt.subplots(1, 3, figsize=(15,5))
|
|
|
plt.subplots_adjust(wspace=0.02, hspace=0)
|
|
|
self.__plot_panels(axs, img, predicted_mask, 'red',
|
|
|
('Original Image','Predicted Mask','Predicted Contours'))
|
|
|
else:
|
|
|
fig, axs = plt.subplots(2,3,figsize=(15,10))
|
|
|
plt.subplots_adjust(wspace=0.02, hspace=0.02)
|
|
|
# row 0: predicted
|
|
|
self.__plot_panels(axs[0], img, predicted_mask, 'red',
|
|
|
('Original Image','Predicted Mask','Predicted Contours'))
|
|
|
# row 1: true
|
|
|
self.__plot_panels(axs[1], img, true_mask, 'blue',
|
|
|
('Original Image','True Mask','True Contours'))
|
|
|
fig.savefig(file_path, bbox_inches='tight', dpi=300)
|
|
|
plt.close(fig)
|
|
|
|
|
|
|
|
|
def __plot_panels(
|
|
|
self,
|
|
|
axes,
|
|
|
img: np.ndarray,
|
|
|
mask: np.ndarray,
|
|
|
contour_color: str,
|
|
|
titles: tuple[str, ...]
|
|
|
):
|
|
|
"""
|
|
|
Plot a row of three panels: original image, mask, and mask boundaries on image.
|
|
|
|
|
|
Args:
|
|
|
axes: list/array of three Axis objects.
|
|
|
img: Image array (H, W or H, W, C).
|
|
|
mask: Label mask (H, W).
|
|
|
contour_color: Color for boundaries.
|
|
|
titles: (title_img, title_mask, title_contours).
|
|
|
"""
|
|
|
# Panel 1: Original image
|
|
|
ax0, ax1, ax2 = axes
|
|
|
ax0.imshow(img, cmap='gray' if img.ndim == 2 else None)
|
|
|
ax0.set_title(titles[0])
|
|
|
ax0.axis('off')
|
|
|
|
|
|
# Compute boundaries once
|
|
|
boundaries = find_boundaries(mask, mode='thick')
|
|
|
|
|
|
# Panel 2: Mask with black boundaries
|
|
|
cmap = plt.get_cmap("gist_ncar")
|
|
|
cmap = mcolors.ListedColormap([
|
|
|
cmap(i/len(np.unique(mask))) for i in range(len(np.unique(mask)))
|
|
|
])
|
|
|
ax1.imshow(mask, cmap=cmap)
|
|
|
ax1.contour(boundaries, colors='black', linewidths=0.5)
|
|
|
ax1.set_title(titles[1])
|
|
|
ax1.axis('off')
|
|
|
|
|
|
# Panel 3: Original image with black contour overlay
|
|
|
ax2.imshow(img)
|
|
|
# Draw boundaries as contour lines
|
|
|
ax2.contour(boundaries, colors=contour_color, linewidths=0.5)
|
|
|
ax2.set_title(titles[2])
|
|
|
ax2.axis('off')
|
|
|
|
|
|
|
|
|
def __save_mask_comparison_visuals(
|
|
|
self,
|
|
|
gt: np.ndarray,
|
|
|
pred: np.ndarray,
|
|
|
tp_mask: np.ndarray,
|
|
|
fp_mask: np.ndarray,
|
|
|
fn_mask: np.ndarray,
|
|
|
file_path: str
|
|
|
) -> None:
|
|
|
"""
|
|
|
Creates and saves a 1x3 subplot figure showing:
|
|
|
1) True mask with boundaries
|
|
|
2) Predicted mask without boundaries
|
|
|
3) Overlay mask combining FP (R), TP (G), FN (B)
|
|
|
|
|
|
Args:
|
|
|
gt (np.ndarray): Ground truth mask (H, W).
|
|
|
pred (np.ndarray): Predicted mask (H, W).
|
|
|
tp_mask (np.ndarray): True positive mask (H, W).
|
|
|
fp_mask (np.ndarray): False positive mask (H, W).
|
|
|
fn_mask (np.ndarray): False negative mask (H, W).
|
|
|
file_path (str): Path where the visualization image will be saved.
|
|
|
"""
|
|
|
# Prepare overlay mask
|
|
|
overlap_mask = np.zeros((*gt.shape[:2], 3), dtype=np.uint8)
|
|
|
overlap_mask[..., 0] = np.where(fp_mask, 255, 0)
|
|
|
overlap_mask[..., 1] = np.where(tp_mask, 255, 0)
|
|
|
overlap_mask[..., 2] = np.where(fn_mask, 255, 0)
|
|
|
|
|
|
# Set up figure
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5),
|
|
|
gridspec_kw={'width_ratios': [1, 1, 1]})
|
|
|
plt.subplots_adjust(wspace=0.02, hspace=0.0,
|
|
|
left=0.05, right=0.95, top=0.95, bottom=0.05)
|
|
|
|
|
|
# Colormap for instances
|
|
|
num_instances = max(np.max(gt), np.max(pred))
|
|
|
cmap = plt.get_cmap("gist_ncar")
|
|
|
colors = [cmap(i / num_instances) for i in range(num_instances)]
|
|
|
cmap = mcolors.ListedColormap(colors)
|
|
|
|
|
|
# Plot true mask
|
|
|
axes[0].imshow(gt, cmap=cmap)
|
|
|
axes[0].contour(find_boundaries(gt, mode="thick"), colors="black", linewidths=0.5)
|
|
|
axes[0].set_title("True Mask")
|
|
|
axes[0].axis("off")
|
|
|
|
|
|
# Plot predicted mask
|
|
|
axes[1].imshow(pred, cmap=cmap)
|
|
|
axes[1].contour(find_boundaries(pred, mode="thick"), colors="black", linewidths=0.5)
|
|
|
axes[1].set_title("Predicted Mask")
|
|
|
axes[1].axis("off")
|
|
|
|
|
|
# Plot overlay
|
|
|
axes[2].imshow(overlap_mask)
|
|
|
axes[2].set_title("Overlay Mask (R-FP; G-TP; B-FN)")
|
|
|
axes[2].axis("off")
|
|
|
|
|
|
# Save
|
|
|
plt.savefig(file_path, bbox_inches="tight", dpi=300)
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
def __compute_flows_from_masks(
|
|
|
self,
|
|
|
true_masks: Tensor
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Convert segmentation masks to flow fields for training.
|
|
|
|
|
|
Args:
|
|
|
true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks.
|
|
|
|
|
|
Returns:
|
|
|
numpy array of concatenated [flow_vectors, renumbered_true_masks] per image.
|
|
|
renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow.
|
|
|
"""
|
|
|
# Move to CPU numpy
|
|
|
_true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16)
|
|
|
batch_size = _true_masks.shape[0]
|
|
|
|
|
|
# Ensure each label has a channel dimension
|
|
|
if _true_masks.ndim == 3:
|
|
|
# shape (batch, H, W) -> (batch, 1, H, W)
|
|
|
_true_masks = _true_masks[:, np.newaxis, :, :]
|
|
|
|
|
|
batch_size, *_ = _true_masks.shape
|
|
|
|
|
|
# Renumber labels to ensure uniqueness
|
|
|
renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0]
|
|
|
for i in range(batch_size)])
|
|
|
# Compute vector flows per image
|
|
|
flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i])
|
|
|
for i in range(batch_size)])
|
|
|
|
|
|
return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32)
|
|
|
|
|
|
|
|
|
def __compute_flow_from_mask(
|
|
|
self,
|
|
|
mask: np.ndarray
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Compute normalized flow vectors from a labeled mask.
|
|
|
|
|
|
Args:
|
|
|
mask: 3D array of instance-labeled mask of shape (C, H, W).
|
|
|
|
|
|
Returns:
|
|
|
flow: Array of shape (2 * C, H, W).
|
|
|
"""
|
|
|
if mask.max() == 0 or np.count_nonzero(mask) <= 1:
|
|
|
# No flow to compute
|
|
|
logger.warning("Empty mask!")
|
|
|
C, H, W = mask.shape
|
|
|
return np.zeros((2*C, H, W), dtype=np.float32)
|
|
|
|
|
|
# Delegate to GPU or CPU routine
|
|
|
if self._device.type == "cuda" or self._device.type == "mps":
|
|
|
return self.__mask_to_flow_gpu(mask)
|
|
|
else:
|
|
|
return self.__mask_to_flow_cpu(mask)
|
|
|
|
|
|
|
|
|
def __mask_to_flow_gpu(self, mask: np.ndarray) -> np.ndarray:
|
|
|
"""Convert masks to flows using diffusion from center pixel.
|
|
|
|
|
|
Center of masks where diffusion starts is defined by pixel closest to median within the mask.
|
|
|
|
|
|
Args:
|
|
|
masks (3D array): Labelled masks of shape (C, H, W).
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: A 3D array where for each channel the flows for each pixel
|
|
|
are represented along the X and Y axes.
|
|
|
"""
|
|
|
|
|
|
channels, height, width = mask.shape
|
|
|
flows = np.zeros((2*channels, height, width), np.float32)
|
|
|
|
|
|
for channel in range(channels):
|
|
|
mask_channel = mask[channel]
|
|
|
|
|
|
if mask_channel.max() > 0:
|
|
|
padded_height, padded_width = height + 2, width + 2
|
|
|
|
|
|
# Pad the mask with a 1-pixel border
|
|
|
masks_padded = torch.from_numpy(mask_channel.astype(np.int64)).to(self._device)
|
|
|
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
|
|
|
|
|
|
# Get coordinates of all non-zero pixels in the padded mask
|
|
|
y, x = torch.nonzero(masks_padded, as_tuple=True)
|
|
|
y = y.int()
|
|
|
x = x.int() # ensure integer type
|
|
|
|
|
|
# Generate 8-connected neighbors (including center) via broadcasted offsets
|
|
|
offsets = torch.tensor([
|
|
|
[ 0, 0], # center
|
|
|
[-1, 0], # up
|
|
|
[ 1, 0], # down
|
|
|
[ 0, -1], # left
|
|
|
[ 0, 1], # right
|
|
|
[-1, -1], # up-left
|
|
|
[-1, 1], # up-right
|
|
|
[ 1, -1], # down-left
|
|
|
[ 1, 1], # down-right
|
|
|
], dtype=torch.int32, device=self._device) # (9, 2)
|
|
|
|
|
|
# coords: (N, 2)
|
|
|
coords = torch.stack((y, x), dim=1)
|
|
|
|
|
|
# neighbors: (9, N, 2)
|
|
|
neighbors = offsets[:, None, :] + coords[None, :, :]
|
|
|
|
|
|
# transpose into (2, 9, N) for the GPU kernel
|
|
|
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
|
|
|
|
|
|
# Build connectivity mask: True where neighbor label == center label
|
|
|
center_labels = masks_padded[y, x][None, :] # (1, N)
|
|
|
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
|
|
|
is_neighbor = neighbor_labels == center_labels # (9, N)
|
|
|
|
|
|
# Compute object slices and pack into array for get_centers
|
|
|
slices = find_objects(mask_channel)
|
|
|
slices_arr = np.array([
|
|
|
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
|
|
|
for i, sl in enumerate(slices, start=1) if sl is not None
|
|
|
], dtype=np.int16)
|
|
|
|
|
|
# Compute centers (pixel indices) and extents via the provided helper
|
|
|
centers, ext = self.__get_mask_centers_and_extents(
|
|
|
mask_channel, slices_arr
|
|
|
)
|
|
|
# Move centers to GPU and shift by +1 for padding
|
|
|
# (M, 2); +1 for padding
|
|
|
meds_p = torch.from_numpy(centers).to(self._device).long() + 1
|
|
|
|
|
|
# Determine number of diffusion iterations
|
|
|
n_iter = 2 * ext.max()
|
|
|
|
|
|
# Run the GPU diffusion kernel
|
|
|
mu = self.__propagate_centers_gpu(
|
|
|
neighbor_indices=neighbors,
|
|
|
center_indices=meds_p.T,
|
|
|
valid_neighbor_mask=is_neighbor,
|
|
|
output_shape=(padded_height, padded_width),
|
|
|
num_iterations=n_iter
|
|
|
)
|
|
|
|
|
|
# Cast to float64 and normalize flow vectors
|
|
|
mu = mu.astype(np.float64)
|
|
|
mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60
|
|
|
|
|
|
# Remove the padding and write into final output
|
|
|
flow_output = np.zeros((2, height, width), dtype=np.float32)
|
|
|
ys_np = y.cpu().numpy() - 1
|
|
|
xs_np = x.cpu().numpy() - 1
|
|
|
flow_output[:, ys_np, xs_np] = mu.reshape(2, -1)
|
|
|
flows[2*channel: 2*channel + 2] = flow_output
|
|
|
|
|
|
return flows
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
@njit(nogil=True)
|
|
|
def __get_mask_centers_and_extents(
|
|
|
label_map: np.ndarray,
|
|
|
slices_arr: np.ndarray
|
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
|
"""
|
|
|
Compute the centroids and extents of labeled regions in a 2D mask array.
|
|
|
|
|
|
Args:
|
|
|
label_map (np.ndarray): 2D array where each connected region has a unique integer label (1…K).
|
|
|
slices_arr (np.ndarray): Array of shape (K, 5), where each row is
|
|
|
(label_id, row_start, row_stop, col_start, col_stop).
|
|
|
|
|
|
Returns:
|
|
|
centers (np.ndarray): Integer array of shape (K, 2) with (row, col) center for each label.
|
|
|
extents (np.ndarray): Integer array of shape (K,) giving the sum of height and width + 2 for each region.
|
|
|
"""
|
|
|
num_regions = slices_arr.shape[0]
|
|
|
centers = np.zeros((num_regions, 2), dtype=np.int32)
|
|
|
extents = np.zeros(num_regions, dtype=np.int32)
|
|
|
|
|
|
for idx in prange(num_regions):
|
|
|
# Unpack slice info
|
|
|
label_id = slices_arr[idx, 0]
|
|
|
row_start = slices_arr[idx, 1]
|
|
|
row_stop = slices_arr[idx, 2]
|
|
|
col_start = slices_arr[idx, 3]
|
|
|
col_stop = slices_arr[idx, 4]
|
|
|
|
|
|
# Extract binary submask for this label
|
|
|
submask = (label_map[row_start:row_stop, col_start:col_stop] == label_id)
|
|
|
|
|
|
# Get local coordinates of all pixels in the region
|
|
|
ys, xs = np.nonzero(submask)
|
|
|
|
|
|
# Compute the floating-point centroid within the submask
|
|
|
y_mean = ys.mean()
|
|
|
x_mean = xs.mean()
|
|
|
|
|
|
# Find the pixel closest to the centroid by minimizing squared distance
|
|
|
dist_sq = (ys - y_mean) ** 2 + (xs - x_mean) ** 2
|
|
|
closest_idx = dist_sq.argmin()
|
|
|
|
|
|
# Convert to global coordinates
|
|
|
center_row = ys[closest_idx] + row_start
|
|
|
center_col = xs[closest_idx] + col_start
|
|
|
centers[idx, 0] = center_row
|
|
|
centers[idx, 1] = center_col
|
|
|
|
|
|
# Compute extent as height + width + 2 (to include one-pixel border)
|
|
|
height = row_stop - row_start
|
|
|
width = col_stop - col_start
|
|
|
extents[idx] = height + width + 2
|
|
|
|
|
|
return centers, extents
|
|
|
|
|
|
|
|
|
def __propagate_centers_gpu(
|
|
|
self,
|
|
|
neighbor_indices: torch.Tensor,
|
|
|
center_indices: torch.Tensor,
|
|
|
valid_neighbor_mask: torch.Tensor,
|
|
|
output_shape: tuple[int, int],
|
|
|
num_iterations: int = 200
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Propagates center points across a mask using GPU-based diffusion.
|
|
|
|
|
|
Args:
|
|
|
neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel.
|
|
|
center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers.
|
|
|
valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid.
|
|
|
output_shape (tuple(int, int)): Desired 2D shape of the diffusion tensor, e.g., (H, W).
|
|
|
num_iterations (int, optional): Number of diffusion iterations. Defaults to 200.
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: Array of shape (2, N) with the computed flows.
|
|
|
"""
|
|
|
# Determine total number of elements and choose dtype accordingly
|
|
|
total_elems = torch.prod(torch.tensor(output_shape))
|
|
|
if total_elems > 4e7 or self._device.type == "mps":
|
|
|
diffusion_tensor = torch.zeros(output_shape, dtype=torch.float, device=self._device)
|
|
|
else:
|
|
|
diffusion_tensor = torch.zeros(output_shape, dtype=torch.double, device=self._device)
|
|
|
|
|
|
# Unpack center row and column indices
|
|
|
center_rows, center_cols = center_indices
|
|
|
|
|
|
# Unpack neighbor row and column indices for 9 neighbors per pixel
|
|
|
# Order: [0: center, 1: up, 2: down, 3: left, 4: right,
|
|
|
# 5: up-left, 6: up-right, 7: down-left, 8: down-right]
|
|
|
neigh_rows, neigh_cols = neighbor_indices # each of shape (9, N)
|
|
|
|
|
|
# Perform diffusion iterations
|
|
|
for _ in range(num_iterations):
|
|
|
# Add source at each mask center
|
|
|
diffusion_tensor[center_rows, center_cols] += 1
|
|
|
|
|
|
# Sample neighbor values for each pixel
|
|
|
neighbor_vals = diffusion_tensor[neigh_rows, neigh_cols] # shape (9, N)
|
|
|
|
|
|
# Zero out invalid neighbors
|
|
|
neighbor_vals *= valid_neighbor_mask
|
|
|
|
|
|
# Update the first neighbor (index 0) with the average of valid neighbor values
|
|
|
diffusion_tensor[neigh_rows[0], neigh_cols[0]] = neighbor_vals.mean(dim=0)
|
|
|
|
|
|
# Compute spatial gradients for 2D flow: dy and dx
|
|
|
# Using neighbor indices: up = 1, down = 2, left = 3, right = 4
|
|
|
grad_samples = diffusion_tensor[
|
|
|
neigh_rows[[2, 1, 4, 3]], # indices [down, up, right, left]
|
|
|
neigh_cols[[2, 1, 4, 3]]
|
|
|
] # shape (4, N)
|
|
|
|
|
|
dy = grad_samples[0] - grad_samples[1]
|
|
|
dx = grad_samples[2] - grad_samples[3]
|
|
|
|
|
|
# Stack and convert to numpy flow field with shape (2, N)
|
|
|
flow_field = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=0)
|
|
|
|
|
|
return flow_field
|
|
|
|
|
|
|
|
|
def __mask_to_flow_cpu(self, mask: np.ndarray) -> np.ndarray:
|
|
|
"""
|
|
|
Convert labeled masks to flow vectors by simulating diffusion from mask centers.
|
|
|
|
|
|
Each mask's center is chosen as the pixel closest to its geometric centroid.
|
|
|
A diffusion process is run on a padded local patch, and flows are derived
|
|
|
as gradients (dy, dx) of the resulting density map.
|
|
|
|
|
|
Args:
|
|
|
masks (np.ndarray): 3D integer array of labels `(C x H x W)`,
|
|
|
where 0 = background and positive integers = mask IDs.
|
|
|
|
|
|
Returns:
|
|
|
flow_field (np.ndarray): Array of shape `(2*C, H, W)` containing
|
|
|
flow components [dy, dx] normalized per pixel.
|
|
|
"""
|
|
|
channels, height, width = mask.shape
|
|
|
flows = np.zeros((2*channels, height, width), np.float32)
|
|
|
|
|
|
for channel in range(channels):
|
|
|
# Initialize flow_field with two channels: dy and dx
|
|
|
flow_field = np.zeros((2, height, width), dtype=np.float64)
|
|
|
|
|
|
mask_channel = mask[channel]
|
|
|
|
|
|
# Find bounding box for each labeled mask
|
|
|
mask_slices = find_objects(mask_channel)
|
|
|
# centers: List[Tuple[int, int]] = []
|
|
|
|
|
|
# Iterate over mask labels in parallel
|
|
|
for label_idx in prange(len(mask_slices)):
|
|
|
slc = mask_slices[label_idx]
|
|
|
if slc is None:
|
|
|
continue
|
|
|
|
|
|
# Extract row and column slice for this mask
|
|
|
row_slice, col_slice = slc
|
|
|
# Add 1-pixel border around the patch
|
|
|
patch_height = (row_slice.stop - row_slice.start) + 2
|
|
|
patch_width = (col_slice.stop - col_slice.start) + 2
|
|
|
|
|
|
# Get local coordinates of mask pixels within the patch
|
|
|
local_rows, local_cols = np.nonzero(
|
|
|
mask_channel[row_slice, col_slice] == (label_idx + 1)
|
|
|
)
|
|
|
# Shift coords by +1 for the border padding
|
|
|
local_rows = local_rows.astype(np.int32) + 1
|
|
|
local_cols = local_cols.astype(np.int32) + 1
|
|
|
|
|
|
# Compute centroid and find nearest pixel as diffusion seed
|
|
|
centroid_row = local_rows.mean()
|
|
|
centroid_col = local_cols.mean()
|
|
|
distances = (local_cols - centroid_col) ** 2 + (local_rows - centroid_row) ** 2
|
|
|
seed_index = distances.argmin()
|
|
|
center_row = int(local_rows[seed_index])
|
|
|
center_col = int(local_cols[seed_index])
|
|
|
|
|
|
# Determine number of iterations
|
|
|
total_iter = 2 * (patch_height + patch_width)
|
|
|
|
|
|
# Initialize flat diffusion map for the local patch
|
|
|
diffusion_map = np.zeros(patch_height * patch_width, dtype=np.float64)
|
|
|
# Run diffusion from the seed center
|
|
|
diffusion_map = self.__diffuse_from_center(
|
|
|
diffusion_map,
|
|
|
local_rows,
|
|
|
local_cols,
|
|
|
center_row,
|
|
|
center_col,
|
|
|
patch_width,
|
|
|
total_iter
|
|
|
)
|
|
|
|
|
|
# Compute flow as finite differences (gradient) on the diffusion map
|
|
|
dy = (
|
|
|
diffusion_map[(local_rows + 1) * patch_width + local_cols] -
|
|
|
diffusion_map[(local_rows - 1) * patch_width + local_cols]
|
|
|
)
|
|
|
dx = (
|
|
|
diffusion_map[local_rows * patch_width + (local_cols + 1)] -
|
|
|
diffusion_map[local_rows * patch_width + (local_cols - 1)]
|
|
|
)
|
|
|
|
|
|
# Write flows back into the global flow_field array
|
|
|
flow_field[0,
|
|
|
row_slice.start + local_rows - 1,
|
|
|
col_slice.start + local_cols - 1] = dy
|
|
|
flow_field[1,
|
|
|
row_slice.start + local_rows - 1,
|
|
|
col_slice.start + local_cols - 1] = dx
|
|
|
|
|
|
# Store center location in original image coordinates
|
|
|
# centers.append(
|
|
|
# (row_slice.start + center_row - 1,
|
|
|
# col_slice.start + center_col - 1)
|
|
|
# )
|
|
|
|
|
|
# Normalize each vector [dy,dx] by its magnitude
|
|
|
magnitudes = np.sqrt((flow_field**2).sum(axis=0)) + 1e-60
|
|
|
flow_field /= magnitudes
|
|
|
|
|
|
flows[2*channel: 2*channel + 2] = flow_field
|
|
|
|
|
|
return flows
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
@njit("(float64[:], int32[:], int32[:], int32, int32, int32, int32)", nogil=True)
|
|
|
def __diffuse_from_center(
|
|
|
diffusion_map: np.ndarray,
|
|
|
row_coords: np.ndarray,
|
|
|
col_coords: np.ndarray,
|
|
|
center_row: int,
|
|
|
center_col: int,
|
|
|
patch_width: int,
|
|
|
num_iterations: int
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Perform diffusion of particles from a seed pixel across a local mask patch.
|
|
|
|
|
|
At each iteration, one particle is added at the seed, and each mask pixel's
|
|
|
value is updated to the average of itself and its 8-connected neighbors.
|
|
|
|
|
|
Args:
|
|
|
diffusion_map (np.ndarray): Flat array of length patch_height * patch_width.
|
|
|
row_coords (np.ndarray): 1D array of row indices for mask pixels (local coords).
|
|
|
col_coords (np.ndarray): 1D array of column indices for mask pixels (local coords).
|
|
|
center_row (int): Row index of the seed point in local patch coords.
|
|
|
center_col (int): Column index of the seed point in local patch coords.
|
|
|
patch_width (int): Width (number of columns) in the local patch.
|
|
|
num_iterations (int): Number of diffusion iterations to perform.
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: Updated diffusion_map after performing diffusion.
|
|
|
"""
|
|
|
# Compute linear indices for each mask pixel and its neighbors
|
|
|
base_idx = row_coords * patch_width + col_coords
|
|
|
up = (row_coords - 1) * patch_width + col_coords
|
|
|
down = (row_coords + 1) * patch_width + col_coords
|
|
|
left = row_coords * patch_width + (col_coords - 1)
|
|
|
right = row_coords * patch_width + (col_coords + 1)
|
|
|
up_left = (row_coords - 1) * patch_width + (col_coords - 1)
|
|
|
up_right = (row_coords - 1) * patch_width + (col_coords + 1)
|
|
|
down_left = (row_coords + 1) * patch_width + (col_coords - 1)
|
|
|
down_right = (row_coords + 1) * patch_width + (col_coords + 1)
|
|
|
|
|
|
for _ in range(num_iterations):
|
|
|
# Inject one particle at the seed location
|
|
|
diffusion_map[center_row * patch_width + center_col] += 1.0
|
|
|
|
|
|
# Update each mask pixel as the average over itself and neighbors
|
|
|
diffusion_map[base_idx] = (
|
|
|
diffusion_map[base_idx] +
|
|
|
diffusion_map[up] + diffusion_map[down] +
|
|
|
diffusion_map[left] + diffusion_map[right] +
|
|
|
diffusion_map[up_left] + diffusion_map[up_right] +
|
|
|
diffusion_map[down_left] + diffusion_map[down_right]
|
|
|
) * (1.0 / 9.0)
|
|
|
|
|
|
return diffusion_map
|
|
|
|
|
|
|
|
|
def __segment_instances(
|
|
|
self,
|
|
|
probability_map: np.ndarray,
|
|
|
flow: np.ndarray,
|
|
|
prob_threshold: float = 0.0,
|
|
|
flow_threshold: float = 0.4,
|
|
|
num_iters: int = 200,
|
|
|
min_object_size: int = 15
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Generate instance segmentation masks from probability and flow fields.
|
|
|
|
|
|
Args:
|
|
|
probability_map: 3D array `(C, H, W)` of cell probabilities.
|
|
|
flow: 3D array `(2*C, H, W)` of forward flow vectors.
|
|
|
prob_threshold: threshold to binarize probability_map. (Default 0.0)
|
|
|
flow_threshold: threshold for filtering bad flow masks. (Default 0.4)
|
|
|
num_iters: number of iterations for flow-following. (Default 200)
|
|
|
min_object_size: minimum area to keep small instances. (Default 15)
|
|
|
|
|
|
Returns:
|
|
|
3D array of uint16 instance labels for each channel.
|
|
|
"""
|
|
|
# Create a binary mask of likely cell locations
|
|
|
probability_mask = probability_map > prob_threshold
|
|
|
|
|
|
# If no cells exceed the threshold, return an empty mask
|
|
|
if not np.any(probability_mask):
|
|
|
logger.warning("No cell pixels found.")
|
|
|
return np.zeros_like(probability_map, dtype=np.uint16)
|
|
|
|
|
|
# Prepare output array for instance labels
|
|
|
labeled_instances = np.zeros_like(probability_map, dtype=np.uint16)
|
|
|
|
|
|
# Process each channel independently
|
|
|
for channel_index in range(probability_mask.shape[0]):
|
|
|
# Extract flow vectors for this channel (two components per channel)
|
|
|
channel_flow_vectors = flow[2 * channel_index : 2 * channel_index + 2]
|
|
|
# Extract binary mask for this channel
|
|
|
channel_mask = probability_mask[channel_index]
|
|
|
|
|
|
if not channel_mask.sum():
|
|
|
continue
|
|
|
|
|
|
nonzero_coords = np.stack(np.nonzero(channel_mask))
|
|
|
|
|
|
# Follow the flow vectors to generate coordinate mappings
|
|
|
flow_coordinates = self.__follow_flows(
|
|
|
flow_field=channel_flow_vectors * channel_mask / 5.0,
|
|
|
initial_coords=nonzero_coords,
|
|
|
num_iters=num_iters
|
|
|
)
|
|
|
|
|
|
if not torch.is_tensor(flow_coordinates):
|
|
|
flow_coordinates = torch.from_numpy(
|
|
|
flow_coordinates).to(self._device, dtype=torch.int32)
|
|
|
else:
|
|
|
flow_coordinates = flow_coordinates.int()
|
|
|
|
|
|
# Obtain preliminary instance masks by clustering the coordinates
|
|
|
channel_instances_mask = self.__get_mask(
|
|
|
pixel_positions=flow_coordinates,
|
|
|
valid_indices=nonzero_coords,
|
|
|
original_shape=probability_map.shape[1:]
|
|
|
)
|
|
|
|
|
|
# Filter out bad flow-derived instances if requested
|
|
|
if channel_instances_mask.max() > 0 and flow_threshold > 0:
|
|
|
channel_instances_mask = self.__remove_inconsistent_flow_masks(
|
|
|
mask=channel_instances_mask,
|
|
|
flow_network=channel_flow_vectors,
|
|
|
error_threshold=flow_threshold
|
|
|
)
|
|
|
|
|
|
# Remove small objects or holes below the minimum size
|
|
|
if min_object_size > 0:
|
|
|
# channel_instances_mask = morphology.remove_small_holes(
|
|
|
# channel_instances_mask, area_threshold=min_object_size
|
|
|
# )
|
|
|
# channel_instances_mask = morphology.remove_small_objects(
|
|
|
# channel_instances_mask, min_size=min_object_size
|
|
|
# )
|
|
|
channel_instances_mask = self.__fill_holes_and_prune_small_masks(
|
|
|
channel_instances_mask, minimum_size=min_object_size
|
|
|
)
|
|
|
|
|
|
labeled_instances[channel_index] = channel_instances_mask
|
|
|
|
|
|
return labeled_instances
|
|
|
|
|
|
|
|
|
def __follow_flows(
|
|
|
self,
|
|
|
flow_field: np.ndarray,
|
|
|
initial_coords: np.ndarray,
|
|
|
num_iters: int = 200
|
|
|
) -> np.ndarray | torch.Tensor:
|
|
|
"""
|
|
|
Trace pixel positions through a flow field via iterative interpolation.
|
|
|
|
|
|
Args:
|
|
|
flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors.
|
|
|
initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions.
|
|
|
num_iters (int): Number of integration steps.
|
|
|
|
|
|
Returns:
|
|
|
(np.ndarray | torch.Tensor): Final (y, x) positions of each point.
|
|
|
"""
|
|
|
dims = 2
|
|
|
# Extract spatial dimensions
|
|
|
height, width = flow_field.shape[1:]
|
|
|
|
|
|
# Choose GPU/MPS path if available
|
|
|
if self._device.type in ("cuda", "mps"):
|
|
|
# Prepare point tensor: shape [1, 1, num_points, 2]
|
|
|
pts = torch.zeros((1, 1, initial_coords.shape[1], dims),
|
|
|
dtype=torch.float32, device=self._device)
|
|
|
# Prepare flow volume: shape [1, 2, height, width]
|
|
|
flow_vol = torch.zeros((1, dims, height, width),
|
|
|
dtype=torch.float32, device=self._device)
|
|
|
|
|
|
# Load initial positions and flow into tensors (flip order for grid_sample)
|
|
|
# dim 0 = x
|
|
|
# dim 1 = y
|
|
|
for i in range(dims):
|
|
|
pts[0, 0, :, dims - i - 1] = (
|
|
|
torch.from_numpy(initial_coords[i])
|
|
|
.to(self._device, torch.float32)
|
|
|
)
|
|
|
flow_vol[0, dims - i - 1] = (
|
|
|
torch.from_numpy(flow_field[i])
|
|
|
.to(self._device, torch.float32)
|
|
|
)
|
|
|
|
|
|
# Prepare normalization factors for x and y (max index)
|
|
|
max_indices = torch.tensor([width - 1, height - 1],
|
|
|
dtype=torch.float32, device=self._device)
|
|
|
# Reshape for broadcasting to point tensor dims
|
|
|
max_idx_pt = max_indices.view(1, 1, 1, dims)
|
|
|
# Reshape for broadcasting to flow volume dims
|
|
|
max_idx_flow = max_indices.view(1, dims, 1, 1)
|
|
|
|
|
|
# Normalize flow values to [-1, 1] range
|
|
|
flow_vol = (flow_vol * 2) / max_idx_flow
|
|
|
# Normalize points to [-1, 1]
|
|
|
pts = (pts / max_idx_pt) * 2 - 1
|
|
|
|
|
|
# Iterate: sample flow and update points
|
|
|
for _ in range(num_iters):
|
|
|
sampled = torch.nn.functional.grid_sample(
|
|
|
flow_vol, pts, align_corners=False
|
|
|
)
|
|
|
# Update each coordinate and clamp to valid range
|
|
|
for i in range(dims):
|
|
|
pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0)
|
|
|
|
|
|
# Denormalize back to original pixel coordinates
|
|
|
pts = (pts + 1) * 0.5 * max_idx_pt
|
|
|
# Swap channels back to (y, x) and flatten
|
|
|
final_pts = pts[..., [1, 0]].squeeze()
|
|
|
# Convert from (num_points, 2) to (2, num_points)
|
|
|
return final_pts.T if final_pts.ndim > 1 else final_pts.unsqueeze(0).T
|
|
|
|
|
|
# CPU fallback using numpy and scipy
|
|
|
current_pos = initial_coords.copy().astype(np.float32)
|
|
|
temp_delta = np.zeros_like(current_pos, dtype=np.float32)
|
|
|
|
|
|
for _ in range(num_iters):
|
|
|
# Interpolate flow at current positions
|
|
|
self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta)
|
|
|
# Update positions and clamp to image bounds
|
|
|
current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1)
|
|
|
current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1)
|
|
|
|
|
|
return current_pos
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
@njit([
|
|
|
"(int16[:,:,:], float32[:], float32[:], float32[:,:])",
|
|
|
"(float32[:,:,:], float32[:], float32[:], float32[:,:])"
|
|
|
], cache=True)
|
|
|
def __map_coordinates(
|
|
|
image_data: np.ndarray,
|
|
|
y_coords: np.ndarray,
|
|
|
x_coords: np.ndarray,
|
|
|
output: np.ndarray
|
|
|
) -> None:
|
|
|
"""
|
|
|
Perform in-place bilinear interpolation on an image volume.
|
|
|
|
|
|
Args:
|
|
|
image_data (np.ndarray): Input volume with shape (C, H, W).
|
|
|
y_coords (np.ndarray): Array of new y positions (num_points).
|
|
|
x_coords (np.ndarray): Array of new x positions (num_points).
|
|
|
output (np.ndarray): Output array of shape (C, num_points) to fill.
|
|
|
|
|
|
Returns:
|
|
|
None. Results written directly into `output`.
|
|
|
"""
|
|
|
channels, height, width = image_data.shape
|
|
|
# Compute integer (floor) and fractional parts for coords
|
|
|
y_floor = y_coords.astype(np.int32)
|
|
|
x_floor = x_coords.astype(np.int32)
|
|
|
y_frac = y_coords - y_floor
|
|
|
x_frac = x_coords - x_floor
|
|
|
|
|
|
# Loop over each sample point
|
|
|
for idx in range(y_floor.shape[0]):
|
|
|
# Clamp base indices to valid range
|
|
|
y0 = min(max(y_floor[idx], 0), height - 1)
|
|
|
x0 = min(max(x_floor[idx], 0), width - 1)
|
|
|
y1 = min(y0 + 1, height - 1)
|
|
|
x1 = min(x0 + 1, width - 1)
|
|
|
|
|
|
wy = y_frac[idx]
|
|
|
wx = x_frac[idx]
|
|
|
|
|
|
# Interpolate per channel
|
|
|
for c in range(channels):
|
|
|
v00 = np.float32(image_data[c, y0, x0])
|
|
|
v10 = np.float32(image_data[c, y0, x1])
|
|
|
v01 = np.float32(image_data[c, y1, x0])
|
|
|
v11 = np.float32(image_data[c, y1, x1])
|
|
|
# Bilinear interpolation formula
|
|
|
output[c, idx] = (
|
|
|
v00 * (1 - wy) * (1 - wx) +
|
|
|
v10 * (1 - wy) * wx +
|
|
|
v01 * wy * (1 - wx) +
|
|
|
v11 * wy * wx
|
|
|
)
|
|
|
|
|
|
|
|
|
def __get_mask(
|
|
|
self,
|
|
|
pixel_positions: torch.Tensor,
|
|
|
valid_indices: np.ndarray,
|
|
|
original_shape: tuple[int, ...],
|
|
|
pad_radius: int = 20,
|
|
|
max_size_fraction: float = 0.4
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Generate labeled masks by clustering pixel trajectories via histogram peaks and region growing.
|
|
|
|
|
|
This function executes the following steps:
|
|
|
1. Pads and clamps pixel final positions to avoid border effects.
|
|
|
2. Builds a dense histogram of pixel counts over spatial bins.
|
|
|
3. Identifies local maxima in the histogram as seed points.
|
|
|
4. Extracts local patches around each seed and grows regions by iteratively adding neighbors
|
|
|
that exceed an intensity threshold.
|
|
|
5. Maps grown patches back to original image indices.
|
|
|
6. Removes any masks that exceed a maximum size fraction of the image.
|
|
|
|
|
|
Args:
|
|
|
pixel_positions (torch.Tensor): Tensor of shape `[2, N_pixels]`, dtype=int, containing
|
|
|
final pixel coordinates after dynamics for each dimension.
|
|
|
valid_indices (np.ndarray): Integer array of shape `[2, N_pixels]`
|
|
|
giving indices of pixels in the original image grid.
|
|
|
original_shape (tuple of ints): Spatial dimensions of the original image, e.g. (H, W).
|
|
|
pad_radius (int): Number of zero-padding pixels added on each side of the histogram.
|
|
|
Defaults to 20.
|
|
|
max_size_fraction (float): If any mask has a pixel count > max_size_fraction * total_pixels,
|
|
|
it will be removed. Defaults to 0.4.
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: Integer mask array of shape `original_shape` with labels 0 (background) and 1..M.
|
|
|
|
|
|
Raises:
|
|
|
ValueError: If input dimensions are inconsistent or pixel_positions shape is invalid.
|
|
|
"""
|
|
|
# Validate inputs
|
|
|
ndim = len(original_shape)
|
|
|
if pixel_positions.ndim != 2 or pixel_positions.size(0) != ndim:
|
|
|
msg = f"pixel_positions must be shape [{ndim}, N], got {tuple(pixel_positions.shape)}"
|
|
|
logger.error(msg)
|
|
|
raise ValueError(msg)
|
|
|
if pad_radius < 0:
|
|
|
msg = f"pad_radius must be non-negative, got {pad_radius}"
|
|
|
logger.error(msg)
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
# Step 1: Pad and clamp pixel positions
|
|
|
padded_positions = pixel_positions.clone().to(torch.int64) + pad_radius
|
|
|
for dim in range(ndim):
|
|
|
max_val = original_shape[dim] + pad_radius - 1
|
|
|
padded_positions[dim] = torch.clamp(padded_positions[dim], min=0, max=max_val)
|
|
|
|
|
|
# Build histogram dimensions
|
|
|
hist_shape = tuple(s + 2 * pad_radius for s in original_shape)
|
|
|
|
|
|
# Step 2: Create sparse tensor and densify to get per-pixel counts
|
|
|
try:
|
|
|
counts_sparse = torch.sparse_coo_tensor(
|
|
|
padded_positions,
|
|
|
torch.ones(padded_positions.shape[1], dtype=torch.int32, device=pixel_positions.device),
|
|
|
size=hist_shape
|
|
|
)
|
|
|
histogram = counts_sparse.to_dense()
|
|
|
except Exception as e:
|
|
|
logger.error("Failed to build dense histogram: %s", e)
|
|
|
raise
|
|
|
|
|
|
# Step 3: Find peaks via 5x5 max-pooling
|
|
|
# k = 5
|
|
|
# pooled = F.max_pool2d(
|
|
|
# histogram.float().unsqueeze(0).unsqueeze(1),
|
|
|
# kernel_size=k,
|
|
|
# stride=1,
|
|
|
# padding=k // 2
|
|
|
# ).squeeze()
|
|
|
pooled = self.__max_pool_nd(
|
|
|
histogram.unsqueeze(0),
|
|
|
kernel_size=5
|
|
|
).squeeze()
|
|
|
|
|
|
# Seeds are positions where histogram equals local max and count > threshold
|
|
|
seed_positions = torch.nonzero((histogram - pooled == 0) * (histogram > 10))
|
|
|
if seed_positions.shape[0] == 0:
|
|
|
logger.warning("No seeds found: returning empty mask")
|
|
|
return np.zeros(original_shape, dtype=np.uint16)
|
|
|
|
|
|
# Sort seeds by ascending count to process small peaks first
|
|
|
seed_counts = histogram[tuple(seed_positions.T)]
|
|
|
order = torch.argsort(seed_counts)
|
|
|
seed_positions = seed_positions[order]
|
|
|
del pooled, counts_sparse
|
|
|
|
|
|
# Step 4: Extract local patches and perform region growing
|
|
|
num_seeds = seed_positions.shape[0]
|
|
|
# Tensor to hold local patches
|
|
|
patches = torch.zeros((num_seeds, 11, 11), device=pixel_positions.device)
|
|
|
for idx in range(num_seeds):
|
|
|
coords = seed_positions[idx]
|
|
|
slices = tuple(slice(c - 5, c + 6) for c in coords)
|
|
|
patches[idx] = histogram[slices]
|
|
|
del histogram
|
|
|
|
|
|
# Initialize seed mask (center pixel of each patch)
|
|
|
seed_masks = torch.zeros_like(patches, device=pixel_positions.device)
|
|
|
seed_masks[:, 5, 5] = 1
|
|
|
# Iterative dilation and thresholding
|
|
|
for _ in range(5):
|
|
|
# seed_masks = F.max_pool2d(
|
|
|
# seed_masks.float().unsqueeze(0),
|
|
|
# kernel_size=3,
|
|
|
# stride=1,
|
|
|
# padding=1
|
|
|
# ).squeeze(0).int()
|
|
|
seed_masks = self.__max_pool_nd(seed_masks, kernel_size=3)
|
|
|
seed_masks *= (patches > 2)
|
|
|
# Compute final mask coordinates
|
|
|
final_coords = []
|
|
|
for idx in range(num_seeds):
|
|
|
coords_local = torch.nonzero(seed_masks[idx])
|
|
|
# Shift back to global positions
|
|
|
coords_global = coords_local + seed_positions[idx] - 5
|
|
|
final_coords.append(tuple(coords_global.T))
|
|
|
|
|
|
# Step 5: Paint masks into padded volume
|
|
|
dtype = torch.int32 if num_seeds < 2**16 else torch.int64
|
|
|
mask_padded = torch.zeros(hist_shape, dtype=dtype, device=pixel_positions.device)
|
|
|
for label_idx, coords in enumerate(final_coords, start=1):
|
|
|
mask_padded[coords] = label_idx
|
|
|
|
|
|
# Extract only the padded positions that correspond to original pixels
|
|
|
mask_values = mask_padded[tuple(padded_positions)]
|
|
|
mask_values = mask_values.cpu().numpy()
|
|
|
|
|
|
# Step 6: Map to original image and remove oversized masks
|
|
|
mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32)
|
|
|
mask_final[tuple(valid_indices)] = mask_values
|
|
|
|
|
|
# Prune masks that are too large
|
|
|
labels, counts = fastremap.unique(mask_final, return_counts=True)
|
|
|
total_pixels = np.prod(original_shape)
|
|
|
oversized = labels[counts > (total_pixels * max_size_fraction)]
|
|
|
if oversized.size > 0:
|
|
|
mask_final = fastremap.mask(mask_final, oversized)
|
|
|
fastremap.renumber(mask_final, in_place=True)
|
|
|
|
|
|
return mask_final
|
|
|
|
|
|
|
|
|
def __max_pool1d(
|
|
|
self,
|
|
|
input_tensor: Tensor,
|
|
|
kernel_size: int = 5,
|
|
|
axis: int = 1,
|
|
|
output_tensor: Tensor | None = None
|
|
|
) -> Tensor:
|
|
|
"""
|
|
|
Memory-efficient 1D max pooling along a specified axis using in-place updates.
|
|
|
Requires:
|
|
|
- stride = 1
|
|
|
- padding = kernel_size // 2
|
|
|
- odd kernel_size >= 3
|
|
|
|
|
|
Args:
|
|
|
input_tensor (Tensor): Source tensor for pooling.
|
|
|
kernel_size (int): Size of the pooling window (must be odd and >= 3).
|
|
|
axis (int): Axis along which to compute 1D max pooling.
|
|
|
output_tensor (Tensor | None): Tensor to store the result.
|
|
|
If None, a clone of input_tensor is used.
|
|
|
|
|
|
Returns:
|
|
|
Tensor: The pooled tensor, same shape as input_tensor.
|
|
|
"""
|
|
|
# Initialize or copy data into the output tensor
|
|
|
if output_tensor is None:
|
|
|
output = input_tensor.clone()
|
|
|
else:
|
|
|
output = output_tensor
|
|
|
output.copy_(input_tensor)
|
|
|
|
|
|
# Number of elements along the chosen axis and half-window size
|
|
|
dimension_size = input_tensor.shape[axis]
|
|
|
half_window = kernel_size // 2
|
|
|
|
|
|
# Slide window offsets from -half_window to +half_window
|
|
|
for offset in range(-half_window, half_window + 1):
|
|
|
# Compute slice indices depending on axis
|
|
|
if axis == 1:
|
|
|
target_slice = output[:, max(-offset, 0): min(dimension_size - offset, dimension_size)]
|
|
|
source_slice = input_tensor[:, max(offset, 0): min(dimension_size + offset, dimension_size)]
|
|
|
elif axis == 2:
|
|
|
target_slice = output[:, :, max(-offset, 0): min(dimension_size - offset, dimension_size)]
|
|
|
source_slice = input_tensor[:, :, max(offset, 0): min(dimension_size + offset, dimension_size)]
|
|
|
elif axis == 3:
|
|
|
target_slice = output[:, :, :, max(-offset, 0): min(dimension_size - offset, dimension_size)]
|
|
|
source_slice = input_tensor[:, :, :, max(offset, 0): min(dimension_size + offset, dimension_size)]
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported axis {axis} for 1D pooling")
|
|
|
|
|
|
# In-place element-wise maximum
|
|
|
torch.maximum(target_slice, source_slice, out=target_slice)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
def __max_pool_nd(
|
|
|
self,
|
|
|
input_tensor: Tensor,
|
|
|
kernel_size: int = 5
|
|
|
) -> Tensor:
|
|
|
"""
|
|
|
Memory-efficient N-dimensional max pooling for 2D or 3D spatial data.
|
|
|
Applies 1D max pooling sequentially over each spatial axis.
|
|
|
|
|
|
Args:
|
|
|
input_tensor (Tensor): Input tensor with shape
|
|
|
(batch_size, dim1, dim2, ..., dimN).
|
|
|
kernel_size (int): Size of the pooling window (must be odd and >= 3).
|
|
|
|
|
|
Returns:
|
|
|
Tensor: The pooled tensor, same shape as input_tensor.
|
|
|
"""
|
|
|
# Determine number of spatial dimensions (excluding batch axis)
|
|
|
num_spatial_dims = input_tensor.ndim - 1
|
|
|
|
|
|
# First pass: pool along axis=1
|
|
|
pooled = self.__max_pool1d(input_tensor, kernel_size=kernel_size, axis=1)
|
|
|
# Second pass: pool along axis=2
|
|
|
pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=2)
|
|
|
|
|
|
# If 3D data, apply a third pass along axis=3
|
|
|
if num_spatial_dims == 3:
|
|
|
pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=3)
|
|
|
elif num_spatial_dims != 2:
|
|
|
raise ValueError("max_pool_nd only supports 2D or 3D spatial data")
|
|
|
|
|
|
return pooled
|
|
|
|
|
|
|
|
|
def __remove_inconsistent_flow_masks(
|
|
|
self,
|
|
|
mask: np.ndarray,
|
|
|
flow_network: np.ndarray,
|
|
|
error_threshold: float = 0.4
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Remove labeled masks that have inconsistent optical flows compared to network-predicted flows.
|
|
|
|
|
|
This performs a quality control step by computing flows from the provided masks
|
|
|
and comparing them to the flows predicted by the network. Masks with a mean squared
|
|
|
flow error above `error_threshold` are discarded (set to 0).
|
|
|
|
|
|
Args:
|
|
|
mask (np.ndarray): Integer mask array with shape [H, W].
|
|
|
Values: 0 = no mask; 1,2,... = mask labels.
|
|
|
flow_network (np.ndarray): Float array of network-predicted flows with shape
|
|
|
[2, H, W].
|
|
|
error_threshold (float): Maximum allowed mean squared flow error per mask label.
|
|
|
Defaults to 0.4.
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: The input mask with inconsistent masks removed (labels set to 0).
|
|
|
|
|
|
Raises:
|
|
|
MemoryError: If the mask size exceeds available GPU memory.
|
|
|
"""
|
|
|
# If mask is very large and running on CUDA, check memory
|
|
|
num_pixels = mask.size
|
|
|
if num_pixels > 10000 * 10000 and self._device.type == 'cuda':
|
|
|
# Clear unused GPU cache
|
|
|
torch.cuda.empty_cache()
|
|
|
# Determine PyTorch version
|
|
|
major, minor = map(int, torch.__version__.split('.')[:2])
|
|
|
# Determine current CUDA device index
|
|
|
device_index = (
|
|
|
self._device.index
|
|
|
if hasattr(self._device, 'index')
|
|
|
else torch.cuda.current_device()
|
|
|
)
|
|
|
# Get free and total memory
|
|
|
if major == 1 and minor < 10:
|
|
|
total_mem = torch.cuda.get_device_properties(device_index).total_memory
|
|
|
used_mem = torch.cuda.memory_allocated(device_index)
|
|
|
free_mem = total_mem - used_mem
|
|
|
else:
|
|
|
free_mem, total_mem = torch.cuda.mem_get_info(device_index)
|
|
|
# Estimate required memory for mask-based flow computation
|
|
|
# Assume float32 per pixel
|
|
|
required_bytes = num_pixels * np.dtype(np.float32).itemsize
|
|
|
if required_bytes > free_mem:
|
|
|
logger.error(
|
|
|
'Image too large for GPU memory in flow QC step (required: %d B, available: %d B)',
|
|
|
required_bytes, free_mem
|
|
|
)
|
|
|
raise MemoryError('Insufficient GPU memory for flow QC computation')
|
|
|
|
|
|
# Compute flow errors per mask label
|
|
|
flow_errors, _ = self.__compute_flow_error(mask, flow_network)
|
|
|
|
|
|
# Identify labels with error above threshold
|
|
|
bad_labels = np.nonzero(flow_errors > error_threshold)[0] + 1
|
|
|
|
|
|
# Remove bad masks by setting their label to 0
|
|
|
mask[np.isin(mask, bad_labels)] = 0
|
|
|
return mask
|
|
|
|
|
|
|
|
|
def __compute_flow_error(
|
|
|
self,
|
|
|
mask: np.ndarray,
|
|
|
flow_network: np.ndarray
|
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
|
"""
|
|
|
Compute mean squared error between network-predicted flows and flows derived from masks.
|
|
|
|
|
|
Args:
|
|
|
mask (np.ndarray): Integer masks, shape must match flow_network spatial dims.
|
|
|
flow_network (np.ndarray): Network predicted flows of shape [axis, ...].
|
|
|
|
|
|
Returns:
|
|
|
tuple(np.ndarray, np.ndarray):
|
|
|
- flow_errors: 1D array (length = max label) of mean squared error per label.
|
|
|
- computed_flows: Array of flows derived from the mask, same shape as flow_network.
|
|
|
|
|
|
Raises:
|
|
|
ValueError: If the spatial dimensions of `mask_array` and `flow_network` do not match.
|
|
|
"""
|
|
|
# Ensure mask and flow shapes match
|
|
|
if flow_network.shape[1:] != mask.shape:
|
|
|
logger.error(
|
|
|
'Shape mismatch: network flow shape %s vs mask shape %s',
|
|
|
flow_network.shape[1:], mask.shape
|
|
|
)
|
|
|
raise ValueError('Network flow and mask shapes must match')
|
|
|
|
|
|
# Compute flows from mask labels (user-provided function)
|
|
|
computed_flows = self.__compute_flow_from_mask(mask[None, ...])
|
|
|
|
|
|
# Prepare array for errors (one value per mask label)
|
|
|
num_labels = int(mask.max())
|
|
|
flow_errors = np.zeros(num_labels, dtype=float)
|
|
|
|
|
|
# Accumulate mean squared error over each flow axis
|
|
|
for axis_index in range(computed_flows.shape[0]):
|
|
|
# MSE per label: mean((computed - predicted/5)^2)
|
|
|
flow_errors += mean(
|
|
|
(computed_flows[axis_index] - flow_network[axis_index] / 5.0) ** 2,
|
|
|
mask,
|
|
|
index=np.arange(1, num_labels + 1)
|
|
|
)
|
|
|
|
|
|
return flow_errors, computed_flows
|
|
|
|
|
|
|
|
|
def __fill_holes_and_prune_small_masks(
|
|
|
self,
|
|
|
masks: np.ndarray,
|
|
|
minimum_size: int = 15
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Fill holes in labeled masks and remove masks smaller than a given size.
|
|
|
|
|
|
This function performs two steps:
|
|
|
1. Fills internal holes in each labeled mask using `fill_voids.fill`.
|
|
|
2. Discards any mask whose pixel count is below `minimum_size`.
|
|
|
|
|
|
Args:
|
|
|
masks (np.ndarray): Integer mask array of dimension 2 or 3 (shape [H, W] or [D, H, W]).
|
|
|
Values: 0 = background; 1,2,... = mask labels.
|
|
|
minimum_size (int): Minimum number of pixels required to keep a mask.
|
|
|
Masks smaller than this will be removed.
|
|
|
Set to -1 to skip size-based pruning. Defaults to 15.
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: Processed mask array with holes filled and small masks removed.
|
|
|
|
|
|
Raises:
|
|
|
ValueError: If `masks` is not a 2D or 3D integer array.
|
|
|
"""
|
|
|
# Validate input dimensions
|
|
|
if masks.ndim not in (2, 3):
|
|
|
msg = f"Expected 2D or 3D mask array, got {masks.ndim}D."
|
|
|
logger.error(msg)
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
# Initial pruning of too-small masks
|
|
|
masks = self._prune_small_masks(masks, minimum_size)
|
|
|
|
|
|
# Find bounding boxes for each mask label
|
|
|
object_slices = find_objects(masks)
|
|
|
new_label = 1
|
|
|
output_masks = np.zeros_like(masks, dtype=masks.dtype)
|
|
|
|
|
|
# Loop over each original slice, fill holes, and assign new labels
|
|
|
for original_label, slc in enumerate(object_slices, start=1):
|
|
|
if slc is None:
|
|
|
continue
|
|
|
# Extract sub-volume or sub-image
|
|
|
region = masks[slc] == original_label
|
|
|
if not np.any(region):
|
|
|
continue
|
|
|
# Fill internal holes
|
|
|
filled_region = fill_voids.fill(region)
|
|
|
# Write back into output mask with sequential labels
|
|
|
output_masks[slc][filled_region] = new_label
|
|
|
new_label += 1
|
|
|
|
|
|
# Final pruning after hole filling
|
|
|
output_masks = self._prune_small_masks(output_masks, minimum_size)
|
|
|
return output_masks
|
|
|
|
|
|
|
|
|
def _prune_small_masks(
|
|
|
self,
|
|
|
masks: np.ndarray,
|
|
|
minimum_size: int
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Remove labeled regions in `masks` whose pixel count is below `minimum_size`.
|
|
|
|
|
|
Args:
|
|
|
masks (np.ndarray): Integer mask array (any shape), 0=background.
|
|
|
minimum_size (int): Minimum pixel count; labels smaller are removed. If <0, skip pruning.
|
|
|
|
|
|
Returns:
|
|
|
np.ndarray: Mask array with small labels suppressed and labels renumbered.
|
|
|
"""
|
|
|
if minimum_size < 0:
|
|
|
return masks
|
|
|
|
|
|
labels, counts = fastremap.unique(masks, return_counts=True)
|
|
|
# Skip background label at index 0
|
|
|
non_bg_labels = labels[1:]
|
|
|
non_bg_counts = counts[1:]
|
|
|
# Identify labels to remove
|
|
|
small_labels = non_bg_labels[non_bg_counts < minimum_size]
|
|
|
if small_labels.size > 0:
|
|
|
masks = fastremap.mask(masks, small_labels)
|
|
|
fastremap.renumber(masks, in_place=True)
|
|
|
return masks |