The methods of training, testing and prediction are implemented

master
laynholt 11 hours ago
parent 60aebe5921
commit 5d984dc7a9

@ -10,6 +10,7 @@ class DatasetCommonConfig(BaseModel):
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
predictions_dir: str = "." # Directory to save predictions
@model_validator(mode="after")
@ -70,7 +71,6 @@ class DatasetTrainingConfig(BaseModel):
batch_size: int = 1 # Batch size for training
num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
pretrained_weights: str = "" # Path to pretrained weights for training

@ -133,8 +133,8 @@ def get_test_transforms():
"""
test_transforms = Compose(
[
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
CustomLoadImaged(keys=["image", "mask"], allow_missing_keys=True, image_only=True),
# Load image and label data in (H, W, C) format (allow missing keys).
CustomLoadImaged(keys=["image", "mask"], allow_missing_keys=True, image_only=False),
# Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImaged(
keys=["image"],
@ -169,8 +169,8 @@ def get_predict_transforms():
"""
pred_transforms = Compose(
[
# Load the image data in (H, W, C) format (image loaded as image-only).
CustomLoadImage(image_only=True),
# Load the image data in (H, W, C) format.
CustomLoadImage(image_only=False),
# Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]),
# Ensure the image is in channel-first format.

@ -84,12 +84,12 @@ class BCE_MSE_Loss(BaseLoss):
# Cell Recognition Loss
cellprob_loss = self.bce_loss(
outputs[:, -self.num_classes:], target[:, self.num_classes:2 * self.num_classes].float()
outputs[:, -self.num_classes:], (target[:, -self.num_classes:] > 0).float()
)
# Cell Distinction Loss
gradflow_loss = 0.5 * self.mse_loss(
outputs[:, :2 * self.num_classes], 5.0 * target[:, 2 * self.num_classes:]
outputs[:, :2 * self.num_classes], 5.0 * target[:, :2 * self.num_classes]
)
# Total loss

@ -19,7 +19,7 @@ class ModelVParams(BaseModel):
decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration
decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels
in_channels: int = 3 # Number of input channels
out_classes: int = 3 # Number of output classes
out_classes: int = 1 # Number of output classes
def asdict(self):
"""
@ -38,6 +38,8 @@ class ModelV(MAnet):
def __init__(self, params: ModelVParams) -> None:
# Initialize the MAnet model with provided parameters
super().__init__(**params.asdict())
self.num_classes = params.out_classes
# Remove the default segmentation head as it's not used in this architecture
self.segmentation_head = None
@ -53,12 +55,6 @@ class ModelV(MAnet):
self.gradflow_head = DeepSegmentationHead(
in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes
)
# self.gradflow_head = nn.ModuleList([
# DeepSegmentationHead(
# in_channels=params.decoder_channels[-1], out_channels=2
# ) for _ in range(params.out_classes)
# ])
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -74,10 +70,6 @@ class ModelV(MAnet):
cellprob_mask = self.cellprob_head(decoder_output)
gradflow_mask = self.gradflow_head(decoder_output)
# gradflow_masks = torch.cat(
# [head(decoder_output) for head in self.flow_heads], dim=1 # [B, 2*C, H, W]
# )
# Concatenate the masks for output
masks = torch.cat((gradflow_mask, cellprob_mask), dim=1)

@ -1,4 +1,3 @@
import torch
import random
import numpy as np
from numba import njit, prange
@ -10,23 +9,41 @@ from torch.utils.data import DataLoader
import fastremap
import fill_voids
from skimage import morphology
from skimage.segmentation import find_boundaries
from scipy.ndimage import mean, find_objects
import fill_voids
from monai.data.dataset import Dataset
from monai.transforms import * # type: ignore
from monai.inferers.utils import sliding_window_inference
from monai.metrics.cumulative_average import CumulativeAverage
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os
import glob
import copy
import tifffile as tiff
from pprint import pformat
from typing import Optional, Tuple, List, Union
from tabulate import tabulate
from typing import Any, Dict, Literal, Optional, Tuple, List, Union
from tqdm import tqdm
import wandb
from config import Config
from core.models import *
from core.losses import *
from core.optimizers import *
from core.schedulers import *
from core.utils import (
compute_batch_segmentation_tp_fp_fn,
compute_f1_score,
compute_average_precision_score
)
from core.logger import get_logger
@ -40,6 +57,11 @@ class CellSegmentator:
self.__parse_config(config)
self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu")
self._scaler = (
torch.amp.GradScaler(self._device.type) # type: ignore
if self._dataset_setup.is_training and self._dataset_setup.common.use_amp
else None
)
self._train_dataloader: Optional[DataLoader] = None
self._valid_dataloader: Optional[DataLoader] = None
@ -214,17 +236,237 @@ class CellSegmentator:
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
def print_data_info(
self,
loader_type: Literal["train", "valid", "test", "predict"],
index: Optional[int] = None
) -> None:
"""
Prints statistics for a single sample from the specified dataloader.
Args:
loader_type: One of "train", "valid", "test", "predict".
index: The sample index; if None, a random index is selected.
"""
# Retrieve the dataloader attribute, e.g., self._train_dataloader
loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None)
if loader is None:
logger.error(f"Dataloader '{loader_type}' is not initialized.")
return
dataset = loader.dataset
total = len(dataset) # type: ignore
if total == 0:
logger.error(f"Dataset for '{loader_type}' is empty.")
return
# Choose index
idx = index if index is not None else random.randint(0, total - 1)
if not (0 <= idx < total):
logger.error(f"Index {idx} is out of range [0, {total}).")
return
# Fetch the sample and apply transforms
sample = dataset[idx]
# Expecting a dict with {'image': ..., 'mask': ...} or {'image': ...}
img = sample["image"]
# Convert tensor to numpy if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
img = np.asarray(img)
# Compute image statistics
img_min, img_max = img.min(), img.max()
img_mean, img_std = float(img.mean()), float(img.std())
img_shape = img.shape
# Prepare log lines
lines = [
"=" * 40,
f"Dataloader: '{loader_type}', sample index: {idx} / {total - 1}",
f"Image — shape: {img_shape}, min: {img_min:.4f}, max: {img_max:.4f}, mean: {img_mean:.4f}, std: {img_std:.4f}"
]
# For 'predict', no mask is available
if loader_type != "predict":
mask = sample.get("mask", None)
if mask is not None:
# Convert tensor to numpy if needed
if isinstance(mask, torch.Tensor):
mask = mask.cpu().numpy()
mask = np.asarray(mask)
m_min, m_max = mask.min(), mask.max()
m_mean, m_std = float(mask.mean()), float(mask.std())
m_shape = mask.shape
lines.append(
f"Mask — shape: {m_shape}, min: {m_min:.4f}, "
f"max: {m_max:.4f}, mean: {m_mean:.4f}, std: {m_std:.4f}"
)
else:
lines.append("Mask — not available for this sample.")
lines.append("=" * 40)
# Output via logger
for l in lines:
logger.info(l)
def train(self) -> None:
pass
"""
Train the model over multiple epochs, including validation and test.
"""
# Ensure training is enabled in dataset setup
if not self._dataset_setup.is_training:
raise RuntimeError("Dataset setup indicates training is disabled.")
# Determine device name for logging
if self._device.type == "cpu":
device_name = "cpu"
else:
idx = self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device()
device_name = torch.cuda.get_device_name(idx)
logger.info(f"\n{'=' * 50}\n")
logger.info(f"Training starts on device: {device_name}")
logger.info(f"\n{'=' * 50}")
best_f1_score = 0.0
best_weights = None
for epoch in range(1, self._dataset_setup.training.num_epochs + 1):
train_metrics = self.__run_epoch("train", epoch)
self.__print_with_logging(train_metrics, epoch)
# Step the scheduler after training
if self._scheduler is not None:
self._scheduler.step()
# Periodic validation or tuning
if epoch % self._dataset_setup.training.val_freq == 0:
if self._valid_dataloader is not None:
valid_metrics = self.__run_epoch("valid", epoch)
self.__print_with_logging(valid_metrics, epoch)
# Update best model on improved F1
f1 = valid_metrics.get("valid_f1_score", 0.0)
if f1 > best_f1_score:
best_f1_score = f1
# Deep copy weights to avoid reference issues
best_weights = copy.deepcopy(self._model.state_dict())
logger.info(f"Updated best model weights with F1 score: {f1:.4f}")
# Restore best model weights if available
if best_weights is not None:
self._model.load_state_dict(best_weights)
if self._test_dataloader is not None:
test_metrics = self.__run_epoch("test")
self.__print_with_logging(test_metrics, 0)
def evaluate(self) -> None:
pass
"""
Run a full test epoch and display/log the resulting metrics.
"""
test_metrics = self.__run_epoch("test")
self.__print_with_logging(test_metrics, 0)
def predict(self) -> None:
pass
"""
Run inference on the predict set and save the resulting instance masks.
"""
# Ensure the predict DataLoader has been set
if self._predict_dataloader is None:
raise RuntimeError("DataLoader for mode 'predict' is not set.")
batch_counter = 0
for batch in tqdm(self._predict_dataloader, desc="Predicting"):
# Move input images to the configured device (CPU/GPU)
inputs = batch["img"].to(self._device)
# Use automatic mixed precision if enabled in dataset setup
with torch.amp.autocast( # type: ignore
self._device.type,
enabled=self._dataset_setup.common.use_amp
):
# Disable gradient computation for inference
with torch.no_grad():
# Run the models forward pass in predict mode
raw_output = self.__run_inference(inputs, mode="predict")
# Convert logits/probabilities to discrete instance masks
# ground_truth is not passed here; only predictions are needed
preds, _ = self.__post_process_predictions(raw_output)
# Save out the predicted masks, using batch_counter to index files
self.__save_prediction_masks(batch, preds, batch_counter)
# Increment counter by batch size for unique file naming
batch_counter += inputs.shape[0]
def run(self) -> None:
"""
Orchestrate the full workflow:
- If training is enabled in the dataset setup, start training.
- Otherwise, if a test DataLoader is provided, run evaluation.
- Else if a prediction DataLoader is provided, run inference/prediction.
- If neither loader is available in nontraining mode, raise an error.
"""
# 1) TRAINING PATH
if self._dataset_setup.is_training:
# Launch the full training loop (with validation, scheduler steps, etc.)
self.train()
return
# 2) NON-TRAINING PATH (TEST or PREDICT)
# Prefer test if available
if self._test_dataloader is not None:
# Run a single evaluation epoch on the test set and log metrics
self.evaluate()
return
# If no test loader, fall back to prediction if available
if self._predict_dataloader is not None:
# Run inference on the predict set and save outputs
self.predict()
return
# 3) ERROR: no appropriate loader found
raise RuntimeError(
"Neither test nor predict DataLoader is set for nontraining mode."
)
def load_from_checkpoint(self, checkpoint_path: str) -> None:
"""
Loads model weights from a specified checkpoint into the current model.
Args:
checkpoint_path (str): Path to the checkpoint file containing the model weights.
"""
# Load the checkpoint onto the correct device (CPU or GPU)
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True)
# Load the state dict into the model, allowing for missing keys
self._model.load_state_dict(checkpoint['state_dict'], strict=False)
def save_checkpoint(self, checkpoint_path: str) -> None:
"""
Saves the current model weights to a checkpoint file.
Args:
checkpoint_path (str): Path where the checkpoint file will be saved.
"""
# Create a checkpoint dictionary containing the models state_dict
checkpoint = {
'state_dict': self._model.state_dict()
}
# Write the checkpoint to disk
torch.save(checkpoint, checkpoint_path)
def __parse_config(self, config: Config) -> None:
@ -255,12 +497,26 @@ class CellSegmentator:
# Initialize model using the model registry
self._model = ModelRegistry.get_model_class(model.name)(model.params)
# Loads model weights from a specified checkpoint
if config.dataset_config.is_training:
if config.dataset_config.training.pretrained_weights:
self.load_from_checkpoint(config.dataset_config.training.pretrained_weights)
# Initialize loss criterion if specified
self._criterion = (
CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params)
if criterion is not None
else None
)
if hasattr(self._criterion, "num_classes"):
nc_model = self._model.num_classes
nc_crit = getattr(self._criterion, "num_classes")
if nc_model != nc_crit:
raise ValueError(
f"Number of classes mismatch: model.num_classes={nc_model} "
f"but criterion.num_classes={nc_crit}"
)
# Initialize optimizer if specified
self._optimizer = (
@ -298,6 +554,7 @@ class CellSegmentator:
logger.info("[COMMON]")
logger.info(f"├─ Seed: {common.seed}")
logger.info(f"├─ Device: {common.device}")
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
if config.dataset_config.is_training:
@ -306,7 +563,6 @@ class CellSegmentator:
logger.info(f"├─ Batch size: {training.batch_size}")
logger.info(f"├─ Epochs: {training.num_epochs}")
logger.info(f"├─ Validation frequency: {training.val_freq}")
logger.info(f"├─ Use AMP: {'yes' if training.use_amp else 'no'}")
logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}")
if training.is_split:
@ -433,6 +689,624 @@ class CellSegmentator:
return Dataset(data, transforms)
def __print_with_logging(self, results: Dict[str, float], step: int) -> None:
"""
Print metrics in a tabular format and log to W&B.
Args:
results (Dict[str, float]): results dictionary.
step (int): epoch index.
"""
table = tabulate(
tabular_data=results.items(),
headers=["Metric", "Value"],
floatfmt=".4f",
tablefmt="fancy_grid"
)
print(table, "\n")
wandb.log(results, step=step)
def __run_epoch(self,
mode: Literal["train", "valid", "test"],
epoch: Optional[int] = None
) -> Dict[str, float]:
"""
Execute one epoch of training, validation, or testing.
Args:
mode (str): One of 'train', 'valid', or 'test'.
epoch (int, optional): Current epoch number for logging.
Returns:
Dict[str, float]: Loss metrics and F1 score for valid/test.
"""
# Ensure required components are available
if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None):
raise RuntimeError("Optimizer and loss function must be initialized for train/valid.")
# Set model mode and choose the appropriate data loader
if mode == "train":
self._model.train()
loader = self._train_dataloader
else:
self._model.eval()
loader = self._valid_dataloader if mode == "valid" else self._test_dataloader
# Check that the data loader is provided
if loader is None:
raise RuntimeError(f"DataLoader for mode '{mode}' is not set.")
all_tp, all_fp, all_fn = [], [], []
# Prepare tqdm description with epoch info if available
if epoch is not None:
desc = f"Epoch {epoch}/{self._dataset_setup.training.num_epochs} [{mode}]"
else:
desc = f"Epoch ({mode})"
# Iterate over batches
batch_counter = 0
for batch in tqdm(loader, desc=desc):
inputs = batch["img"].to(self._device)
targets = batch["label"].to(self._device)
# Zero gradients for training
if self._optimizer is not None:
self._optimizer.zero_grad()
# Mixed precision context if enabled
with torch.amp.autocast( # type: ignore
self._device.type,
enabled=self._dataset_setup.common.use_amp
):
# Only compute gradients in training phase
with torch.set_grad_enabled(mode == "train"):
# Forward pass
raw_output = self.__run_inference(inputs, mode)
if self._criterion is not None:
# Convert label masks to flow representations (one-hot)
flow_targets = self.__compute_flows_from_masks(targets)
# Compute loss for this batch
batch_loss = self._criterion(raw_output, flow_targets) # type: ignore
# Post-process and compute F1 during validation and testing
if mode in ("valid", "test"):
preds, labels_post = self.__post_process_predictions(
raw_output, targets
)
# Collecting statistics on the batch
tp, fp, fn = self.__compute_stats(
predicted_masks=preds,
ground_truth_masks=labels_post, # type: ignore
iou_threshold=0.5
)
all_tp.append(tp)
all_fp.append(fp)
all_fn.append(fn)
if mode == "test":
self.__save_prediction_masks(
batch, preds, batch_counter
)
# Backpropagation and optimizer step in training
if mode == "train":
if self._dataset_setup.common.use_amp and self._scaler is not None:
self._scaler.scale(batch_loss).backward() # type: ignore
self._scaler.unscale_(self._optimizer.optim) # type: ignore
self._scaler.step(self._optimizer.optim) # type: ignore
self._scaler.update()
else:
batch_loss.backward() # type: ignore
if self._optimizer is not None:
self._optimizer.step()
batch_counter += inputs.shape[0]
if self._criterion is not None:
# Collect loss metrics
epoch_metrics = {f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()}
# Reset internal loss metrics accumulator
self._criterion.reset_metrics()
else:
epoch_metrics = {}
# Include F1 and mAP for validation and testing
if mode in ("valid", "test"):
# Concatenating by batch: shape (num_batches*B, C)
tp_array = np.vstack(all_tp)
fp_array = np.vstack(all_fp)
fn_array = np.vstack(all_fn)
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="micro"
)
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="macro"
)
return epoch_metrics
def __run_inference(
self,
inputs: torch.Tensor,
mode: Literal["train", "valid", "test", "predict"] = "train"
) -> torch.Tensor:
"""
Perform model inference for different stages.
Args:
inputs (torch.Tensor): Input tensor of shape (B, C, H, W).
stage (Literal[...]): One of "train", "valid", "test", "predict".
Returns:
torch.Tensor: Model outputs tensor.
"""
if mode != "train":
# Use sliding window inference for non-training phases
outputs = sliding_window_inference(
inputs,
roi_size=512,
sw_batch_size=4,
predictor=self._model,
padding_mode="constant",
mode="gaussian",
overlap=0.5,
)
else:
# Direct forward pass during training
outputs = self._model(inputs)
return outputs # type: ignore
def __post_process_predictions(
self,
raw_outputs: torch.Tensor,
ground_truth: Optional[torch.Tensor] = None
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Post-process raw network outputs to extract instance segmentation masks.
Args:
raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W).
ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W).
Returns:
Tuple[np.ndarray, Optional[np.ndarray]]:
- instance_masks: Instance-wise masks array of shape (B, С, H, W).
- labels_np: Converted ground truth of shape (B, С, H, W) or None if
ground_truth was not provided.
"""
# Move outputs to CPU and convert to numpy
outputs_np = raw_outputs.cpu().numpy()
# Split channels: gradient flows then class logits
gradflow = outputs_np[:, :2 * self._model.num_classes]
logits = outputs_np[:, -self._model.num_classes :]
# Apply sigmoid to logits to get probabilities
probabilities = self.__sigmoid(logits)
batch_size, _, height, width = probabilities.shape
# Prepare container for instance masks
instance_masks = np.zeros((batch_size, self._model.num_classes, height, width), dtype=np.uint16)
for idx in range(batch_size):
instance_masks[idx] = self.__segment_instances(
probability_map=probabilities[idx],
flow=gradflow[idx],
prob_threshold=0.0,
flow_threshold=0.4,
min_object_size=15
)
# Convert ground truth to numpy
labels_np = ground_truth.cpu().numpy() if ground_truth is not None else None
return instance_masks, labels_np
def __compute_stats(
self,
predicted_masks: np.ndarray,
ground_truth_masks: np.ndarray,
iou_threshold: float = 0.5
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute batch-wise true positives, false positives, and false negatives
for instance segmentation, using a configurable IoU threshold.
Args:
predicted_masks (np.ndarray): Predicted instance masks of shape (B, C, H, W).
ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W).
iou_threshold (float): Intersection-over-Union threshold for matching predictions
to ground truths (default: 0.5).
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]:
- tp: True positives per batch and class, shape (B, C)
- fp: False positives per batch and class, shape (B, C)
- fn: False negatives per batch and class, shape (B, C)
"""
stats = compute_batch_segmentation_tp_fp_fn(
batch_ground_truth=ground_truth_masks,
batch_prediction=predicted_masks,
iou_threshold=iou_threshold,
remove_boundary_objects=True
)
tp = stats["tp"]
fp = stats["fp"]
fn = stats["fn"]
return tp, fp, fn
def __compute_f1_metric(
self,
true_positives: np.ndarray,
false_positives: np.ndarray,
false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro"
) -> Union[float, np.ndarray]:
"""
Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes.
Args:
true_positives: array of TP counts per sample and class.
false_positives: array of FP counts per sample and class.
false_negatives: array of FN counts per sample and class.
reduction:
- 'none': return F1 for each sample, class shape (batch_size, num_classes)
- 'micro': global F1 over all samples & classes
- 'imagewise': F1 per sample (summing over classes), then average over samples
- 'macro': average class-wise F1 (classes summed over batch)
- 'weighted': class-wise F1 weighted by support (TP+FN)
Returns:
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
or np.ndarray of shape (batch_size, num_classes) if reduction='none'.
"""
batch_size, num_classes = true_positives.shape
# 1) No reduction: compute F1 for each (sample, class)
if reduction == "none":
f1_matrix = np.zeros((batch_size, num_classes), dtype=float)
for i in range(batch_size):
for c in range(num_classes):
tp_val = int(true_positives[i, c])
fp_val = int(false_positives[i, c])
fn_val = int(false_negatives[i, c])
_, _, f1_val = compute_f1_score(tp_val, fp_val, fn_val)
f1_matrix[i, c] = f1_val
return f1_matrix
# 2) Micro: sum all TP/FP/FN and compute a single F1
if reduction == "micro":
tp_total = int(true_positives.sum())
fp_total = int(false_positives.sum())
fn_total = int(false_negatives.sum())
_, _, f1_global = compute_f1_score(tp_total, fp_total, fn_total)
return f1_global
# 3) Imagewise: compute per-sample F1 (sum over classes), then average
if reduction == "imagewise":
f1_per_image = np.zeros(batch_size, dtype=float)
for i in range(batch_size):
tp_i = int(true_positives[i].sum())
fp_i = int(false_positives[i].sum())
fn_i = int(false_negatives[i].sum())
_, _, f1_i = compute_f1_score(tp_i, fp_i, fn_i)
f1_per_image[i] = f1_i
return float(f1_per_image.mean())
# For macro/weighted, first aggregate per class across the batch
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Macro: average F1 across classes equally
if reduction == "macro":
f1_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
_, _, f1_c = compute_f1_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
f1_per_class[c] = f1_c
return float(f1_per_class.mean())
# 5) Weighted: class-wise F1 weighted by support = TP + FN
if reduction == "weighted":
f1_per_class = np.zeros(num_classes, dtype=float)
support = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
tp_c = tp_per_class[c]
fp_c = fp_per_class[c]
fn_c = fn_per_class[c]
_, _, f1_c = compute_f1_score(tp_c, fp_c, fn_c)
f1_per_class[c] = f1_c
support[c] = tp_c + fn_c
total_support = support.sum()
if total_support == 0:
# fallback to unweighted macro if no positives
return float(f1_per_class.mean())
weights = support / total_support
return float((f1_per_class * weights).sum())
raise ValueError(f"Unknown reduction mode: {reduction}")
def __compute_average_precision_metric(
self,
true_positives: np.ndarray,
false_positives: np.ndarray,
false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro"
) -> Union[float, np.ndarray]:
"""
Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes.
AP is defined here as:
AP = TP / (TP + FP + FN)
Args:
true_positives: array of true positives per sample and class.
false_positives: array of false positives per sample and class.
false_negatives: array of false negatives per sample and class.
reduction:
- 'none': return AP for each sample and class shape (batch_size, num_classes)
- 'micro': global AP over all samples & classes
- 'imagewise': AP per sample (summing stats over classes), then average over batch
- 'macro': average class-wise AP (each class summed over batch)
- 'weighted': class-wise AP weighted by support (TP+FN)
Returns:
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
or np.ndarray of shape (batch_size, num_classes) if reduction='none'.
"""
batch_size, num_classes = true_positives.shape
# 1) No reduction: AP per (sample, class)
if reduction == "none":
ap_matrix = np.zeros((batch_size, num_classes), dtype=float)
for i in range(batch_size):
for c in range(num_classes):
tp_val = int(true_positives[i, c])
fp_val = int(false_positives[i, c])
fn_val = int(false_negatives[i, c])
ap_val = compute_average_precision_score(tp_val, fp_val, fn_val)
ap_matrix[i, c] = ap_val
return ap_matrix
# 2) Micro: sum all TP/FP/FN and compute one AP
if reduction == "micro":
tp_total = int(true_positives.sum())
fp_total = int(false_positives.sum())
fn_total = int(false_negatives.sum())
return compute_average_precision_score(tp_total, fp_total, fn_total)
# 3) Imagewise: compute per-sample AP (sum over classes), then mean
if reduction == "imagewise":
ap_per_image = np.zeros(batch_size, dtype=float)
for i in range(batch_size):
tp_i = int(true_positives[i].sum())
fp_i = int(false_positives[i].sum())
fn_i = int(false_negatives[i].sum())
ap_per_image[i] = compute_average_precision_score(tp_i, fp_i, fn_i)
return float(ap_per_image.mean())
# For macro and weighted: first aggregate per class across batch
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Macro: average AP across classes equally
if reduction == "macro":
ap_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
ap_per_class[c] = compute_average_precision_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
return float(ap_per_class.mean())
# 5) Weighted: class-wise AP weighted by support = TP + FN
if reduction == "weighted":
ap_per_class = np.zeros(num_classes, dtype=float)
support = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
tp_c = tp_per_class[c]
fp_c = fp_per_class[c]
fn_c = fn_per_class[c]
ap_per_class[c] = compute_average_precision_score(tp_c, fp_c, fn_c)
support[c] = tp_c + fn_c
total_support = support.sum()
if total_support == 0:
# fallback to unweighted macro if no positive instances
return float(ap_per_class.mean())
weights = support / total_support
return float((ap_per_class * weights).sum())
raise ValueError(f"Unknown reduction mode: {reduction}")
@staticmethod
def __sigmoid(z: np.ndarray) -> np.ndarray:
"""
Numerically stable sigmoid activation for numpy arrays.
Args:
z (np.ndarray): Input array.
Returns:
np.ndarray: Sigmoid of the input.
"""
return 1 / (1 + np.exp(-z))
def __save_prediction_masks(
self,
sample: Dict[str, Any],
predicted_mask: Union[np.ndarray, torch.Tensor],
start_index: int = 0,
) -> None:
"""
Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders.
Args:
sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
start_index (int): Starting index for naming when metadata is missing.
"""
# Determine base paths
base_output_dir = self._dataset_setup.common.predictions_dir
masks_dir = base_output_dir
plots_dir = os.path.join(base_output_dir, "plots")
os.makedirs(masks_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)
# Extract image (C, H, W) or batch of images (B, C, H, W), and metadata
image_obj = sample.get("image") # Expected shape: (C, H, W) or (B, C, H, W)
mask_obj = sample.get("mask") # Expected shape: (C, H, W) or (B, C, H, W)
image_meta = sample.get("image_meta_dict")
# Convert tensors to numpy
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
return x
image_array = to_numpy(image_obj) if image_obj is not None else None
mask_array = to_numpy(mask_obj) if mask_obj is not None else None
pred_array = to_numpy(predicted_mask)
# Handle batch dimension: (B, C, H, W)
if pred_array.ndim == 4:
for idx in range(pred_array.shape[0]):
batch_sample = dict(sample)
if image_array is not None and image_array.ndim == 4:
batch_sample["image"] = image_array[idx]
if isinstance(image_meta, list):
batch_sample["image_meta_dict"] = image_meta[idx]
if mask_array is not None and mask_array.ndim == 4:
batch_sample["mask"] = mask_array[idx]
self.__save_prediction_masks(
batch_sample,
pred_array[idx],
start_index=start_index+idx
)
return
# Determine base filename
if image_meta and "filename_or_obj" in image_meta:
base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0]
else:
# Use provided start_index when metadata missing
base_name = f"prediction_{start_index:04d}"
# Now pred_array shape is (C, H, W)
num_channels = pred_array.shape[0]
for channel_idx in range(num_channels):
channel_mask = pred_array[channel_idx]
# File names
mask_filename = f"{base_name}_ch{channel_idx:02d}.tif"
plot_filename = f"{base_name}_ch{channel_idx:02d}.png"
mask_path = os.path.join(masks_dir, mask_filename)
plot_path = os.path.join(plots_dir, plot_filename)
# Save mask TIFF (16-bit)
tiff.imwrite(mask_path, channel_mask.astype(np.uint16), compression="zlib")
# Extract corresponding true mask channel if exists
true_mask = None
if mask_array is not None and mask_array.ndim == 3:
true_mask = mask_array[channel_idx]
# Generate and save visualization
self.__plot_mask(
file_path=plot_path,
image_data=image_array, # type: ignore
predicted_mask=channel_mask,
true_mask=true_mask,
)
def __plot_mask(
self,
file_path: str,
image_data: np.ndarray,
predicted_mask: np.ndarray,
true_mask: Optional[np.ndarray] = None,
) -> None:
"""
Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided.
"""
img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data
if true_mask is None:
fig, axs = plt.subplots(1, 3, figsize=(15,5))
plt.subplots_adjust(wspace=0.02, hspace=0)
self.__plot_panels(axs, img, predicted_mask, 'red',
('Original Image','Predicted Mask','Predicted Contours'))
else:
fig, axs = plt.subplots(2,3,figsize=(15,10))
plt.subplots_adjust(wspace=0.02, hspace=0.1)
# row 0: predicted
self.__plot_panels(axs[0], img, predicted_mask, 'red',
('Original Image','Predicted Mask','Predicted Contours'))
# row 1: true
self.__plot_panels(axs[1], img, true_mask, 'blue',
('Original Image','True Mask','True Contours'))
fig.savefig(file_path, bbox_inches='tight', dpi=300)
plt.close(fig)
def __plot_panels(
self,
axes,
img: np.ndarray,
mask: np.ndarray,
contour_color: str,
titles: Tuple[str, ...]
):
"""
Plot a row of three panels: original image, mask, and mask boundaries on image.
Args:
axes: list/array of three Axis objects.
img: Image array (H, W or H, W, C).
mask: Label mask (H, W).
contour_color: Color for boundaries.
titles: (title_img, title_mask, title_contours).
"""
# Panel 1: Original image
ax0, ax1, ax2 = axes
ax0.imshow(img, cmap='gray' if img.ndim == 2 else None)
ax0.set_title(titles[0]); ax0.axis('off')
# Compute boundaries once
boundaries = find_boundaries(mask, mode='thick')
# Panel 2: Mask with black boundaries
cmap = plt.get_cmap("gist_ncar")
cmap = mcolors.ListedColormap([
cmap(i/len(np.unique(mask))) for i in range(len(np.unique(mask)))
])
ax1.imshow(mask, cmap=cmap)
ax1.contour(boundaries, colors='black', linewidths=0.5)
ax1.set_title(titles[1])
ax1.axis('off')
# Panel 3: Original image with black contour overlay
ax2.imshow(img)
# Draw boundaries as contour lines
ax2.contour(boundaries, colors=contour_color, linewidths=0.5)
ax2.set_title(titles[2])
ax2.axis('off')
def __compute_flows_from_masks(
self,
@ -445,9 +1319,8 @@ class CellSegmentator:
true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks.
Returns:
numpy array of concatenated [renumbered_true_masks, binary_masks, flow_vectors] per image.
renumbered_true_masks is labels, binary_masks is cell distance transform, flow_vectors[2] is Y flow, flows[k][3] is X flow,
and flow_vectors[4] is heat distribution.
numpy array of concatenated [flow_vectors, renumbered_true_masks] per image.
renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow.
"""
# Move to CPU numpy
_true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16)
@ -467,7 +1340,7 @@ class CellSegmentator:
flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i])
for i in range(batch_size)])
return np.concatenate((renumbered, renumbered > 0.5, flow_vectors), axis=1).astype(np.float32)
return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32)
def __compute_flow_from_mask(
@ -879,7 +1752,7 @@ class CellSegmentator:
prob_threshold: float = 0.0,
flow_threshold: float = 0.4,
num_iters: int = 200,
min_object_size: int = 0
min_object_size: int = 15
) -> np.ndarray:
"""
Generate instance segmentation masks from probability and flow fields.
@ -890,7 +1763,7 @@ class CellSegmentator:
prob_threshold: threshold to binarize probability_map. (Default 0.0)
flow_threshold: threshold for filtering bad flow masks. (Default 0.4)
num_iters: number of iterations for flow-following. (Default 200)
min_object_size: minimum area to keep small instances. (Default 0)
min_object_size: minimum area to keep small instances. (Default 15)
Returns:
3D array of uint16 instance labels for each channel.

@ -13,12 +13,14 @@ from typing import Dict, List, Tuple, Any, Union
__all__ = [
"compute_batch_segmentation_f1_metrics", "compute_batch_segmentation_average_precision_metrics",
"compute_batch_segmentation_tp_fp_fn",
"compute_segmentation_f1_metrics", "compute_segmentation_average_precision_metrics",
"compute_confusion_matrix", "compute_f1_scores", "compute_average_precision_score"
"compute_segmentation_tp_fp_fn",
"compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score"
]
def compute_f1_scores(
def compute_f1_score(
true_positives: int,
false_positives: int,
false_negatives: int
@ -103,7 +105,7 @@ def compute_confusion_matrix(
return true_positive_count, false_positive_count, false_negative_count
def compute_segmentation_f1_metrics(
def compute_segmentation_tp_fp_fn(
ground_truth_mask: np.ndarray,
predicted_mask: np.ndarray,
iou_threshold: float = 0.5,
@ -111,36 +113,32 @@ def compute_segmentation_f1_metrics(
remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]:
"""
Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image.
Computes TP, FP and FN for segmentation on a single image.
If the input masks have shape (H, W), they are expanded to (H, W, 1).
For multi-channel inputs (H, W, C), each channel is processed independently, and the returned
metrics (precision, recall, f1_score, TP, FP, FN) are provided as NumPy arrays with shape (C,).
If the input masks have shape (H, W), they are expanded to (1, H, W).
For multi-channel inputs (C, H, W), each channel is processed independently, and the returned
metrics (TP, FP, FN) are provided as NumPy arrays with shape (C,).
Optionally, if return_error_masks is True, binary error masks for true positives, false positives,
and false negatives are also returned with shape (H, W, C).
and false negatives are also returned with shape (C, H, W).
Args:
ground_truth_mask: Ground truth segmentation mask (HxW or HxWxC).
predicted_mask: Predicted segmentation mask (HxW or HxWxC).
ground_truth_mask: Ground truth segmentation mask (HxW or CxHxW).
predicted_mask: Predicted segmentation mask (HxW or CxHxW).
iou_threshold: IoU threshold for matching objects.
return_error_masks: Whether to also return binary error masks.
remove_boundary_objects: Whether to remove objects that touch the image boundary.
Returns:
A dictionary with the following keys:
- 'precision', 'recall', 'f1_score': arrays of shape (C,)
- 'tp', 'fp', 'fn': arrays of shape (C,)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (H, W, C)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (C, H, W)
"""
# If the masks are 2D, add a singleton channel dimension.
ground_truth_mask = _ensure_ndim(ground_truth_mask, 3, insert_position=-1)
predicted_mask = _ensure_ndim(predicted_mask, 3, insert_position=-1)
ground_truth_mask = _ensure_ndim(ground_truth_mask, 3, insert_position=0)
predicted_mask = _ensure_ndim(predicted_mask, 3, insert_position=0)
num_channels = ground_truth_mask.shape[-1]
precision_list = []
recall_list = []
f1_score_list = []
num_channels = ground_truth_mask.shape[0]
true_positive_list = []
false_positive_list = []
false_negative_list = []
@ -151,8 +149,8 @@ def compute_segmentation_f1_metrics(
# Process each channel independently.
for channel in range(num_channels):
channel_ground_truth = ground_truth_mask[..., channel]
channel_prediction = predicted_mask[..., channel]
channel_ground_truth = ground_truth_mask[channel, ...]
channel_prediction = predicted_mask[channel, ...]
if np.prod(channel_ground_truth.shape) < (5000 * 5000):
results = _process_instance_matching(
channel_ground_truth, channel_prediction, iou_threshold,
@ -166,12 +164,7 @@ def compute_segmentation_f1_metrics(
tp = results['tp']
fp = results['fp']
fn = results['fn']
precision, recall, f1_score = compute_f1_scores(
tp, fp, fn # type: ignore
)
precision_list.append(precision)
recall_list.append(recall)
f1_score_list.append(f1_score)
true_positive_list.append(tp)
false_positive_list.append(fp)
false_negative_list.append(fn)
@ -181,17 +174,75 @@ def compute_segmentation_f1_metrics(
false_negative_mask_list.append(results.get('fn_mask')) # type: ignore
output: Dict[str, np.ndarray] = {
'precision': np.array(precision_list),
'recall': np.array(recall_list),
'f1_score': np.array(f1_score_list),
'tp': np.array(true_positive_list),
'fp': np.array(false_positive_list),
'fn': np.array(false_negative_list)
}
if return_error_masks:
output['tp_mask'] = np.stack(true_positive_mask_list, axis=-1) # type: ignore
output['fp_mask'] = np.stack(false_positive_mask_list, axis=-1) # type: ignore
output['fn_mask'] = np.stack(false_negative_mask_list, axis=-1) # type: ignore
output['tp_mask'] = np.stack(true_positive_mask_list, axis=0) # type: ignore
output['fp_mask'] = np.stack(false_positive_mask_list, axis=0) # type: ignore
output['fn_mask'] = np.stack(false_negative_mask_list, axis=0) # type: ignore
return output
def compute_segmentation_f1_metrics(
ground_truth_mask: np.ndarray,
predicted_mask: np.ndarray,
iou_threshold: float = 0.5,
return_error_masks: bool = False,
remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]:
"""
Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image.
If the input masks have shape (H, W), they are expanded to (1, H, W).
For multi-channel inputs (C, H, W), each channel is processed independently, and the returned
metrics (precision, recall, f1_score, TP, FP, FN) are provided as NumPy arrays with shape (C,).
Optionally, if return_error_masks is True, binary error masks for true positives, false positives,
and false negatives are also returned with shape (C, H, W).
Args:
ground_truth_mask: Ground truth segmentation mask (HxW or CxHxW).
predicted_mask: Predicted segmentation mask (HxW or CxHxW).
iou_threshold: IoU threshold for matching objects.
return_error_masks: Whether to also return binary error masks.
remove_boundary_objects: Whether to remove objects that touch the image boundary.
Returns:
A dictionary with the following keys:
- 'precision', 'recall', 'f1_score': arrays of shape (C,)
- 'tp', 'fp', 'fn': arrays of shape (C,)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (C, H, W)
"""
num_channels = ground_truth_mask.shape[0]
precision_list = []
recall_list = []
f1_score_list = []
results = compute_segmentation_tp_fp_fn(
ground_truth_mask, predicted_mask,
iou_threshold, return_error_masks,
remove_boundary_objects
)
# Process each channel independently.
for channel in range(num_channels):
tp = results['tp'][channel]
fp = results['fp'][channel]
fn = results['fn'][channel]
precision, recall, f1_score = compute_f1_score(
tp, fp, fn # type: ignore
)
precision_list.append(precision)
recall_list.append(recall)
f1_score_list.append(f1_score)
output: Dict[str, np.ndarray] = {
'precision': np.array(precision_list),
'recall': np.array(recall_list),
'f1_score': np.array(f1_score_list),
}
output.update(results)
return output
@ -205,16 +256,16 @@ def compute_segmentation_average_precision_metrics(
"""
Computes the average precision (AP) for segmentation on a single image.
If the input masks have shape (H, W), they are expanded to (H, W, 1).
For multi-channel inputs (H, W, C), each channel is processed independently and the returned
If the input masks have shape (H, W), they are expanded to (1, H, W).
For multi-channel inputs (C, H, W), each channel is processed independently and the returned
metrics (average precision, TP, FP, FN) are provided as NumPy arrays with shape (C,).
Optionally, if return_error_masks is True, binary error masks for true positives, false positives,
and false negatives are also returned with shape (H, W, C).
and false negatives are also returned with shape (C, H, W).
Args:
ground_truth_mask: Ground truth segmentation mask (HxW or HxWxC).
predicted_mask: Predicted segmentation mask (HxW or HxWxC).
ground_truth_mask: Ground truth segmentation mask (HxW or CxHxW).
predicted_mask: Predicted segmentation mask (HxW or CxHxW).
iou_threshold: IoU threshold for matching objects.
return_error_masks: Whether to also return binary error masks.
remove_boundary_objects: Whether to remove objects that touch the image boundary.
@ -223,62 +274,99 @@ def compute_segmentation_average_precision_metrics(
A dictionary with the following keys:
- 'avg_precision': array of shape (C,)
- 'tp', 'fp', 'fn': arrays of shape (C,)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (H, W, C)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (C, H, W)
"""
ground_truth_mask = _ensure_ndim(ground_truth_mask, 3, insert_position=-1)
predicted_mask = _ensure_ndim(predicted_mask, 3, insert_position=-1)
num_channels = ground_truth_mask.shape[-1]
num_channels = ground_truth_mask.shape[0]
avg_precision_list = []
true_positive_list = []
false_positive_list = []
false_negative_list = []
if return_error_masks:
true_positive_mask_list = []
false_positive_mask_list = []
false_negative_mask_list = []
results = compute_segmentation_tp_fp_fn(
ground_truth_mask, predicted_mask,
iou_threshold, return_error_masks,
remove_boundary_objects
)
# Process each channel independently.
for channel in range(num_channels):
channel_ground_truth = ground_truth_mask[..., channel]
channel_prediction = predicted_mask[..., channel]
if np.prod(channel_ground_truth.shape) < (5000 * 5000):
results = _process_instance_matching(
channel_ground_truth, channel_prediction,
iou_threshold,
return_masks=return_error_masks, without_boundary_objects=remove_boundary_objects
)
else:
results = _compute_patch_based_metrics(
channel_ground_truth, channel_prediction,
iou_threshold,
return_masks=return_error_masks, without_boundary_objects=remove_boundary_objects
)
tp = results['tp']
fp = results['fp']
fn = results['fn']
tp = results['tp'][channel]
fp = results['fp'][channel]
fn = results['fn'][channel]
avg_precision = compute_average_precision_score(
tp, fp, fn # type: ignore
)
avg_precision_list.append(avg_precision)
true_positive_list.append(tp)
false_positive_list.append(fp)
false_negative_list.append(fn)
output: Dict[str, np.ndarray] = {
'avg_precision': np.array(avg_precision_list)
}
output.update(results)
return output
def compute_batch_segmentation_tp_fp_fn(
batch_ground_truth: np.ndarray,
batch_prediction: np.ndarray,
iou_threshold: float = 0.5,
return_error_masks: bool = False,
remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]:
"""
Computes segmentation TP, FP and FN for a batch of images.
Expects inputs with shape (B, C, H, W). For each image in the batch, the data is extracted
into (C, H, W) and then processed using compute_segmentation_tp_fp_fn. The results are stacked
so that each metric has shape (B, C). If error masks are returned, their shape will be (B, C, H, W).
Args:
batch_ground_truth: Batch of ground truth masks (BxCxHxW).
batch_prediction: Batch of predicted masks (BxCxHxW).
iou_threshold: IoU threshold for matching objects.
return_error_masks: Whether to also return binary error masks.
remove_boundary_objects: Whether to remove objects that touch the image boundary.
Returns:
A dictionary with keys:
- 'tp', 'fp', 'fn': arrays of shape (B, C)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, C, H, W)
"""
batch_ground_truth = _ensure_ndim(batch_ground_truth, 4, insert_position=0)
batch_prediction = _ensure_ndim(batch_prediction, 4, insert_position=0)
batch_size = batch_ground_truth.shape[0]
tp_list = []
fp_list = []
fn_list = []
if return_error_masks:
tp_mask_list = []
fp_mask_list = []
fn_mask_list = []
for i in range(batch_size):
image_ground_truth = batch_ground_truth[i]
image_prediction = batch_prediction[i]
result = compute_segmentation_tp_fp_fn(
image_ground_truth,
image_prediction,
iou_threshold,
return_error_masks,
remove_boundary_objects
)
tp_list.append(result['tp'])
fp_list.append(result['fp'])
fn_list.append(result['fn'])
if return_error_masks:
true_positive_mask_list.append(results.get('tp_mask')) # type: ignore
false_positive_mask_list.append(results.get('fp_mask')) # type: ignore
false_negative_mask_list.append(results.get('fn_mask')) # type: ignore
tp_mask_list.append(result.get('tp_mask')) # type: ignore
fp_mask_list.append(result.get('fp_mask')) # type: ignore
fn_mask_list.append(result.get('fn_mask')) # type: ignore
output: Dict[str, np.ndarray] = {
'avg_precision': np.array(avg_precision_list),
'tp': np.array(true_positive_list),
'fp': np.array(false_positive_list),
'fn': np.array(false_negative_list)
'tp': np.stack(tp_list, axis=0),
'fp': np.stack(fp_list, axis=0),
'fn': np.stack(fn_list, axis=0)
}
if return_error_masks:
output['tp_mask'] = np.stack(true_positive_mask_list, axis=-1) # type: ignore
output['fp_mask'] = np.stack(false_positive_mask_list, axis=-1) # type: ignore
output['fn_mask'] = np.stack(false_negative_mask_list, axis=-1) # type: ignore
output['tp_mask'] = np.stack(tp_mask_list, axis=0) # type: ignore
output['fp_mask'] = np.stack(fp_mask_list, axis=0) # type: ignore
output['fn_mask'] = np.stack(fn_mask_list, axis=0) # type: ignore
return output
@ -292,9 +380,9 @@ def compute_batch_segmentation_f1_metrics(
"""
Computes segmentation F1 metrics for a batch of images.
Expects inputs with shape (B, C, H, W). For each image in the batch, the data is transposed
to (H, W, C) and then processed with compute_segmentation_f1_metrics. The results are stacked
so that each metric has shape (B, C). If error masks are returned, their shape will be (B, H, W, C).
Expects inputs with shape (B, C, H, W). For each image in the batch, the data is extracted
into (C, H, W) and then processed using compute_segmentation_f1_metrics. The results are stacked
so that each metric has shape (B, C). If error masks are returned, their shape will be (B, C, H, W).
Args:
batch_ground_truth: Batch of ground truth masks (BxCxHxW).
@ -306,10 +394,10 @@ def compute_batch_segmentation_f1_metrics(
Returns:
A dictionary with keys:
- 'precision', 'recall', 'f1_score', 'tp', 'fp', 'fn': arrays of shape (B, C)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, H, W, C)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, C, H, W)
"""
batch_ground_truth = _ensure_ndim(batch_ground_truth, 4)
batch_prediction = _ensure_ndim(batch_prediction, 4)
batch_ground_truth = _ensure_ndim(batch_ground_truth, 4, insert_position=0)
batch_prediction = _ensure_ndim(batch_prediction, 4, insert_position=0)
batch_size = batch_ground_truth.shape[0]
precision_list = []
@ -324,9 +412,8 @@ def compute_batch_segmentation_f1_metrics(
fn_mask_list = []
for i in range(batch_size):
# Each image is expected to have shape (C, H, W); transpose to (H, W, C)
image_ground_truth = np.transpose(batch_ground_truth[i], (1, 2, 0))
image_prediction = np.transpose(batch_prediction[i], (1, 2, 0))
image_ground_truth = batch_ground_truth[i]
image_prediction = batch_prediction[i]
result = compute_segmentation_f1_metrics(
image_ground_truth,
image_prediction,
@ -370,9 +457,9 @@ def compute_batch_segmentation_average_precision_metrics(
"""
Computes segmentation average precision metrics for a batch of images.
Expects inputs with shape (B, C, H, W). For each image in the batch, the data is transposed
to (H, W, C) and then processed with compute_segmentation_average_precision_metrics. The results are stacked
so that each metric has shape (B, C). If error masks are returned, their shape will be (B, H, W, C).
Expects inputs with shape (B, C, H, W). For each image in the batch, the data is extracted
into (C, H, W) and then processed using compute_segmentation_average_precision_metrics. The results are stacked
so that each metric has shape (B, C). If error masks are returned, their shape will be (B, C, H, W).
Args:
batch_ground_truth: Batch of ground truth masks (BxCxHxW).
@ -384,10 +471,10 @@ def compute_batch_segmentation_average_precision_metrics(
Returns:
A dictionary with keys:
- 'avg_precision', 'tp', 'fp', 'fn': arrays of shape (B, C)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, H, W, C)
- If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, C, H, W)
"""
batch_ground_truth = _ensure_ndim(batch_ground_truth, 4)
batch_prediction = _ensure_ndim(batch_prediction, 4)
batch_ground_truth = _ensure_ndim(batch_ground_truth, 4, insert_position=0)
batch_prediction = _ensure_ndim(batch_prediction, 4, insert_position=0)
batch_size = batch_ground_truth.shape[0]
avg_precision_list = []
@ -400,10 +487,14 @@ def compute_batch_segmentation_average_precision_metrics(
fn_mask_list = []
for i in range(batch_size):
ground_truth_mask = np.transpose(batch_ground_truth[i], (1, 2, 0))
prediction_mask = np.transpose(batch_prediction[i], (1, 2, 0))
ground_truth_mask = batch_ground_truth[i]
prediction_mask = batch_prediction[i]
result = compute_segmentation_average_precision_metrics(
ground_truth_mask, prediction_mask, iou_threshold, return_error_masks, remove_boundary_objects
ground_truth_mask,
prediction_mask,
iou_threshold,
return_error_masks,
remove_boundary_objects
)
avg_precision_list.append(result['avg_precision'])
tp_list.append(result['tp'])

Loading…
Cancel
Save