From 5d984dc7a99942d67b7e897b0b2c503639b0d399 Mon Sep 17 00:00:00 2001 From: laynholt Date: Tue, 6 May 2025 02:30:48 +0000 Subject: [PATCH] The methods of training, testing and prediction are implemented --- config/dataset_config.py | 2 +- core/data/transforms/__init__.py | 8 +- core/losses/mse_with_bce.py | 4 +- core/models/model_v.py | 14 +- core/segmentator.py | 899 ++++++++++++++++++++++++++++++- core/utils/measures.py | 285 ++++++---- 6 files changed, 1084 insertions(+), 128 deletions(-) diff --git a/config/dataset_config.py b/config/dataset_config.py index 8aa83ae..6c62103 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -10,6 +10,7 @@ class DatasetCommonConfig(BaseModel): seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda') use_tta: bool = False # Flag to use Test-Time Augmentation (TTA) + use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) predictions_dir: str = "." # Directory to save predictions @model_validator(mode="after") @@ -70,7 +71,6 @@ class DatasetTrainingConfig(BaseModel): batch_size: int = 1 # Batch size for training num_epochs: int = 100 # Number of training epochs val_freq: int = 1 # Frequency of validation during training - use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) pretrained_weights: str = "" # Path to pretrained weights for training diff --git a/core/data/transforms/__init__.py b/core/data/transforms/__init__.py index 9cd2b06..9fa9255 100644 --- a/core/data/transforms/__init__.py +++ b/core/data/transforms/__init__.py @@ -133,8 +133,8 @@ def get_test_transforms(): """ test_transforms = Compose( [ - # Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys). - CustomLoadImaged(keys=["image", "mask"], allow_missing_keys=True, image_only=True), + # Load image and label data in (H, W, C) format (allow missing keys). + CustomLoadImaged(keys=["image", "mask"], allow_missing_keys=True, image_only=False), # Normalize the (H, W, C) image using the specified percentiles. CustomNormalizeImaged( keys=["image"], @@ -169,8 +169,8 @@ def get_predict_transforms(): """ pred_transforms = Compose( [ - # Load the image data in (H, W, C) format (image loaded as image-only). - CustomLoadImage(image_only=True), + # Load the image data in (H, W, C) format. + CustomLoadImage(image_only=False), # Normalize the (H, W, C) image using the specified percentiles. CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]), # Ensure the image is in channel-first format. diff --git a/core/losses/mse_with_bce.py b/core/losses/mse_with_bce.py index 90ecb44..acd8ee9 100644 --- a/core/losses/mse_with_bce.py +++ b/core/losses/mse_with_bce.py @@ -84,12 +84,12 @@ class BCE_MSE_Loss(BaseLoss): # Cell Recognition Loss cellprob_loss = self.bce_loss( - outputs[:, -self.num_classes:], target[:, self.num_classes:2 * self.num_classes].float() + outputs[:, -self.num_classes:], (target[:, -self.num_classes:] > 0).float() ) # Cell Distinction Loss gradflow_loss = 0.5 * self.mse_loss( - outputs[:, :2 * self.num_classes], 5.0 * target[:, 2 * self.num_classes:] + outputs[:, :2 * self.num_classes], 5.0 * target[:, :2 * self.num_classes] ) # Total loss diff --git a/core/models/model_v.py b/core/models/model_v.py index a525daf..d5bb664 100644 --- a/core/models/model_v.py +++ b/core/models/model_v.py @@ -19,7 +19,7 @@ class ModelVParams(BaseModel): decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels in_channels: int = 3 # Number of input channels - out_classes: int = 3 # Number of output classes + out_classes: int = 1 # Number of output classes def asdict(self): """ @@ -38,6 +38,8 @@ class ModelV(MAnet): def __init__(self, params: ModelVParams) -> None: # Initialize the MAnet model with provided parameters super().__init__(**params.asdict()) + + self.num_classes = params.out_classes # Remove the default segmentation head as it's not used in this architecture self.segmentation_head = None @@ -53,12 +55,6 @@ class ModelV(MAnet): self.gradflow_head = DeepSegmentationHead( in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes ) - - # self.gradflow_head = nn.ModuleList([ - # DeepSegmentationHead( - # in_channels=params.decoder_channels[-1], out_channels=2 - # ) for _ in range(params.out_classes) - # ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -74,10 +70,6 @@ class ModelV(MAnet): cellprob_mask = self.cellprob_head(decoder_output) gradflow_mask = self.gradflow_head(decoder_output) - # gradflow_masks = torch.cat( - # [head(decoder_output) for head in self.flow_heads], dim=1 # [B, 2*C, H, W] - # ) - # Concatenate the masks for output masks = torch.cat((gradflow_mask, cellprob_mask), dim=1) diff --git a/core/segmentator.py b/core/segmentator.py index 461b98e..f84795f 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -1,4 +1,3 @@ -import torch import random import numpy as np from numba import njit, prange @@ -10,23 +9,41 @@ from torch.utils.data import DataLoader import fastremap +import fill_voids from skimage import morphology +from skimage.segmentation import find_boundaries from scipy.ndimage import mean, find_objects -import fill_voids from monai.data.dataset import Dataset from monai.transforms import * # type: ignore +from monai.inferers.utils import sliding_window_inference +from monai.metrics.cumulative_average import CumulativeAverage + +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors import os import glob +import copy +import tifffile as tiff + from pprint import pformat -from typing import Optional, Tuple, List, Union +from tabulate import tabulate +from typing import Any, Dict, Literal, Optional, Tuple, List, Union + +from tqdm import tqdm +import wandb from config import Config from core.models import * from core.losses import * from core.optimizers import * from core.schedulers import * +from core.utils import ( + compute_batch_segmentation_tp_fp_fn, + compute_f1_score, + compute_average_precision_score +) from core.logger import get_logger @@ -40,6 +57,11 @@ class CellSegmentator: self.__parse_config(config) self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu") + self._scaler = ( + torch.amp.GradScaler(self._device.type) # type: ignore + if self._dataset_setup.is_training and self._dataset_setup.common.use_amp + else None + ) self._train_dataloader: Optional[DataLoader] = None self._valid_dataloader: Optional[DataLoader] = None @@ -214,17 +236,237 @@ class CellSegmentator: self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False) logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.") + + def print_data_info( + self, + loader_type: Literal["train", "valid", "test", "predict"], + index: Optional[int] = None + ) -> None: + """ + Prints statistics for a single sample from the specified dataloader. + + Args: + loader_type: One of "train", "valid", "test", "predict". + index: The sample index; if None, a random index is selected. + """ + # Retrieve the dataloader attribute, e.g., self._train_dataloader + loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None) + if loader is None: + logger.error(f"Dataloader '{loader_type}' is not initialized.") + return + + dataset = loader.dataset + total = len(dataset) # type: ignore + if total == 0: + logger.error(f"Dataset for '{loader_type}' is empty.") + return + + # Choose index + idx = index if index is not None else random.randint(0, total - 1) + if not (0 <= idx < total): + logger.error(f"Index {idx} is out of range [0, {total}).") + return + + # Fetch the sample and apply transforms + sample = dataset[idx] + # Expecting a dict with {'image': ..., 'mask': ...} or {'image': ...} + img = sample["image"] + # Convert tensor to numpy if needed + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + img = np.asarray(img) + + # Compute image statistics + img_min, img_max = img.min(), img.max() + img_mean, img_std = float(img.mean()), float(img.std()) + img_shape = img.shape + + # Prepare log lines + lines = [ + "=" * 40, + f"Dataloader: '{loader_type}', sample index: {idx} / {total - 1}", + f"Image — shape: {img_shape}, min: {img_min:.4f}, max: {img_max:.4f}, mean: {img_mean:.4f}, std: {img_std:.4f}" + ] + + # For 'predict', no mask is available + if loader_type != "predict": + mask = sample.get("mask", None) + if mask is not None: + # Convert tensor to numpy if needed + if isinstance(mask, torch.Tensor): + mask = mask.cpu().numpy() + mask = np.asarray(mask) + m_min, m_max = mask.min(), mask.max() + m_mean, m_std = float(mask.mean()), float(mask.std()) + m_shape = mask.shape + lines.append( + f"Mask — shape: {m_shape}, min: {m_min:.4f}, " + f"max: {m_max:.4f}, mean: {m_mean:.4f}, std: {m_std:.4f}" + ) + else: + lines.append("Mask — not available for this sample.") + + lines.append("=" * 40) + + # Output via logger + for l in lines: + logger.info(l) + def train(self) -> None: - pass + """ + Train the model over multiple epochs, including validation and test. + """ + # Ensure training is enabled in dataset setup + if not self._dataset_setup.is_training: + raise RuntimeError("Dataset setup indicates training is disabled.") + + # Determine device name for logging + if self._device.type == "cpu": + device_name = "cpu" + else: + idx = self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device() + device_name = torch.cuda.get_device_name(idx) + + logger.info(f"\n{'=' * 50}\n") + logger.info(f"Training starts on device: {device_name}") + logger.info(f"\n{'=' * 50}") + + best_f1_score = 0.0 + best_weights = None + + for epoch in range(1, self._dataset_setup.training.num_epochs + 1): + train_metrics = self.__run_epoch("train", epoch) + self.__print_with_logging(train_metrics, epoch) + + # Step the scheduler after training + if self._scheduler is not None: + self._scheduler.step() + + # Periodic validation or tuning + if epoch % self._dataset_setup.training.val_freq == 0: + if self._valid_dataloader is not None: + valid_metrics = self.__run_epoch("valid", epoch) + self.__print_with_logging(valid_metrics, epoch) + + # Update best model on improved F1 + f1 = valid_metrics.get("valid_f1_score", 0.0) + if f1 > best_f1_score: + best_f1_score = f1 + # Deep copy weights to avoid reference issues + best_weights = copy.deepcopy(self._model.state_dict()) + logger.info(f"Updated best model weights with F1 score: {f1:.4f}") + + # Restore best model weights if available + if best_weights is not None: + self._model.load_state_dict(best_weights) + + if self._test_dataloader is not None: + test_metrics = self.__run_epoch("test") + self.__print_with_logging(test_metrics, 0) def evaluate(self) -> None: - pass + """ + Run a full test epoch and display/log the resulting metrics. + """ + test_metrics = self.__run_epoch("test") + self.__print_with_logging(test_metrics, 0) def predict(self) -> None: - pass + """ + Run inference on the predict set and save the resulting instance masks. + """ + # Ensure the predict DataLoader has been set + if self._predict_dataloader is None: + raise RuntimeError("DataLoader for mode 'predict' is not set.") + + batch_counter = 0 + for batch in tqdm(self._predict_dataloader, desc="Predicting"): + # Move input images to the configured device (CPU/GPU) + inputs = batch["img"].to(self._device) + + # Use automatic mixed precision if enabled in dataset setup + with torch.amp.autocast( # type: ignore + self._device.type, + enabled=self._dataset_setup.common.use_amp + ): + # Disable gradient computation for inference + with torch.no_grad(): + # Run the model’s forward pass in ‘predict’ mode + raw_output = self.__run_inference(inputs, mode="predict") + + # Convert logits/probabilities to discrete instance masks + # ground_truth is not passed here; only predictions are needed + preds, _ = self.__post_process_predictions(raw_output) + + # Save out the predicted masks, using batch_counter to index files + self.__save_prediction_masks(batch, preds, batch_counter) + + # Increment counter by batch size for unique file naming + batch_counter += inputs.shape[0] + + + def run(self) -> None: + """ + Orchestrate the full workflow: + - If training is enabled in the dataset setup, start training. + - Otherwise, if a test DataLoader is provided, run evaluation. + - Else if a prediction DataLoader is provided, run inference/prediction. + - If neither loader is available in non‐training mode, raise an error. + """ + # 1) TRAINING PATH + if self._dataset_setup.is_training: + # Launch the full training loop (with validation, scheduler steps, etc.) + self.train() + return + + # 2) NON-TRAINING PATH (TEST or PREDICT) + # Prefer test if available + if self._test_dataloader is not None: + # Run a single evaluation epoch on the test set and log metrics + self.evaluate() + return + + # If no test loader, fall back to prediction if available + if self._predict_dataloader is not None: + # Run inference on the predict set and save outputs + self.predict() + return + + # 3) ERROR: no appropriate loader found + raise RuntimeError( + "Neither test nor predict DataLoader is set for non‐training mode." + ) + + + def load_from_checkpoint(self, checkpoint_path: str) -> None: + """ + Loads model weights from a specified checkpoint into the current model. + + Args: + checkpoint_path (str): Path to the checkpoint file containing the model weights. + """ + # Load the checkpoint onto the correct device (CPU or GPU) + checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True) + # Load the state dict into the model, allowing for missing keys + self._model.load_state_dict(checkpoint['state_dict'], strict=False) + + + def save_checkpoint(self, checkpoint_path: str) -> None: + """ + Saves the current model weights to a checkpoint file. + + Args: + checkpoint_path (str): Path where the checkpoint file will be saved. + """ + # Create a checkpoint dictionary containing the model’s state_dict + checkpoint = { + 'state_dict': self._model.state_dict() + } + # Write the checkpoint to disk + torch.save(checkpoint, checkpoint_path) def __parse_config(self, config: Config) -> None: @@ -255,12 +497,26 @@ class CellSegmentator: # Initialize model using the model registry self._model = ModelRegistry.get_model_class(model.name)(model.params) + # Loads model weights from a specified checkpoint + if config.dataset_config.is_training: + if config.dataset_config.training.pretrained_weights: + self.load_from_checkpoint(config.dataset_config.training.pretrained_weights) + # Initialize loss criterion if specified self._criterion = ( CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params) if criterion is not None else None ) + + if hasattr(self._criterion, "num_classes"): + nc_model = self._model.num_classes + nc_crit = getattr(self._criterion, "num_classes") + if nc_model != nc_crit: + raise ValueError( + f"Number of classes mismatch: model.num_classes={nc_model} " + f"but criterion.num_classes={nc_crit}" + ) # Initialize optimizer if specified self._optimizer = ( @@ -298,6 +554,7 @@ class CellSegmentator: logger.info("[COMMON]") logger.info(f"├─ Seed: {common.seed}") logger.info(f"├─ Device: {common.device}") + logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}") logger.info(f"└─ Predictions output dir: {common.predictions_dir}") if config.dataset_config.is_training: @@ -306,7 +563,6 @@ class CellSegmentator: logger.info(f"├─ Batch size: {training.batch_size}") logger.info(f"├─ Epochs: {training.num_epochs}") logger.info(f"├─ Validation frequency: {training.val_freq}") - logger.info(f"├─ Use AMP: {'yes' if training.use_amp else 'no'}") logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}") if training.is_split: @@ -433,6 +689,624 @@ class CellSegmentator: return Dataset(data, transforms) + + def __print_with_logging(self, results: Dict[str, float], step: int) -> None: + """ + Print metrics in a tabular format and log to W&B. + + Args: + results (Dict[str, float]): results dictionary. + step (int): epoch index. + """ + table = tabulate( + tabular_data=results.items(), + headers=["Metric", "Value"], + floatfmt=".4f", + tablefmt="fancy_grid" + ) + print(table, "\n") + wandb.log(results, step=step) + + + def __run_epoch(self, + mode: Literal["train", "valid", "test"], + epoch: Optional[int] = None + ) -> Dict[str, float]: + """ + Execute one epoch of training, validation, or testing. + + Args: + mode (str): One of 'train', 'valid', or 'test'. + epoch (int, optional): Current epoch number for logging. + + Returns: + Dict[str, float]: Loss metrics and F1 score for valid/test. + """ + # Ensure required components are available + if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None): + raise RuntimeError("Optimizer and loss function must be initialized for train/valid.") + + # Set model mode and choose the appropriate data loader + if mode == "train": + self._model.train() + loader = self._train_dataloader + else: + self._model.eval() + loader = self._valid_dataloader if mode == "valid" else self._test_dataloader + + # Check that the data loader is provided + if loader is None: + raise RuntimeError(f"DataLoader for mode '{mode}' is not set.") + + all_tp, all_fp, all_fn = [], [], [] + + # Prepare tqdm description with epoch info if available + if epoch is not None: + desc = f"Epoch {epoch}/{self._dataset_setup.training.num_epochs} [{mode}]" + else: + desc = f"Epoch ({mode})" + + # Iterate over batches + batch_counter = 0 + for batch in tqdm(loader, desc=desc): + inputs = batch["img"].to(self._device) + targets = batch["label"].to(self._device) + + # Zero gradients for training + if self._optimizer is not None: + self._optimizer.zero_grad() + + # Mixed precision context if enabled + with torch.amp.autocast( # type: ignore + self._device.type, + enabled=self._dataset_setup.common.use_amp + ): + # Only compute gradients in training phase + with torch.set_grad_enabled(mode == "train"): + # Forward pass + raw_output = self.__run_inference(inputs, mode) + + if self._criterion is not None: + # Convert label masks to flow representations (one-hot) + flow_targets = self.__compute_flows_from_masks(targets) + + # Compute loss for this batch + batch_loss = self._criterion(raw_output, flow_targets) # type: ignore + + # Post-process and compute F1 during validation and testing + if mode in ("valid", "test"): + preds, labels_post = self.__post_process_predictions( + raw_output, targets + ) + + # Collecting statistics on the batch + tp, fp, fn = self.__compute_stats( + predicted_masks=preds, + ground_truth_masks=labels_post, # type: ignore + iou_threshold=0.5 + ) + all_tp.append(tp) + all_fp.append(fp) + all_fn.append(fn) + + if mode == "test": + self.__save_prediction_masks( + batch, preds, batch_counter + ) + + # Backpropagation and optimizer step in training + if mode == "train": + if self._dataset_setup.common.use_amp and self._scaler is not None: + self._scaler.scale(batch_loss).backward() # type: ignore + self._scaler.unscale_(self._optimizer.optim) # type: ignore + self._scaler.step(self._optimizer.optim) # type: ignore + self._scaler.update() + else: + batch_loss.backward() # type: ignore + if self._optimizer is not None: + self._optimizer.step() + + batch_counter += inputs.shape[0] + + if self._criterion is not None: + # Collect loss metrics + epoch_metrics = {f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()} + # Reset internal loss metrics accumulator + self._criterion.reset_metrics() + else: + epoch_metrics = {} + + # Include F1 and mAP for validation and testing + if mode in ("valid", "test"): + # Concatenating by batch: shape (num_batches*B, C) + tp_array = np.vstack(all_tp) + fp_array = np.vstack(all_fp) + fn_array = np.vstack(all_fn) + + epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore + tp_array, fp_array, fn_array, reduction="micro" + ) + epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore + tp_array, fp_array, fn_array, reduction="macro" + ) + + return epoch_metrics + + + def __run_inference( + self, + inputs: torch.Tensor, + mode: Literal["train", "valid", "test", "predict"] = "train" + ) -> torch.Tensor: + """ + Perform model inference for different stages. + + Args: + inputs (torch.Tensor): Input tensor of shape (B, C, H, W). + stage (Literal[...]): One of "train", "valid", "test", "predict". + + Returns: + torch.Tensor: Model outputs tensor. + """ + if mode != "train": + # Use sliding window inference for non-training phases + outputs = sliding_window_inference( + inputs, + roi_size=512, + sw_batch_size=4, + predictor=self._model, + padding_mode="constant", + mode="gaussian", + overlap=0.5, + ) + else: + # Direct forward pass during training + outputs = self._model(inputs) + return outputs # type: ignore + + + def __post_process_predictions( + self, + raw_outputs: torch.Tensor, + ground_truth: Optional[torch.Tensor] = None + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Post-process raw network outputs to extract instance segmentation masks. + + Args: + raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W). + ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W). + + Returns: + Tuple[np.ndarray, Optional[np.ndarray]]: + - instance_masks: Instance-wise masks array of shape (B, С, H, W). + - labels_np: Converted ground truth of shape (B, С, H, W) or None if + ground_truth was not provided. + """ + # Move outputs to CPU and convert to numpy + outputs_np = raw_outputs.cpu().numpy() + # Split channels: gradient flows then class logits + gradflow = outputs_np[:, :2 * self._model.num_classes] + logits = outputs_np[:, -self._model.num_classes :] + # Apply sigmoid to logits to get probabilities + probabilities = self.__sigmoid(logits) + + batch_size, _, height, width = probabilities.shape + # Prepare container for instance masks + instance_masks = np.zeros((batch_size, self._model.num_classes, height, width), dtype=np.uint16) + for idx in range(batch_size): + instance_masks[idx] = self.__segment_instances( + probability_map=probabilities[idx], + flow=gradflow[idx], + prob_threshold=0.0, + flow_threshold=0.4, + min_object_size=15 + ) + + # Convert ground truth to numpy + labels_np = ground_truth.cpu().numpy() if ground_truth is not None else None + return instance_masks, labels_np + + + def __compute_stats( + self, + predicted_masks: np.ndarray, + ground_truth_masks: np.ndarray, + iou_threshold: float = 0.5 + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute batch-wise true positives, false positives, and false negatives + for instance segmentation, using a configurable IoU threshold. + + Args: + predicted_masks (np.ndarray): Predicted instance masks of shape (B, C, H, W). + ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W). + iou_threshold (float): Intersection-over-Union threshold for matching predictions + to ground truths (default: 0.5). + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: + - tp: True positives per batch and class, shape (B, C) + - fp: False positives per batch and class, shape (B, C) + - fn: False negatives per batch and class, shape (B, C) + """ + stats = compute_batch_segmentation_tp_fp_fn( + batch_ground_truth=ground_truth_masks, + batch_prediction=predicted_masks, + iou_threshold=iou_threshold, + remove_boundary_objects=True + ) + tp = stats["tp"] + fp = stats["fp"] + fn = stats["fn"] + return tp, fp, fn + + + def __compute_f1_metric( + self, + true_positives: np.ndarray, + false_positives: np.ndarray, + false_negatives: np.ndarray, + reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro" + ) -> Union[float, np.ndarray]: + """ + Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes. + + Args: + true_positives: array of TP counts per sample and class. + false_positives: array of FP counts per sample and class. + false_negatives: array of FN counts per sample and class. + reduction: + - 'none': return F1 for each sample, class → shape (batch_size, num_classes) + - 'micro': global F1 over all samples & classes + - 'imagewise': F1 per sample (summing over classes), then average over samples + - 'macro': average class-wise F1 (classes summed over batch) + - 'weighted': class-wise F1 weighted by support (TP+FN) + Returns: + float for reductions 'micro', 'imagewise', 'macro', 'weighted'; + or np.ndarray of shape (batch_size, num_classes) if reduction='none'. + """ + batch_size, num_classes = true_positives.shape + + # 1) No reduction: compute F1 for each (sample, class) + if reduction == "none": + f1_matrix = np.zeros((batch_size, num_classes), dtype=float) + for i in range(batch_size): + for c in range(num_classes): + tp_val = int(true_positives[i, c]) + fp_val = int(false_positives[i, c]) + fn_val = int(false_negatives[i, c]) + _, _, f1_val = compute_f1_score(tp_val, fp_val, fn_val) + f1_matrix[i, c] = f1_val + return f1_matrix + + # 2) Micro: sum all TP/FP/FN and compute a single F1 + if reduction == "micro": + tp_total = int(true_positives.sum()) + fp_total = int(false_positives.sum()) + fn_total = int(false_negatives.sum()) + _, _, f1_global = compute_f1_score(tp_total, fp_total, fn_total) + return f1_global + + # 3) Imagewise: compute per-sample F1 (sum over classes), then average + if reduction == "imagewise": + f1_per_image = np.zeros(batch_size, dtype=float) + for i in range(batch_size): + tp_i = int(true_positives[i].sum()) + fp_i = int(false_positives[i].sum()) + fn_i = int(false_negatives[i].sum()) + _, _, f1_i = compute_f1_score(tp_i, fp_i, fn_i) + f1_per_image[i] = f1_i + return float(f1_per_image.mean()) + + # For macro/weighted, first aggregate per class across the batch + tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,) + fp_per_class = false_positives.sum(axis=0).astype(int) + fn_per_class = false_negatives.sum(axis=0).astype(int) + + # 4) Macro: average F1 across classes equally + if reduction == "macro": + f1_per_class = np.zeros(num_classes, dtype=float) + for c in range(num_classes): + _, _, f1_c = compute_f1_score( + tp_per_class[c], + fp_per_class[c], + fn_per_class[c] + ) + f1_per_class[c] = f1_c + return float(f1_per_class.mean()) + + # 5) Weighted: class-wise F1 weighted by support = TP + FN + if reduction == "weighted": + f1_per_class = np.zeros(num_classes, dtype=float) + support = np.zeros(num_classes, dtype=float) + for c in range(num_classes): + tp_c = tp_per_class[c] + fp_c = fp_per_class[c] + fn_c = fn_per_class[c] + _, _, f1_c = compute_f1_score(tp_c, fp_c, fn_c) + f1_per_class[c] = f1_c + support[c] = tp_c + fn_c + total_support = support.sum() + if total_support == 0: + # fallback to unweighted macro if no positives + return float(f1_per_class.mean()) + weights = support / total_support + return float((f1_per_class * weights).sum()) + + raise ValueError(f"Unknown reduction mode: {reduction}") + + + def __compute_average_precision_metric( + self, + true_positives: np.ndarray, + false_positives: np.ndarray, + false_negatives: np.ndarray, + reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro" + ) -> Union[float, np.ndarray]: + """ + Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes. + + AP is defined here as: + AP = TP / (TP + FP + FN) + + Args: + true_positives: array of true positives per sample and class. + false_positives: array of false positives per sample and class. + false_negatives: array of false negatives per sample and class. + reduction: + - 'none': return AP for each sample and class → shape (batch_size, num_classes) + - 'micro': global AP over all samples & classes + - 'imagewise': AP per sample (summing stats over classes), then average over batch + - 'macro': average class-wise AP (each class summed over batch) + - 'weighted': class-wise AP weighted by support (TP+FN) + + Returns: + float for reductions 'micro', 'imagewise', 'macro', 'weighted'; + or np.ndarray of shape (batch_size, num_classes) if reduction='none'. + """ + batch_size, num_classes = true_positives.shape + + # 1) No reduction: AP per (sample, class) + if reduction == "none": + ap_matrix = np.zeros((batch_size, num_classes), dtype=float) + for i in range(batch_size): + for c in range(num_classes): + tp_val = int(true_positives[i, c]) + fp_val = int(false_positives[i, c]) + fn_val = int(false_negatives[i, c]) + ap_val = compute_average_precision_score(tp_val, fp_val, fn_val) + ap_matrix[i, c] = ap_val + return ap_matrix + + # 2) Micro: sum all TP/FP/FN and compute one AP + if reduction == "micro": + tp_total = int(true_positives.sum()) + fp_total = int(false_positives.sum()) + fn_total = int(false_negatives.sum()) + return compute_average_precision_score(tp_total, fp_total, fn_total) + + # 3) Imagewise: compute per-sample AP (sum over classes), then mean + if reduction == "imagewise": + ap_per_image = np.zeros(batch_size, dtype=float) + for i in range(batch_size): + tp_i = int(true_positives[i].sum()) + fp_i = int(false_positives[i].sum()) + fn_i = int(false_negatives[i].sum()) + ap_per_image[i] = compute_average_precision_score(tp_i, fp_i, fn_i) + return float(ap_per_image.mean()) + + # For macro and weighted: first aggregate per class across batch + tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,) + fp_per_class = false_positives.sum(axis=0).astype(int) + fn_per_class = false_negatives.sum(axis=0).astype(int) + + # 4) Macro: average AP across classes equally + if reduction == "macro": + ap_per_class = np.zeros(num_classes, dtype=float) + for c in range(num_classes): + ap_per_class[c] = compute_average_precision_score( + tp_per_class[c], + fp_per_class[c], + fn_per_class[c] + ) + return float(ap_per_class.mean()) + + # 5) Weighted: class-wise AP weighted by support = TP + FN + if reduction == "weighted": + ap_per_class = np.zeros(num_classes, dtype=float) + support = np.zeros(num_classes, dtype=float) + for c in range(num_classes): + tp_c = tp_per_class[c] + fp_c = fp_per_class[c] + fn_c = fn_per_class[c] + ap_per_class[c] = compute_average_precision_score(tp_c, fp_c, fn_c) + support[c] = tp_c + fn_c + total_support = support.sum() + if total_support == 0: + # fallback to unweighted macro if no positive instances + return float(ap_per_class.mean()) + weights = support / total_support + return float((ap_per_class * weights).sum()) + + raise ValueError(f"Unknown reduction mode: {reduction}") + + + @staticmethod + def __sigmoid(z: np.ndarray) -> np.ndarray: + """ + Numerically stable sigmoid activation for numpy arrays. + + Args: + z (np.ndarray): Input array. + + Returns: + np.ndarray: Sigmoid of the input. + """ + return 1 / (1 + np.exp(-z)) + + + def __save_prediction_masks( + self, + sample: Dict[str, Any], + predicted_mask: Union[np.ndarray, torch.Tensor], + start_index: int = 0, + ) -> None: + """ + Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders. + + Args: + sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict'). + predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W). + start_index (int): Starting index for naming when metadata is missing. + """ + # Determine base paths + base_output_dir = self._dataset_setup.common.predictions_dir + masks_dir = base_output_dir + plots_dir = os.path.join(base_output_dir, "plots") + os.makedirs(masks_dir, exist_ok=True) + os.makedirs(plots_dir, exist_ok=True) + + # Extract image (C, H, W) or batch of images (B, C, H, W), and metadata + image_obj = sample.get("image") # Expected shape: (C, H, W) or (B, C, H, W) + mask_obj = sample.get("mask") # Expected shape: (C, H, W) or (B, C, H, W) + image_meta = sample.get("image_meta_dict") + + # Convert tensors to numpy + def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + if isinstance(x, torch.Tensor): + return x.detach().cpu().numpy() + return x + + image_array = to_numpy(image_obj) if image_obj is not None else None + mask_array = to_numpy(mask_obj) if mask_obj is not None else None + pred_array = to_numpy(predicted_mask) + + # Handle batch dimension: (B, C, H, W) + if pred_array.ndim == 4: + for idx in range(pred_array.shape[0]): + batch_sample = dict(sample) + if image_array is not None and image_array.ndim == 4: + batch_sample["image"] = image_array[idx] + if isinstance(image_meta, list): + batch_sample["image_meta_dict"] = image_meta[idx] + if mask_array is not None and mask_array.ndim == 4: + batch_sample["mask"] = mask_array[idx] + self.__save_prediction_masks( + batch_sample, + pred_array[idx], + start_index=start_index+idx + ) + return + + # Determine base filename + if image_meta and "filename_or_obj" in image_meta: + base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0] + else: + # Use provided start_index when metadata missing + base_name = f"prediction_{start_index:04d}" + + # Now pred_array shape is (C, H, W) + num_channels = pred_array.shape[0] + for channel_idx in range(num_channels): + channel_mask = pred_array[channel_idx] + + # File names + mask_filename = f"{base_name}_ch{channel_idx:02d}.tif" + plot_filename = f"{base_name}_ch{channel_idx:02d}.png" + mask_path = os.path.join(masks_dir, mask_filename) + plot_path = os.path.join(plots_dir, plot_filename) + + # Save mask TIFF (16-bit) + tiff.imwrite(mask_path, channel_mask.astype(np.uint16), compression="zlib") + + # Extract corresponding true mask channel if exists + true_mask = None + if mask_array is not None and mask_array.ndim == 3: + true_mask = mask_array[channel_idx] + + # Generate and save visualization + self.__plot_mask( + file_path=plot_path, + image_data=image_array, # type: ignore + predicted_mask=channel_mask, + true_mask=true_mask, + ) + + + def __plot_mask( + self, + file_path: str, + image_data: np.ndarray, + predicted_mask: np.ndarray, + true_mask: Optional[np.ndarray] = None, + ) -> None: + """ + Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided. + """ + img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data + + if true_mask is None: + fig, axs = plt.subplots(1, 3, figsize=(15,5)) + plt.subplots_adjust(wspace=0.02, hspace=0) + self.__plot_panels(axs, img, predicted_mask, 'red', + ('Original Image','Predicted Mask','Predicted Contours')) + else: + fig, axs = plt.subplots(2,3,figsize=(15,10)) + plt.subplots_adjust(wspace=0.02, hspace=0.1) + # row 0: predicted + self.__plot_panels(axs[0], img, predicted_mask, 'red', + ('Original Image','Predicted Mask','Predicted Contours')) + # row 1: true + self.__plot_panels(axs[1], img, true_mask, 'blue', + ('Original Image','True Mask','True Contours')) + fig.savefig(file_path, bbox_inches='tight', dpi=300) + plt.close(fig) + + + def __plot_panels( + self, + axes, + img: np.ndarray, + mask: np.ndarray, + contour_color: str, + titles: Tuple[str, ...] + ): + """ + Plot a row of three panels: original image, mask, and mask boundaries on image. + + Args: + axes: list/array of three Axis objects. + img: Image array (H, W or H, W, C). + mask: Label mask (H, W). + contour_color: Color for boundaries. + titles: (title_img, title_mask, title_contours). + """ + # Panel 1: Original image + ax0, ax1, ax2 = axes + ax0.imshow(img, cmap='gray' if img.ndim == 2 else None) + ax0.set_title(titles[0]); ax0.axis('off') + + # Compute boundaries once + boundaries = find_boundaries(mask, mode='thick') + + # Panel 2: Mask with black boundaries + cmap = plt.get_cmap("gist_ncar") + cmap = mcolors.ListedColormap([ + cmap(i/len(np.unique(mask))) for i in range(len(np.unique(mask))) + ]) + ax1.imshow(mask, cmap=cmap) + ax1.contour(boundaries, colors='black', linewidths=0.5) + ax1.set_title(titles[1]) + ax1.axis('off') + + # Panel 3: Original image with black contour overlay + ax2.imshow(img) + # Draw boundaries as contour lines + ax2.contour(boundaries, colors=contour_color, linewidths=0.5) + ax2.set_title(titles[2]) + ax2.axis('off') + def __compute_flows_from_masks( self, @@ -445,9 +1319,8 @@ class CellSegmentator: true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks. Returns: - numpy array of concatenated [renumbered_true_masks, binary_masks, flow_vectors] per image. - renumbered_true_masks is labels, binary_masks is cell distance transform, flow_vectors[2] is Y flow, flows[k][3] is X flow, - and flow_vectors[4] is heat distribution. + numpy array of concatenated [flow_vectors, renumbered_true_masks] per image. + renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow. """ # Move to CPU numpy _true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16) @@ -467,7 +1340,7 @@ class CellSegmentator: flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i]) for i in range(batch_size)]) - return np.concatenate((renumbered, renumbered > 0.5, flow_vectors), axis=1).astype(np.float32) + return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32) def __compute_flow_from_mask( @@ -879,7 +1752,7 @@ class CellSegmentator: prob_threshold: float = 0.0, flow_threshold: float = 0.4, num_iters: int = 200, - min_object_size: int = 0 + min_object_size: int = 15 ) -> np.ndarray: """ Generate instance segmentation masks from probability and flow fields. @@ -890,7 +1763,7 @@ class CellSegmentator: prob_threshold: threshold to binarize probability_map. (Default 0.0) flow_threshold: threshold for filtering bad flow masks. (Default 0.4) num_iters: number of iterations for flow-following. (Default 200) - min_object_size: minimum area to keep small instances. (Default 0) + min_object_size: minimum area to keep small instances. (Default 15) Returns: 3D array of uint16 instance labels for each channel. diff --git a/core/utils/measures.py b/core/utils/measures.py index e95bcf7..c157e06 100644 --- a/core/utils/measures.py +++ b/core/utils/measures.py @@ -13,12 +13,14 @@ from typing import Dict, List, Tuple, Any, Union __all__ = [ "compute_batch_segmentation_f1_metrics", "compute_batch_segmentation_average_precision_metrics", + "compute_batch_segmentation_tp_fp_fn", "compute_segmentation_f1_metrics", "compute_segmentation_average_precision_metrics", - "compute_confusion_matrix", "compute_f1_scores", "compute_average_precision_score" + "compute_segmentation_tp_fp_fn", + "compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score" ] -def compute_f1_scores( +def compute_f1_score( true_positives: int, false_positives: int, false_negatives: int @@ -103,7 +105,7 @@ def compute_confusion_matrix( return true_positive_count, false_positive_count, false_negative_count -def compute_segmentation_f1_metrics( +def compute_segmentation_tp_fp_fn( ground_truth_mask: np.ndarray, predicted_mask: np.ndarray, iou_threshold: float = 0.5, @@ -111,36 +113,32 @@ def compute_segmentation_f1_metrics( remove_boundary_objects: bool = True ) -> Dict[str, np.ndarray]: """ - Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image. + Computes TP, FP and FN for segmentation on a single image. - If the input masks have shape (H, W), they are expanded to (H, W, 1). - For multi-channel inputs (H, W, C), each channel is processed independently, and the returned - metrics (precision, recall, f1_score, TP, FP, FN) are provided as NumPy arrays with shape (C,). + If the input masks have shape (H, W), they are expanded to (1, H, W). + For multi-channel inputs (C, H, W), each channel is processed independently, and the returned + metrics (TP, FP, FN) are provided as NumPy arrays with shape (C,). Optionally, if return_error_masks is True, binary error masks for true positives, false positives, - and false negatives are also returned with shape (H, W, C). + and false negatives are also returned with shape (C, H, W). Args: - ground_truth_mask: Ground truth segmentation mask (HxW or HxWxC). - predicted_mask: Predicted segmentation mask (HxW or HxWxC). + ground_truth_mask: Ground truth segmentation mask (HxW or CxHxW). + predicted_mask: Predicted segmentation mask (HxW or CxHxW). iou_threshold: IoU threshold for matching objects. return_error_masks: Whether to also return binary error masks. remove_boundary_objects: Whether to remove objects that touch the image boundary. Returns: A dictionary with the following keys: - - 'precision', 'recall', 'f1_score': arrays of shape (C,) - 'tp', 'fp', 'fn': arrays of shape (C,) - - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (H, W, C) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (C, H, W) """ # If the masks are 2D, add a singleton channel dimension. - ground_truth_mask = _ensure_ndim(ground_truth_mask, 3, insert_position=-1) - predicted_mask = _ensure_ndim(predicted_mask, 3, insert_position=-1) + ground_truth_mask = _ensure_ndim(ground_truth_mask, 3, insert_position=0) + predicted_mask = _ensure_ndim(predicted_mask, 3, insert_position=0) - num_channels = ground_truth_mask.shape[-1] - precision_list = [] - recall_list = [] - f1_score_list = [] + num_channels = ground_truth_mask.shape[0] true_positive_list = [] false_positive_list = [] false_negative_list = [] @@ -151,8 +149,8 @@ def compute_segmentation_f1_metrics( # Process each channel independently. for channel in range(num_channels): - channel_ground_truth = ground_truth_mask[..., channel] - channel_prediction = predicted_mask[..., channel] + channel_ground_truth = ground_truth_mask[channel, ...] + channel_prediction = predicted_mask[channel, ...] if np.prod(channel_ground_truth.shape) < (5000 * 5000): results = _process_instance_matching( channel_ground_truth, channel_prediction, iou_threshold, @@ -166,12 +164,7 @@ def compute_segmentation_f1_metrics( tp = results['tp'] fp = results['fp'] fn = results['fn'] - precision, recall, f1_score = compute_f1_scores( - tp, fp, fn # type: ignore - ) - precision_list.append(precision) - recall_list.append(recall) - f1_score_list.append(f1_score) + true_positive_list.append(tp) false_positive_list.append(fp) false_negative_list.append(fn) @@ -181,17 +174,75 @@ def compute_segmentation_f1_metrics( false_negative_mask_list.append(results.get('fn_mask')) # type: ignore output: Dict[str, np.ndarray] = { - 'precision': np.array(precision_list), - 'recall': np.array(recall_list), - 'f1_score': np.array(f1_score_list), 'tp': np.array(true_positive_list), 'fp': np.array(false_positive_list), 'fn': np.array(false_negative_list) } if return_error_masks: - output['tp_mask'] = np.stack(true_positive_mask_list, axis=-1) # type: ignore - output['fp_mask'] = np.stack(false_positive_mask_list, axis=-1) # type: ignore - output['fn_mask'] = np.stack(false_negative_mask_list, axis=-1) # type: ignore + output['tp_mask'] = np.stack(true_positive_mask_list, axis=0) # type: ignore + output['fp_mask'] = np.stack(false_positive_mask_list, axis=0) # type: ignore + output['fn_mask'] = np.stack(false_negative_mask_list, axis=0) # type: ignore + return output + + +def compute_segmentation_f1_metrics( + ground_truth_mask: np.ndarray, + predicted_mask: np.ndarray, + iou_threshold: float = 0.5, + return_error_masks: bool = False, + remove_boundary_objects: bool = True +) -> Dict[str, np.ndarray]: + """ + Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image. + + If the input masks have shape (H, W), they are expanded to (1, H, W). + For multi-channel inputs (C, H, W), each channel is processed independently, and the returned + metrics (precision, recall, f1_score, TP, FP, FN) are provided as NumPy arrays with shape (C,). + + Optionally, if return_error_masks is True, binary error masks for true positives, false positives, + and false negatives are also returned with shape (C, H, W). + + Args: + ground_truth_mask: Ground truth segmentation mask (HxW or CxHxW). + predicted_mask: Predicted segmentation mask (HxW or CxHxW). + iou_threshold: IoU threshold for matching objects. + return_error_masks: Whether to also return binary error masks. + remove_boundary_objects: Whether to remove objects that touch the image boundary. + + Returns: + A dictionary with the following keys: + - 'precision', 'recall', 'f1_score': arrays of shape (C,) + - 'tp', 'fp', 'fn': arrays of shape (C,) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (C, H, W) + """ + num_channels = ground_truth_mask.shape[0] + precision_list = [] + recall_list = [] + f1_score_list = [] + + results = compute_segmentation_tp_fp_fn( + ground_truth_mask, predicted_mask, + iou_threshold, return_error_masks, + remove_boundary_objects + ) + # Process each channel independently. + for channel in range(num_channels): + tp = results['tp'][channel] + fp = results['fp'][channel] + fn = results['fn'][channel] + precision, recall, f1_score = compute_f1_score( + tp, fp, fn # type: ignore + ) + precision_list.append(precision) + recall_list.append(recall) + f1_score_list.append(f1_score) + + output: Dict[str, np.ndarray] = { + 'precision': np.array(precision_list), + 'recall': np.array(recall_list), + 'f1_score': np.array(f1_score_list), + } + output.update(results) return output @@ -205,16 +256,16 @@ def compute_segmentation_average_precision_metrics( """ Computes the average precision (AP) for segmentation on a single image. - If the input masks have shape (H, W), they are expanded to (H, W, 1). - For multi-channel inputs (H, W, C), each channel is processed independently and the returned + If the input masks have shape (H, W), they are expanded to (1, H, W). + For multi-channel inputs (C, H, W), each channel is processed independently and the returned metrics (average precision, TP, FP, FN) are provided as NumPy arrays with shape (C,). Optionally, if return_error_masks is True, binary error masks for true positives, false positives, - and false negatives are also returned with shape (H, W, C). + and false negatives are also returned with shape (C, H, W). Args: - ground_truth_mask: Ground truth segmentation mask (HxW or HxWxC). - predicted_mask: Predicted segmentation mask (HxW or HxWxC). + ground_truth_mask: Ground truth segmentation mask (HxW or CxHxW). + predicted_mask: Predicted segmentation mask (HxW or CxHxW). iou_threshold: IoU threshold for matching objects. return_error_masks: Whether to also return binary error masks. remove_boundary_objects: Whether to remove objects that touch the image boundary. @@ -223,62 +274,99 @@ def compute_segmentation_average_precision_metrics( A dictionary with the following keys: - 'avg_precision': array of shape (C,) - 'tp', 'fp', 'fn': arrays of shape (C,) - - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (H, W, C) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (C, H, W) """ - ground_truth_mask = _ensure_ndim(ground_truth_mask, 3, insert_position=-1) - predicted_mask = _ensure_ndim(predicted_mask, 3, insert_position=-1) - - num_channels = ground_truth_mask.shape[-1] + num_channels = ground_truth_mask.shape[0] avg_precision_list = [] - true_positive_list = [] - false_positive_list = [] - false_negative_list = [] - if return_error_masks: - true_positive_mask_list = [] - false_positive_mask_list = [] - false_negative_mask_list = [] + + results = compute_segmentation_tp_fp_fn( + ground_truth_mask, predicted_mask, + iou_threshold, return_error_masks, + remove_boundary_objects + ) # Process each channel independently. for channel in range(num_channels): - channel_ground_truth = ground_truth_mask[..., channel] - channel_prediction = predicted_mask[..., channel] - if np.prod(channel_ground_truth.shape) < (5000 * 5000): - results = _process_instance_matching( - channel_ground_truth, channel_prediction, - iou_threshold, - return_masks=return_error_masks, without_boundary_objects=remove_boundary_objects - ) - else: - results = _compute_patch_based_metrics( - channel_ground_truth, channel_prediction, - iou_threshold, - return_masks=return_error_masks, without_boundary_objects=remove_boundary_objects - ) - tp = results['tp'] - fp = results['fp'] - fn = results['fn'] + tp = results['tp'][channel] + fp = results['fp'][channel] + fn = results['fn'][channel] avg_precision = compute_average_precision_score( tp, fp, fn # type: ignore ) avg_precision_list.append(avg_precision) - true_positive_list.append(tp) - false_positive_list.append(fp) - false_negative_list.append(fn) + + output: Dict[str, np.ndarray] = { + 'avg_precision': np.array(avg_precision_list) + } + output.update(results) + return output + + +def compute_batch_segmentation_tp_fp_fn( + batch_ground_truth: np.ndarray, + batch_prediction: np.ndarray, + iou_threshold: float = 0.5, + return_error_masks: bool = False, + remove_boundary_objects: bool = True +) -> Dict[str, np.ndarray]: + """ + Computes segmentation TP, FP and FN for a batch of images. + + Expects inputs with shape (B, C, H, W). For each image in the batch, the data is extracted + into (C, H, W) and then processed using compute_segmentation_tp_fp_fn. The results are stacked + so that each metric has shape (B, C). If error masks are returned, their shape will be (B, C, H, W). + + Args: + batch_ground_truth: Batch of ground truth masks (BxCxHxW). + batch_prediction: Batch of predicted masks (BxCxHxW). + iou_threshold: IoU threshold for matching objects. + return_error_masks: Whether to also return binary error masks. + remove_boundary_objects: Whether to remove objects that touch the image boundary. + + Returns: + A dictionary with keys: + - 'tp', 'fp', 'fn': arrays of shape (B, C) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, C, H, W) + """ + batch_ground_truth = _ensure_ndim(batch_ground_truth, 4, insert_position=0) + batch_prediction = _ensure_ndim(batch_prediction, 4, insert_position=0) + + batch_size = batch_ground_truth.shape[0] + tp_list = [] + fp_list = [] + fn_list = [] + if return_error_masks: + tp_mask_list = [] + fp_mask_list = [] + fn_mask_list = [] + + for i in range(batch_size): + image_ground_truth = batch_ground_truth[i] + image_prediction = batch_prediction[i] + result = compute_segmentation_tp_fp_fn( + image_ground_truth, + image_prediction, + iou_threshold, + return_error_masks, + remove_boundary_objects + ) + tp_list.append(result['tp']) + fp_list.append(result['fp']) + fn_list.append(result['fn']) if return_error_masks: - true_positive_mask_list.append(results.get('tp_mask')) # type: ignore - false_positive_mask_list.append(results.get('fp_mask')) # type: ignore - false_negative_mask_list.append(results.get('fn_mask')) # type: ignore + tp_mask_list.append(result.get('tp_mask')) # type: ignore + fp_mask_list.append(result.get('fp_mask')) # type: ignore + fn_mask_list.append(result.get('fn_mask')) # type: ignore output: Dict[str, np.ndarray] = { - 'avg_precision': np.array(avg_precision_list), - 'tp': np.array(true_positive_list), - 'fp': np.array(false_positive_list), - 'fn': np.array(false_negative_list) + 'tp': np.stack(tp_list, axis=0), + 'fp': np.stack(fp_list, axis=0), + 'fn': np.stack(fn_list, axis=0) } if return_error_masks: - output['tp_mask'] = np.stack(true_positive_mask_list, axis=-1) # type: ignore - output['fp_mask'] = np.stack(false_positive_mask_list, axis=-1) # type: ignore - output['fn_mask'] = np.stack(false_negative_mask_list, axis=-1) # type: ignore + output['tp_mask'] = np.stack(tp_mask_list, axis=0) # type: ignore + output['fp_mask'] = np.stack(fp_mask_list, axis=0) # type: ignore + output['fn_mask'] = np.stack(fn_mask_list, axis=0) # type: ignore return output @@ -292,9 +380,9 @@ def compute_batch_segmentation_f1_metrics( """ Computes segmentation F1 metrics for a batch of images. - Expects inputs with shape (B, C, H, W). For each image in the batch, the data is transposed - to (H, W, C) and then processed with compute_segmentation_f1_metrics. The results are stacked - so that each metric has shape (B, C). If error masks are returned, their shape will be (B, H, W, C). + Expects inputs with shape (B, C, H, W). For each image in the batch, the data is extracted + into (C, H, W) and then processed using compute_segmentation_f1_metrics. The results are stacked + so that each metric has shape (B, C). If error masks are returned, their shape will be (B, C, H, W). Args: batch_ground_truth: Batch of ground truth masks (BxCxHxW). @@ -306,10 +394,10 @@ def compute_batch_segmentation_f1_metrics( Returns: A dictionary with keys: - 'precision', 'recall', 'f1_score', 'tp', 'fp', 'fn': arrays of shape (B, C) - - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, H, W, C) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, C, H, W) """ - batch_ground_truth = _ensure_ndim(batch_ground_truth, 4) - batch_prediction = _ensure_ndim(batch_prediction, 4) + batch_ground_truth = _ensure_ndim(batch_ground_truth, 4, insert_position=0) + batch_prediction = _ensure_ndim(batch_prediction, 4, insert_position=0) batch_size = batch_ground_truth.shape[0] precision_list = [] @@ -324,9 +412,8 @@ def compute_batch_segmentation_f1_metrics( fn_mask_list = [] for i in range(batch_size): - # Each image is expected to have shape (C, H, W); transpose to (H, W, C) - image_ground_truth = np.transpose(batch_ground_truth[i], (1, 2, 0)) - image_prediction = np.transpose(batch_prediction[i], (1, 2, 0)) + image_ground_truth = batch_ground_truth[i] + image_prediction = batch_prediction[i] result = compute_segmentation_f1_metrics( image_ground_truth, image_prediction, @@ -370,9 +457,9 @@ def compute_batch_segmentation_average_precision_metrics( """ Computes segmentation average precision metrics for a batch of images. - Expects inputs with shape (B, C, H, W). For each image in the batch, the data is transposed - to (H, W, C) and then processed with compute_segmentation_average_precision_metrics. The results are stacked - so that each metric has shape (B, C). If error masks are returned, their shape will be (B, H, W, C). + Expects inputs with shape (B, C, H, W). For each image in the batch, the data is extracted + into (C, H, W) and then processed using compute_segmentation_average_precision_metrics. The results are stacked + so that each metric has shape (B, C). If error masks are returned, their shape will be (B, C, H, W). Args: batch_ground_truth: Batch of ground truth masks (BxCxHxW). @@ -384,10 +471,10 @@ def compute_batch_segmentation_average_precision_metrics( Returns: A dictionary with keys: - 'avg_precision', 'tp', 'fp', 'fn': arrays of shape (B, C) - - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, H, W, C) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, C, H, W) """ - batch_ground_truth = _ensure_ndim(batch_ground_truth, 4) - batch_prediction = _ensure_ndim(batch_prediction, 4) + batch_ground_truth = _ensure_ndim(batch_ground_truth, 4, insert_position=0) + batch_prediction = _ensure_ndim(batch_prediction, 4, insert_position=0) batch_size = batch_ground_truth.shape[0] avg_precision_list = [] @@ -400,10 +487,14 @@ def compute_batch_segmentation_average_precision_metrics( fn_mask_list = [] for i in range(batch_size): - ground_truth_mask = np.transpose(batch_ground_truth[i], (1, 2, 0)) - prediction_mask = np.transpose(batch_prediction[i], (1, 2, 0)) + ground_truth_mask = batch_ground_truth[i] + prediction_mask = batch_prediction[i] result = compute_segmentation_average_precision_metrics( - ground_truth_mask, prediction_mask, iou_threshold, return_error_masks, remove_boundary_objects + ground_truth_mask, + prediction_mask, + iou_threshold, + return_error_masks, + remove_boundary_objects ) avg_precision_list.append(result['avg_precision']) tp_list.append(result['tp'])