diff --git a/core/utils/__init__.py b/core/utils/__init__.py index e69de29..93e96eb 100644 --- a/core/utils/__init__.py +++ b/core/utils/__init__.py @@ -0,0 +1 @@ +from .measures import * \ No newline at end of file diff --git a/core/utils/measures.py b/core/utils/measures.py new file mode 100644 index 0000000..e95bcf7 --- /dev/null +++ b/core/utils/measures.py @@ -0,0 +1,759 @@ +""" +Adapted from: +[1] https://github.com/JunMa11/NeurIPS-CellSeg/blob/main/baseline/compute_metric.py +[2] https://github.com/stardist/stardist/blob/master/stardist/matching.py +""" + +import numpy as np +from numpy.typing import NDArray +from numba import jit +from skimage import segmentation +from scipy.optimize import linear_sum_assignment +from typing import Dict, List, Tuple, Any, Union + +__all__ = [ + "compute_batch_segmentation_f1_metrics", "compute_batch_segmentation_average_precision_metrics", + "compute_segmentation_f1_metrics", "compute_segmentation_average_precision_metrics", + "compute_confusion_matrix", "compute_f1_scores", "compute_average_precision_score" +] + + +def compute_f1_scores( + true_positives: int, + false_positives: int, + false_negatives: int +) -> Tuple[float, float, float]: + """ + Computes the precision, recall, and F1-score given the numbers of + true positives, false positives, and false negatives. + + Args: + true_positives: Number of true positive detections. + false_positives: Number of false positive detections. + false_negatives: Number of false negative detections. + + Returns: + A tuple (precision, recall, f1_score). + """ + if true_positives == 0: + return 0.0, 0.0, 0.0 + precision = true_positives / (true_positives + false_positives) + recall = true_positives / (true_positives + false_negatives) + f1_score = 2 * (precision * recall) / (precision + recall) + return precision, recall, f1_score + + +def compute_average_precision_score( + true_positives: int, + false_positives: int, + false_negatives: int +) -> float: + """ + Computes the average precision score using the formula: + + Average Precision = TP / (TP + FP + FN) + + If the denominator is zero, returns 0. + + Args: + true_positives: Number of true positive detections. + false_positives: Number of false positive detections. + false_negatives: Number of false negative detections. + + Returns: + Average precision score as a float. + """ + denominator = true_positives + false_positives + false_negatives + return 0.0 if denominator == 0 else true_positives / denominator + + +def compute_confusion_matrix( + ground_truth_mask: np.ndarray, + predicted_mask: np.ndarray, + iou_threshold: float = 0.5 +) -> Tuple[int, int, int]: + """ + Computes the confusion matrix elements (true positives, false positives, false negatives) + for a single image given the ground truth and predicted masks. + + Args: + ground_truth_mask: Ground truth segmentation mask. + predicted_mask: Predicted segmentation mask. + iou_threshold: IoU threshold for matching objects. + + Returns: + A tuple (TP, FP, FN). + """ + # Determine the number of objects in the ground truth and prediction. + num_ground_truth = np.max(ground_truth_mask) + num_predictions = np.max(predicted_mask) + + # If no predictions were made, return zeros (with a printout for debugging). + if num_predictions == 0: + print("No segmentation results!") + return 0, 0, 0 + + # Compute the IoU matrix and ignore the background (first row and column). + iou_matrix = _calculate_iou(ground_truth_mask, predicted_mask) + # Count true positives based on optimal matching. + true_positive_count = _calculate_true_positive(iou_matrix, iou_threshold) + # Derive false positives and false negatives. + false_positive_count = num_predictions - true_positive_count + false_negative_count = num_ground_truth - true_positive_count + return true_positive_count, false_positive_count, false_negative_count + + +def compute_segmentation_f1_metrics( + ground_truth_mask: np.ndarray, + predicted_mask: np.ndarray, + iou_threshold: float = 0.5, + return_error_masks: bool = False, + remove_boundary_objects: bool = True +) -> Dict[str, np.ndarray]: + """ + Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image. + + If the input masks have shape (H, W), they are expanded to (H, W, 1). + For multi-channel inputs (H, W, C), each channel is processed independently, and the returned + metrics (precision, recall, f1_score, TP, FP, FN) are provided as NumPy arrays with shape (C,). + + Optionally, if return_error_masks is True, binary error masks for true positives, false positives, + and false negatives are also returned with shape (H, W, C). + + Args: + ground_truth_mask: Ground truth segmentation mask (HxW or HxWxC). + predicted_mask: Predicted segmentation mask (HxW or HxWxC). + iou_threshold: IoU threshold for matching objects. + return_error_masks: Whether to also return binary error masks. + remove_boundary_objects: Whether to remove objects that touch the image boundary. + + Returns: + A dictionary with the following keys: + - 'precision', 'recall', 'f1_score': arrays of shape (C,) + - 'tp', 'fp', 'fn': arrays of shape (C,) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (H, W, C) + """ + # If the masks are 2D, add a singleton channel dimension. + ground_truth_mask = _ensure_ndim(ground_truth_mask, 3, insert_position=-1) + predicted_mask = _ensure_ndim(predicted_mask, 3, insert_position=-1) + + num_channels = ground_truth_mask.shape[-1] + precision_list = [] + recall_list = [] + f1_score_list = [] + true_positive_list = [] + false_positive_list = [] + false_negative_list = [] + if return_error_masks: + true_positive_mask_list = [] + false_positive_mask_list = [] + false_negative_mask_list = [] + + # Process each channel independently. + for channel in range(num_channels): + channel_ground_truth = ground_truth_mask[..., channel] + channel_prediction = predicted_mask[..., channel] + if np.prod(channel_ground_truth.shape) < (5000 * 5000): + results = _process_instance_matching( + channel_ground_truth, channel_prediction, iou_threshold, + return_masks=return_error_masks, without_boundary_objects=remove_boundary_objects + ) + else: + results = _compute_patch_based_metrics( + channel_ground_truth, channel_prediction, iou_threshold, + return_masks=return_error_masks, without_boundary_objects=remove_boundary_objects + ) + tp = results['tp'] + fp = results['fp'] + fn = results['fn'] + precision, recall, f1_score = compute_f1_scores( + tp, fp, fn # type: ignore + ) + precision_list.append(precision) + recall_list.append(recall) + f1_score_list.append(f1_score) + true_positive_list.append(tp) + false_positive_list.append(fp) + false_negative_list.append(fn) + if return_error_masks: + true_positive_mask_list.append(results.get('tp_mask')) # type: ignore + false_positive_mask_list.append(results.get('fp_mask')) # type: ignore + false_negative_mask_list.append(results.get('fn_mask')) # type: ignore + + output: Dict[str, np.ndarray] = { + 'precision': np.array(precision_list), + 'recall': np.array(recall_list), + 'f1_score': np.array(f1_score_list), + 'tp': np.array(true_positive_list), + 'fp': np.array(false_positive_list), + 'fn': np.array(false_negative_list) + } + if return_error_masks: + output['tp_mask'] = np.stack(true_positive_mask_list, axis=-1) # type: ignore + output['fp_mask'] = np.stack(false_positive_mask_list, axis=-1) # type: ignore + output['fn_mask'] = np.stack(false_negative_mask_list, axis=-1) # type: ignore + return output + + +def compute_segmentation_average_precision_metrics( + ground_truth_mask: np.ndarray, + predicted_mask: np.ndarray, + iou_threshold: float = 0.5, + return_error_masks: bool = False, + remove_boundary_objects: bool = True +) -> Dict[str, np.ndarray]: + """ + Computes the average precision (AP) for segmentation on a single image. + + If the input masks have shape (H, W), they are expanded to (H, W, 1). + For multi-channel inputs (H, W, C), each channel is processed independently and the returned + metrics (average precision, TP, FP, FN) are provided as NumPy arrays with shape (C,). + + Optionally, if return_error_masks is True, binary error masks for true positives, false positives, + and false negatives are also returned with shape (H, W, C). + + Args: + ground_truth_mask: Ground truth segmentation mask (HxW or HxWxC). + predicted_mask: Predicted segmentation mask (HxW or HxWxC). + iou_threshold: IoU threshold for matching objects. + return_error_masks: Whether to also return binary error masks. + remove_boundary_objects: Whether to remove objects that touch the image boundary. + + Returns: + A dictionary with the following keys: + - 'avg_precision': array of shape (C,) + - 'tp', 'fp', 'fn': arrays of shape (C,) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask' with shape (H, W, C) + """ + ground_truth_mask = _ensure_ndim(ground_truth_mask, 3, insert_position=-1) + predicted_mask = _ensure_ndim(predicted_mask, 3, insert_position=-1) + + num_channels = ground_truth_mask.shape[-1] + avg_precision_list = [] + true_positive_list = [] + false_positive_list = [] + false_negative_list = [] + if return_error_masks: + true_positive_mask_list = [] + false_positive_mask_list = [] + false_negative_mask_list = [] + + # Process each channel independently. + for channel in range(num_channels): + channel_ground_truth = ground_truth_mask[..., channel] + channel_prediction = predicted_mask[..., channel] + if np.prod(channel_ground_truth.shape) < (5000 * 5000): + results = _process_instance_matching( + channel_ground_truth, channel_prediction, + iou_threshold, + return_masks=return_error_masks, without_boundary_objects=remove_boundary_objects + ) + else: + results = _compute_patch_based_metrics( + channel_ground_truth, channel_prediction, + iou_threshold, + return_masks=return_error_masks, without_boundary_objects=remove_boundary_objects + ) + tp = results['tp'] + fp = results['fp'] + fn = results['fn'] + avg_precision = compute_average_precision_score( + tp, fp, fn # type: ignore + ) + avg_precision_list.append(avg_precision) + true_positive_list.append(tp) + false_positive_list.append(fp) + false_negative_list.append(fn) + if return_error_masks: + true_positive_mask_list.append(results.get('tp_mask')) # type: ignore + false_positive_mask_list.append(results.get('fp_mask')) # type: ignore + false_negative_mask_list.append(results.get('fn_mask')) # type: ignore + + output: Dict[str, np.ndarray] = { + 'avg_precision': np.array(avg_precision_list), + 'tp': np.array(true_positive_list), + 'fp': np.array(false_positive_list), + 'fn': np.array(false_negative_list) + } + if return_error_masks: + output['tp_mask'] = np.stack(true_positive_mask_list, axis=-1) # type: ignore + output['fp_mask'] = np.stack(false_positive_mask_list, axis=-1) # type: ignore + output['fn_mask'] = np.stack(false_negative_mask_list, axis=-1) # type: ignore + return output + + +def compute_batch_segmentation_f1_metrics( + batch_ground_truth: np.ndarray, + batch_prediction: np.ndarray, + iou_threshold: float = 0.5, + return_error_masks: bool = False, + remove_boundary_objects: bool = True +) -> Dict[str, np.ndarray]: + """ + Computes segmentation F1 metrics for a batch of images. + + Expects inputs with shape (B, C, H, W). For each image in the batch, the data is transposed + to (H, W, C) and then processed with compute_segmentation_f1_metrics. The results are stacked + so that each metric has shape (B, C). If error masks are returned, their shape will be (B, H, W, C). + + Args: + batch_ground_truth: Batch of ground truth masks (BxCxHxW). + batch_prediction: Batch of predicted masks (BxCxHxW). + iou_threshold: IoU threshold for matching objects. + return_error_masks: Whether to also return binary error masks. + remove_boundary_objects: Whether to remove objects that touch the image boundary. + + Returns: + A dictionary with keys: + - 'precision', 'recall', 'f1_score', 'tp', 'fp', 'fn': arrays of shape (B, C) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, H, W, C) + """ + batch_ground_truth = _ensure_ndim(batch_ground_truth, 4) + batch_prediction = _ensure_ndim(batch_prediction, 4) + + batch_size = batch_ground_truth.shape[0] + precision_list = [] + recall_list = [] + f1_score_list = [] + tp_list = [] + fp_list = [] + fn_list = [] + if return_error_masks: + tp_mask_list = [] + fp_mask_list = [] + fn_mask_list = [] + + for i in range(batch_size): + # Each image is expected to have shape (C, H, W); transpose to (H, W, C) + image_ground_truth = np.transpose(batch_ground_truth[i], (1, 2, 0)) + image_prediction = np.transpose(batch_prediction[i], (1, 2, 0)) + result = compute_segmentation_f1_metrics( + image_ground_truth, + image_prediction, + iou_threshold, + return_error_masks, + remove_boundary_objects + ) + precision_list.append(result['precision']) + recall_list.append(result['recall']) + f1_score_list.append(result['f1_score']) + tp_list.append(result['tp']) + fp_list.append(result['fp']) + fn_list.append(result['fn']) + if return_error_masks: + tp_mask_list.append(result.get('tp_mask')) # type: ignore + fp_mask_list.append(result.get('fp_mask')) # type: ignore + fn_mask_list.append(result.get('fn_mask')) # type: ignore + + output: Dict[str, np.ndarray] = { + 'precision': np.stack(precision_list, axis=0), + 'recall': np.stack(recall_list, axis=0), + 'f1_score': np.stack(f1_score_list, axis=0), + 'tp': np.stack(tp_list, axis=0), + 'fp': np.stack(fp_list, axis=0), + 'fn': np.stack(fn_list, axis=0) + } + if return_error_masks: + output['tp_mask'] = np.stack(tp_mask_list, axis=0) # type: ignore + output['fp_mask'] = np.stack(fp_mask_list, axis=0) # type: ignore + output['fn_mask'] = np.stack(fn_mask_list, axis=0) # type: ignore + return output + + +def compute_batch_segmentation_average_precision_metrics( + batch_ground_truth: np.ndarray, + batch_prediction: np.ndarray, + iou_threshold: float = 0.5, + return_error_masks: bool = False, + remove_boundary_objects: bool = True +) -> Dict[str, NDArray]: + """ + Computes segmentation average precision metrics for a batch of images. + + Expects inputs with shape (B, C, H, W). For each image in the batch, the data is transposed + to (H, W, C) and then processed with compute_segmentation_average_precision_metrics. The results are stacked + so that each metric has shape (B, C). If error masks are returned, their shape will be (B, H, W, C). + + Args: + batch_ground_truth: Batch of ground truth masks (BxCxHxW). + batch_prediction: Batch of predicted masks (BxCxHxW). + iou_threshold: IoU threshold for matching objects. + return_error_masks: Whether to also return binary error masks. + remove_boundary_objects: Whether to remove objects that touch the image boundary. + + Returns: + A dictionary with keys: + - 'avg_precision', 'tp', 'fp', 'fn': arrays of shape (B, C) + - If return_error_masks is True: 'tp_mask', 'fp_mask', 'fn_mask': arrays of shape (B, H, W, C) + """ + batch_ground_truth = _ensure_ndim(batch_ground_truth, 4) + batch_prediction = _ensure_ndim(batch_prediction, 4) + + batch_size = batch_ground_truth.shape[0] + avg_precision_list = [] + tp_list = [] + fp_list = [] + fn_list = [] + if return_error_masks: + tp_mask_list = [] + fp_mask_list = [] + fn_mask_list = [] + + for i in range(batch_size): + ground_truth_mask = np.transpose(batch_ground_truth[i], (1, 2, 0)) + prediction_mask = np.transpose(batch_prediction[i], (1, 2, 0)) + result = compute_segmentation_average_precision_metrics( + ground_truth_mask, prediction_mask, iou_threshold, return_error_masks, remove_boundary_objects + ) + avg_precision_list.append(result['avg_precision']) + tp_list.append(result['tp']) + fp_list.append(result['fp']) + fn_list.append(result['fn']) + if return_error_masks: + tp_mask_list.append(result.get('tp_mask')) # type: ignore + fp_mask_list.append(result.get('fp_mask')) # type: ignore + fn_mask_list.append(result.get('fn_mask')) # type: ignore + + output: Dict[str, NDArray] = { + 'avg_precision': np.stack(avg_precision_list, axis=0), + 'tp': np.stack(tp_list, axis=0), + 'fp': np.stack(fp_list, axis=0), + 'fn': np.stack(fn_list, axis=0) + } + if return_error_masks: + output['tp_mask'] = np.stack(tp_mask_list, axis=0) # type: ignore + output['fp_mask'] = np.stack(fp_mask_list, axis=0) # type: ignore + output['fn_mask'] = np.stack(fn_mask_list, axis=0) # type: ignore + return output + + +# ===================== INTERNAL HELPER FUNCTIONS ===================== + +def _ensure_ndim(array: np.ndarray, target_ndim: int, insert_position: int = 0) -> np.ndarray: + """ + Makes sure that the array has the right dimension by adding axes in front if necessary. + + Args: + array (np.ndarray): Input array. + target_new (int): The expected number of axes. + insert_position (int): Where to add axes. + + Returns: + np.ndarray: An array with the desired dimension. + + Raises: + ValueError: If the array cannot be cast to target_ndim in a valid way. + """ + while array.ndim < target_ndim: + array = np.expand_dims(array, axis=insert_position) + + if array.ndim != target_ndim: + raise ValueError( + f"Expected ndim={target_ndim}, but got ndim={array.ndim} and shape={array.shape}" + ) + + return array + + +def _process_instance_matching( + ground_truth_mask: np.ndarray, + predicted_mask: np.ndarray, + iou_threshold: float = 0.5, + return_masks: bool = False, + without_boundary_objects: bool = True +) -> Dict[str, Union[int, NDArray[np.uint8]]]: + """ + Processes instance matching on a full image by performing the following steps: + - Removes objects that touch the image boundary and reindexes the masks. + - Computes the IoU matrix between instances (ignoring background). + - Computes optimal matching via linear assignment based on the IoU matrix. + + If return_masks is True, binary error masks (TP, FP, FN) are also generated. + + Args: + ground_truth_mask: Ground truth instance mask. + predicted_mask: Predicted instance mask. + iou_threshold: IoU threshold for matching. + return_masks: Whether to generate binary error masks. + without_boundary_objects: Whether to remove objects touching the image boundary. + + Returns: + A dictionary with keys: + - 'tp', 'fp', 'fn': integer counts. + - If return_masks is True, also 'tp_mask', 'fp_mask', and 'fn_mask'. + """ + # Optionally remove boundary objects. + if without_boundary_objects: + processed_ground_truth = _remove_boundary_objects(ground_truth_mask.astype(np.int32)) + processed_prediction = _remove_boundary_objects(predicted_mask.astype(np.int32)) + else: + processed_ground_truth = ground_truth_mask.astype(np.int32) + processed_prediction = predicted_mask.astype(np.int32) + + num_ground_truth = np.max(processed_ground_truth) + num_prediction = np.max(processed_prediction) + + # If no predictions are found, return with all ground truth as false negatives. + if num_prediction == 0: + print("No segmentation results!") + result = {'tp': 0, 'fp': 0, 'fn': num_ground_truth} + if return_masks: + tp_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8) + fp_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8) + fn_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8) + # Mark all ground truth objects as false negatives. + fn_mask[ground_truth_mask > 0] = 1 + result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) + return result + + # Compute the IoU matrix for the processed masks. + iou_matrix = _calculate_iou(processed_ground_truth, processed_prediction) + # Compute optimal matching pairs using linear assignment. + matching_pairs = _compute_optimal_matching_pairs(iou_matrix, iou_threshold) + + true_positive_count = len(matching_pairs) + false_positive_count = num_prediction - true_positive_count + false_negative_count = num_ground_truth - true_positive_count + result = {'tp': true_positive_count, 'fp': false_positive_count, 'fn': false_negative_count} + + if return_masks: + # Initialize binary masks for error visualization. + tp_mask = np.zeros_like(processed_ground_truth, dtype=np.uint8) + fp_mask = np.zeros_like(processed_ground_truth, dtype=np.uint8) + fn_mask = np.zeros_like(processed_ground_truth, dtype=np.uint8) + + # Record which labels were matched. + matched_ground_truth_labels = {gt for gt, _ in matching_pairs} + matched_prediction_labels = {pred for _, pred in matching_pairs} + + # For each matching pair, mark the intersection as true positive. + for gt_label, pred_label in matching_pairs: + gt_region = (processed_ground_truth == gt_label) + prediction_region = (processed_prediction == pred_label) + intersection = gt_region & prediction_region + tp_mask[intersection] = 1 + # Mark parts of the ground truth not in the intersection as false negatives. + fn_mask[gt_region & ~intersection] = 1 + # Mark parts of the prediction not in the intersection as false positives. + fp_mask[prediction_region & ~intersection] = 1 + + # Mark entire regions for objects with no match. + all_ground_truth_labels = set(np.unique(processed_ground_truth)) - {0} + for gt_label in (all_ground_truth_labels - matched_ground_truth_labels): + fn_mask[processed_ground_truth == gt_label] = 1 + + all_prediction_labels = set(np.unique(processed_prediction)) - {0} + for pred_label in (all_prediction_labels - matched_prediction_labels): + fp_mask[processed_prediction == pred_label] = 1 + + result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) + return result + + +def _compute_optimal_matching_pairs(iou_matrix: np.ndarray, iou_threshold: float) -> List[Any]: + """ + Computes the optimal matching pairs between ground truth and predicted masks using the IoU matrix. + + Args: + iou_matrix: The IoU matrix between ground truth and predicted masks. + iou_threshold: The IoU threshold for considering a valid match. + + Returns: + A list of tuples (ground_truth_label, predicted_label) representing matched pairs. + """ + # Exclude the background (first row and column). + iou_without_background = iou_matrix[1:, 1:] + + if iou_without_background.size == 0: + return [] + + # Determine the number of possible matches. + num_possible_matches = min(iou_without_background.shape[0], iou_without_background.shape[1]) + + # Create a cost matrix where lower costs indicate better matches. + cost_matrix = -(iou_without_background >= iou_threshold).astype(np.float64) - iou_without_background / (2 * num_possible_matches) + + # Solve the assignment problem using the Hungarian algorithm. + matched_ground_truth_indices, matched_prediction_indices = linear_sum_assignment(cost_matrix) + + # Only accept matches that meet the IoU threshold. + matched_pairs_arr = np.stack([matched_ground_truth_indices, matched_prediction_indices], axis=1) + + # Filtering by IoU + ious = iou_without_background[matched_ground_truth_indices, matched_prediction_indices] + valid_mask = ious >= iou_threshold + + # Apply a filter and index shift (the background is missing, so +1) + return (matched_pairs_arr[valid_mask] + 1).tolist() + + +def _compute_patch_based_metrics( + ground_truth_mask: np.ndarray, + predicted_mask: np.ndarray, + iou_threshold: float = 0.5, + return_masks: bool = False, + without_boundary_objects: bool = True +) -> Dict[str, Union[int, NDArray[np.uint8]]]: + """ + Computes segmentation metrics using a patch-based approach for very large images. + + The image is divided into fixed-size patches (e.g., 2000x2000 pixels). For each patch, + instance matching is performed and the statistics (TP, FP, FN) are accumulated. + If error masks are requested, they are also collected and cropped to the original image size. + + Args: + ground_truth_mask: Ground truth segmentation mask. + predicted_mask: Predicted segmentation mask. + iou_threshold: IoU threshold for matching objects. + return_masks: Whether to generate binary error masks. + without_boundary_objects: Whether to remove objects that touch the image boundary. + + Returns: + A dictionary with accumulated 'tp', 'fp', 'fn'. If return_masks is True, + also includes 'tp_mask', 'fp_mask', and 'fn_mask'. + """ + H, W = ground_truth_mask.shape + patch_size = 2000 + + # Calculate number of patches needed in height and width. + num_patches_height = H // patch_size + (H % patch_size != 0) + num_patches_width = W // patch_size + (W % patch_size != 0) + padded_height, padded_width = patch_size * num_patches_height, patch_size * num_patches_width + + # Create padded images to ensure full patches. + padded_ground_truth = np.zeros((padded_height, padded_width), dtype=ground_truth_mask.dtype) + padded_prediction = np.zeros((padded_height, padded_width), dtype=ground_truth_mask.dtype) + padded_ground_truth[:H, :W] = ground_truth_mask + padded_prediction[:H, :W] = predicted_mask + + total_tp, total_fp, total_fn = 0, 0, 0 + if return_masks: + padded_tp_mask = np.zeros((padded_height, padded_width), dtype=np.uint8) + padded_fp_mask = np.zeros((padded_height, padded_width), dtype=np.uint8) + padded_fn_mask = np.zeros((padded_height, padded_width), dtype=np.uint8) + + # Loop over all patches. + for i in range(num_patches_height): + for j in range(num_patches_width): + y_start, y_end = patch_size * i, patch_size * (i + 1) + x_start, x_end = patch_size * j, patch_size * (j + 1) + # Extract the patch from both ground truth and prediction. + patch_ground_truth = padded_ground_truth[y_start:y_end, x_start:x_end] + patch_prediction = padded_prediction[y_start:y_end, x_start:x_end] + # Process the patch and accumulate the results. + patch_results = _process_instance_matching( + patch_ground_truth, patch_prediction, iou_threshold, + return_masks=return_masks, without_boundary_objects=without_boundary_objects + ) + total_tp += patch_results['tp'] + total_fp += patch_results['fp'] + total_fn += patch_results['fn'] + if return_masks: + padded_tp_mask[y_start:y_end, x_start:x_end] = patch_results.get('tp_mask', 0) # type: ignore + padded_fp_mask[y_start:y_end, x_start:x_end] = patch_results.get('fp_mask', 0) # type: ignore + padded_fn_mask[y_start:y_end, x_start:x_end] = patch_results.get('fn_mask', 0) # type: ignore + + results: Dict[str, Union[int, np.ndarray]] = {'tp': total_tp, 'fp': total_fp, 'fn': total_fn} + if return_masks: + # Crop the padded masks back to the original image size. + results.update({ + 'tp_mask': padded_tp_mask[:H, :W], # type: ignore + 'fp_mask': padded_fp_mask[:H, :W], # type: ignore + 'fn_mask': padded_fn_mask[:H, :W] # type: ignore + }) + return results + + +def _remove_boundary_objects(mask: np.ndarray) -> np.ndarray: + """ + Removes objects that touch the image boundary and reindexes the mask. + + A border of 2 pixels is defined around the image; any object that touches this border is removed. + + Args: + mask: Segmentation mask where 0 represents the background and positive integers represent object labels. + + Returns: + A reindexed mask with objects touching the boundary removed. + """ + H, W = mask.shape + # Create a mask with a border (value 1 in border, 0 in interior). + border_mask = np.ones((H, W), dtype=np.uint8) + border_mask[2:H - 2, 2:W - 2] = 0 + # Multiply the mask with the border mask to identify boundary labels. + border_labels = np.unique(mask * border_mask) + + # Remove objects (set to 0) that appear in the border. + mask[np.isin(mask, border_labels[1:])] = 0 + + # Reindex the mask so that labels are sequential. + new_mask, _, _ = segmentation.relabel_sequential(mask) + return new_mask + + +def _calculate_true_positive(iou_matrix: np.ndarray, iou_threshold: float = 0.5) -> int: + """ + Calculates the number of true positive instances based on the IoU matrix. + + Args: + iou_matrix: IoU matrix between ground truth and predicted masks (excluding background). + iou_threshold: IoU threshold for matching. + + Returns: + The number of true positive matches. + """ + matching_pairs = _compute_optimal_matching_pairs(iou_matrix, iou_threshold) + return len(matching_pairs) + + +def _calculate_iou(ground_truth_mask: np.ndarray, predicted_mask: np.ndarray) -> NDArray[np.float32]: + """ + Computes the Intersection over Union (IoU) matrix between ground truth and predicted masks. + + Args: + ground_truth_mask: Ground truth mask with integer labels. + predicted_mask: Predicted mask with integer labels. + + Returns: + An IoU matrix of shape (num_ground_truth+1, num_prediction+1). + """ + # Compute the overlap matrix between the two masks. + overlap_matrix = _calculate_label_overlap(ground_truth_mask, predicted_mask) + + # Total number of pixels in each predicted object (sum over columns). + pixels_per_prediction = np.sum(overlap_matrix, axis=0, keepdims=True) + # Total number of pixels in each ground truth object (sum over rows). + pixels_per_ground_truth = np.sum(overlap_matrix, axis=1, keepdims=True) + + # Compute the union for each pair. + union_matrix = pixels_per_prediction + pixels_per_ground_truth - overlap_matrix + + # Avoid division by zero. + iou = np.zeros_like(union_matrix, dtype=np.float32) + valid = union_matrix > 0 + iou[valid] = overlap_matrix[valid] / union_matrix[valid] + return iou + + +@jit(nopython=True) +def _calculate_label_overlap(mask_x: np.ndarray, mask_y: np.ndarray) -> NDArray[np.uint32]: + """ + Computes the overlap (number of common pixels) between labels in two masks. + + Args: + mask_x: First mask (integer labels with 0 as background). + mask_y: Second mask (integer labels with 0 as background). + + Returns: + An overlap matrix of shape [mask_x.max()+1, mask_y.max()+1]. + """ + flat_x = mask_x.ravel() + flat_y = mask_y.ravel() + + # Create an empty overlap matrix with size based on the maximum label in each mask. + overlap = np.zeros((1 + flat_x.max(), 1 + flat_y.max()), dtype=np.uint32) + + # Count overlaps for each pixel pair. + for i in range(flat_x.shape[0]): + overlap[flat_x[i], flat_y[i]] += 1 + return overlap