From 28f978956ca499470a3acdb01094c59e6003dc0f Mon Sep 17 00:00:00 2001 From: laynholt Date: Fri, 9 May 2025 17:09:48 +0000 Subject: [PATCH] Added the creation of error masks for each class; Added arguments for creating only prediction masks, as well as refusing to save the output; Added execution time display; The pretrained_weight parameter has been moved to common; Parameters related to ensembles have been removed. --- config/dataset_config.py | 40 +- core/data/transforms/cell_aware.py | 2 +- core/data/transforms/random_crop.py | 193 ++++++++++ core/segmentator.py | 549 +++++++++++++++++++++------- core/utils/measures.py | 2 +- main.py | 45 ++- 6 files changed, 647 insertions(+), 184 deletions(-) create mode 100644 core/data/transforms/random_crop.py diff --git a/config/dataset_config.py b/config/dataset_config.py index 9a43858..ef659bd 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -8,11 +8,12 @@ class DatasetCommonConfig(BaseModel): Common configuration fields shared by both training and testing. """ seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) - device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda') + device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda') use_tta: bool = False # Flag to use Test-Time Augmentation (TTA) use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars' predictions_dir: str = "." # Directory to save predictions + pretrained_weights: str = "" # Path to pretrained weights @model_validator(mode="after") def validate_common(self) -> "DatasetCommonConfig": @@ -39,7 +40,7 @@ class TrainingPreSplitInfo(BaseModel): Contains: - train_dir, valid_dir, test_dir: Directories for training, validation, and testing data. """ - train_dir: str = "." # Directory for training data if data is pre-split + train_dir: str = "." # Directory for training data if data is pre-split valid_dir: str = "" # Directory for validation data if data is pre-split test_dir: str = "" # Directory for testing data if data is pre-split @@ -72,7 +73,6 @@ class DatasetTrainingConfig(BaseModel): batch_size: int = 1 # Batch size for training num_epochs: int = 100 # Number of training epochs val_freq: int = 1 # Frequency of validation during training - pretrained_weights: str = "" # Path to pretrained weights for training @field_validator("train_size", "valid_size", "test_size", mode="before") @@ -137,15 +137,6 @@ class DatasetTrainingConfig(BaseModel): raise ValueError("offsets must be >= 0") return self - @model_validator(mode="after") - def validate_pretrained(self) -> "DatasetTrainingConfig": - """ - Validates that pretrained_weights is provided and exists. - """ - if self.pretrained_weights and not os.path.exists(self.pretrained_weights): - raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}") - return self - class DatasetTestingConfig(BaseModel): """ @@ -155,11 +146,6 @@ class DatasetTestingConfig(BaseModel): test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic) test_offset: int = 0 # Offset for testing data shuffle: bool = True # Shuffle data - - use_ensemble: bool = False # Flag to use ensemble mode in testing - ensemble_pretrained_weights1: str = "." - ensemble_pretrained_weights2: str = "." - pretrained_weights: str = "." @field_validator("test_size", mode="before") def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]: @@ -191,25 +177,11 @@ class DatasetTestingConfig(BaseModel): """ Validates the testing configuration: - test_dir must be non-empty and exist. - - If use_ensemble is True, both ensemble_pretrained_weights1 and ensemble_pretrained_weights2 must be provided and exist. - - If use_ensemble is False, pretrained_weights must be provided and exist. """ if not self.test_dir: raise ValueError("In testing configuration, test_dir must be provided and non-empty") if not os.path.exists(self.test_dir): raise ValueError(f"Path for test_dir does not exist: {self.test_dir}") - if self.use_ensemble: - for field in ["ensemble_pretrained_weights1", "ensemble_pretrained_weights2"]: - value = getattr(self, field) - if not value: - raise ValueError(f"When use_ensemble is True, {field} must be provided and non-empty") - if not os.path.exists(value): - raise ValueError(f"Path for {field} does not exist: {value}") - else: - if not self.pretrained_weights: - raise ValueError("When use_ensemble is False, pretrained_weights must be provided and non-empty") - if not os.path.exists(self.pretrained_weights): - raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}") if self.test_offset < 0: raise ValueError("test_offset must be >= 0") return self @@ -237,11 +209,17 @@ class DatasetConfig(BaseModel): if (self.training.is_split and self.training.pre_split.test_dir) or (not self.training.is_split): if self.training.test_size > 0 and not self.common.predictions_dir: raise ValueError("predictions_dir must be provided when test_size is non-zero") + if self.common.pretrained_weights and not os.path.exists(self.common.pretrained_weights): + raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}") else: if self.testing is None: raise ValueError("Testing configuration must be provided when is_training is False") if self.testing.test_size > 0 and not self.common.predictions_dir: raise ValueError("predictions_dir must be provided when test_size is non-zero") + if not self.common.pretrained_weights: + raise ValueError("When testing pretrained_weights must be provided and non-empty") + if not os.path.exists(self.common.pretrained_weights): + raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}") return self def model_dump(self, **kwargs) -> Dict[str, Any]: diff --git a/core/data/transforms/cell_aware.py b/core/data/transforms/cell_aware.py index 21c606e..d0d14d6 100644 --- a/core/data/transforms/cell_aware.py +++ b/core/data/transforms/cell_aware.py @@ -10,7 +10,7 @@ from core.logger import get_logger __all__ = ["BoundaryExclusion", "IntensityDiversification"] -logger = get_logger("cell_aware") +logger = get_logger(__name__) class BoundaryExclusion(MapTransform): diff --git a/core/data/transforms/random_crop.py b/core/data/transforms/random_crop.py new file mode 100644 index 0000000..634b864 --- /dev/null +++ b/core/data/transforms/random_crop.py @@ -0,0 +1,193 @@ +import torch +import numpy as np +from typing import Hashable, List, Sequence, Optional, Tuple + +from monai.utils.misc import fall_back_tuple +from monai.data.meta_tensor import MetaTensor +from monai.data.utils import get_random_patch, get_valid_patch_size +from monai.transforms import Randomizable, RandCropd, Crop # type: ignore + +from core.logger import get_logger + +logger = get_logger(__name__) + + +def _compute_multilabel_bbox( + mask: np.ndarray +) -> Optional[Tuple[List[int], List[int], List[int], List[int]]]: + """ + Compute per-channel bounding-box constraints and return lists of limits for each axis. + + Args: + mask: multi-channel instance mask of shape (C, H, W). + + Returns: + A tuple of four lists: + - top_mins: list of r_max for each non-empty channel + - top_maxs: list of r_min for each non-empty channel + - left_mins: list of c_max for each non-empty channel + - left_maxs: list of c_min for each non-empty channel + Or None if mask contains no positive labels. + """ + channels, rows, cols = np.nonzero(mask) + if channels.size == 0: + return None + + top_mins: List[int] = [] + top_maxs: List[int] = [] + left_mins: List[int] = [] + left_maxs: List[int] = [] + C = mask.shape[0] + for ch in range(C): + rs, cs = np.nonzero(mask[ch]) + if rs.size == 0: + continue + r_min, r_max = int(rs.min()), int(rs.max()) + c_min, c_max = int(cs.min()), int(cs.max()) + # For each channel, record the row/col extents + top_mins.append(r_max) + top_maxs.append(r_min) + left_mins.append(c_max) + left_maxs.append(c_min) + + return top_mins, top_maxs, left_mins, left_maxs + + +class SpatialCropAllClasses(Randomizable, Crop): + """ + Cropper for multi-label instance masks and images: ensures each label-channel's + instances lie within the crop if possible. + + Must be called on a mask tensor first to compute the crop, then on the image. + + Args: + roi_size: desired crop size (height, width). + num_candidates: fallback samples when no single crop fits all instances. + lazy: defer actual cropping. + """ + def __init__( + self, + roi_size: Sequence[int], + num_candidates: int = 10, + lazy: bool = False, + ) -> None: + super().__init__(lazy=lazy) + self.roi_size = tuple(roi_size) + self.num_candidates = num_candidates + self._slices: Optional[Tuple[slice, ...]] = None + + def randomize(self, img_size: Sequence[int]) -> None: # type: ignore + """ + Choose crop offsets so that each non-empty channel is included if possible. + """ + height, width = img_size + crop_h, crop_w = self.roi_size + max_top = max(0, height - crop_h) + max_left = max(0, width - crop_w) + + # Compute per-channel bbox constraints + mask = self._img + bboxes = _compute_multilabel_bbox(mask) + if bboxes is None: + # no labels: random patch using MONAI utils + logger.warning("No labels found; using random patch.") + # determine actual patch size (fallback) + self._size = fall_back_tuple(self.roi_size, img_size) + # compute valid size for random patch + valid_size = get_valid_patch_size(img_size, self._size) + # directly get random patch slices + self._slices = get_random_patch(img_size, valid_size, self.R) + return + else: + top_mins, top_maxs, left_mins, left_maxs = bboxes + # Convert to allowable windows + # top_min_global = max(r_max - crop_h +1 for each channel) + global_top_min = max(0, max(r_max - crop_h + 1 for r_max in top_mins)) + # top_max_global = min(r_min for each channel) + global_top_max = min(min(top_maxs), max_top) + # same for left + global_left_min = max(0, max(c_max - crop_w + 1 for c_max in left_mins)) + global_left_max = min(min(left_maxs), max_left) + + if global_top_min <= global_top_max and global_left_min <= global_left_max: + # there is a window covering all channels fully + top = self.R.randint(global_top_min, global_top_max + 1) + left = self.R.randint(global_left_min, global_left_max + 1) + else: + # fallback: sample candidates to maximize channel coverage + logger.warning( + f"Cannot fit all instances; sampling {self.num_candidates} candidates." + ) + best_cover = -1 + best_top = best_left = 0 + C = mask.shape[0] + for _ in range(self.num_candidates): + cand_top = self.R.randint(0, max_top + 1) + cand_left = self.R.randint(0, max_left + 1) + window = mask[:, cand_top : cand_top + crop_h, cand_left : cand_left + crop_w] + cover = sum(int(window[ch].any()) for ch in range(C)) + if cover > best_cover: + best_cover = cover + best_top, best_left = cand_top, cand_left + logger.info(f"Selected crop covering {best_cover}/{C} channels.") + top, left = best_top, best_left + + # store slices for use on both mask and image + self._slices = ( + slice(None), + slice(top, top + crop_h), + slice(left, left + crop_w), + ) + + def __call__(self, img: torch.Tensor, lazy: Optional[bool] = None) -> torch.Tensor: # type: ignore + """ + On first call (mask), computes crop. On subsequent (image), just applies. + Raises if mask not provided first. + """ + # Determine tensor shape + img_size = ( + img.peek_pending_shape()[1:] + if isinstance(img, MetaTensor) + else img.shape[1:] + ) + # First call must be mask to compute slices + if self._slices is None: + if not torch.is_floating_point(img) and img.dtype in (torch.uint8, torch.int16, torch.int32, torch.int64): + # assume integer mask + self._img = img.cpu().numpy() + self.randomize(img_size) + else: + raise RuntimeError( + "Mask tensor must be passed first for computing crop bounds." + ) + # Now apply stored slice + if self._slices is None: + raise RuntimeError("Crop slices not computed; call on mask first.") + lazy_exec = self.lazy if lazy is None else lazy + return super().__call__(img=img, slices=self._slices, lazy=lazy_exec) + + +class RandSpatialCropAllClassesd(RandCropd): + """ + Dict-based wrapper: applies SpatialCropAllClasses to mask then image. + Requires mask present or raises. + """ + def __init__( + self, + keys: Sequence, + roi_size: Sequence[int], + num_candidates: int = 10, + allow_missing_keys: bool = False, + lazy: bool = False, + ): + cropper = SpatialCropAllClasses( + roi_size=roi_size, + num_candidates=num_candidates, + lazy=lazy, + ) + super().__init__( + keys=keys, + cropper=cropper, + allow_missing_keys=allow_missing_keys, + lazy=lazy, + ) diff --git a/core/segmentator.py b/core/segmentator.py index e76c7d2..eda58f0 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -1,3 +1,4 @@ +import time import random import numpy as np from numba import njit, prange @@ -69,6 +70,8 @@ class CellSegmentator: self._valid_dataloader: Optional[DataLoader] = None self._test_dataloader: Optional[DataLoader] = None self._predict_dataloader: Optional[DataLoader] = None + + self._best_weights = None def create_dataloaders( @@ -315,9 +318,14 @@ class CellSegmentator: logger.info(l) - def train(self) -> None: + def train(self, save_results: bool = True, only_masks: bool = False) -> None: """ Train the model over multiple epochs, including validation and test. + + Args: + save_results (bool): If True, the predicted masks and test metrics will be saved. + only_masks (bool): If True and save_results is True, only raw predicted masks are saved, + without visualization overlays. """ # Ensure training is enabled in dataset setup if not self._dataset_setup.is_training: @@ -335,7 +343,6 @@ class CellSegmentator: logger.info(f"\n{'=' * 50}") best_f1_score = 0.0 - best_weights = None for epoch in range(1, self._dataset_setup.training.num_epochs + 1): train_metrics = self.__run_epoch("train", epoch) @@ -356,29 +363,38 @@ class CellSegmentator: if f1 > best_f1_score: best_f1_score = f1 # Deep copy weights to avoid reference issues - best_weights = copy.deepcopy(self._model.state_dict()) + self._best_weights = copy.deepcopy(self._model.state_dict()) logger.info(f"Updated best model weights with F1 score: {f1:.4f}") # Restore best model weights if available - if best_weights is not None: - self._model.load_state_dict(best_weights) + if self._best_weights is not None: + self._model.load_state_dict(self._best_weights) if self._test_dataloader is not None: - test_metrics = self.__run_epoch("test") + test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks) self.__print_with_logging(test_metrics, 0) - def evaluate(self) -> None: + def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None: """ Run a full test epoch and display/log the resulting metrics. + + Args: + save_results (bool): If True, the predicted masks and test metrics will be saved. + only_masks (bool): If True and save_results is True, only raw predicted masks are saved, + without visualization overlays. """ - test_metrics = self.__run_epoch("test") + test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks) self.__print_with_logging(test_metrics, 0) - def predict(self) -> None: + def predict(self, only_masks: bool = False) -> None: """ Run inference on the predict set and save the resulting instance masks. + + Args: + only_masks (bool): If True, only raw predicted masks are saved, + without visualization overlays. """ # Ensure the predict DataLoader has been set if self._predict_dataloader is None: @@ -404,43 +420,62 @@ class CellSegmentator: preds, _ = self.__post_process_predictions(raw_output) # Save out the predicted masks, using batch_counter to index files - self.__save_prediction_masks(batch, preds, batch_counter) + self.__save_prediction_masks( + sample=batch, + predicted_mask=preds, + start_index=batch_counter, + only_masks=only_masks + ) # Increment counter by batch size for unique file naming batch_counter += inputs.shape[0] - def run(self) -> None: + def run(self, save_results: bool = True, only_masks: bool = False) -> None: """ - Orchestrate the full workflow: + Orchestrate the full workflow and report execution time: - If training is enabled in the dataset setup, start training. - Otherwise, if a test DataLoader is provided, run evaluation. - Else if a prediction DataLoader is provided, run inference/prediction. - If neither loader is available in non‐training mode, raise an error. + + Args: + save_results (bool): If True, the predicted masks and test metrics will be saved. + only_masks (bool): If True and save_results is True, only raw predicted masks are saved, + without visualization overlays. """ + start_time = time.time() + # 1) TRAINING PATH if self._dataset_setup.is_training: # Launch the full training loop (with validation, scheduler steps, etc.) - self.train() - return - - # 2) NON-TRAINING PATH (TEST or PREDICT) - # Prefer test if available - if self._test_dataloader is not None: - # Run a single evaluation epoch on the test set and log metrics - self.evaluate() - return - - # If no test loader, fall back to prediction if available - if self._predict_dataloader is not None: - # Run inference on the predict set and save outputs - self.predict() - return + self.train(save_results=save_results, only_masks=only_masks) + else: + # 2) NON-TRAINING PATH (TEST or PREDICT) + if self._test_dataloader is not None: + # Run a single evaluation epoch on the test set and log metrics + self.evaluate(save_results=save_results, only_masks=only_masks) + elif self._predict_dataloader is not None: + # Run inference on the predict set and save outputs + self.predict(only_masks=only_masks) + else: + # 3) ERROR: no appropriate loader found + raise RuntimeError( + "Neither test nor predict DataLoader is set for non‐training mode." + ) - # 3) ERROR: no appropriate loader found - raise RuntimeError( - "Neither test nor predict DataLoader is set for non‐training mode." - ) + elapsed = time.time() - start_time + if elapsed < 60: + logger.info(f"Total execution time: {elapsed:.2f} seconds") + elif elapsed < 3600: + minutes = int(elapsed // 60) + seconds = elapsed % 60 + logger.info(f"Total execution time: {minutes} min {seconds:.2f} sec") + else: + hours = int(elapsed // 3600) + minutes = int((elapsed % 3600) // 60) + seconds = elapsed % 60 + logger.info(f"Total execution time: {hours} h {minutes} min {seconds:.2f} sec") def load_from_checkpoint(self, checkpoint_path: str) -> None: @@ -490,7 +525,12 @@ class CellSegmentator: """ # Write the checkpoint to disk os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) - torch.save(self._model.state_dict(), checkpoint_path) + torch.save(( + self._model.state_dict() + if self._best_weights is None + else self._best_weights), + checkpoint_path + ) def __parse_config(self, config: Config) -> None: @@ -523,11 +563,7 @@ class CellSegmentator: self._model = ModelRegistry.get_model_class(model.name)(model.params) # Loads model weights from a specified checkpoint - pretrained_weights = ( - config.dataset_config.training.pretrained_weights - if config.dataset_config.is_training - else config.dataset_config.testing.pretrained_weights - ) + pretrained_weights = config.dataset_config.common.pretrained_weights if pretrained_weights: self.load_from_checkpoint(pretrained_weights) logger.info(f"Loaded pre-trained weights from: {pretrained_weights}") @@ -589,6 +625,7 @@ class CellSegmentator: logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}") logger.info(f"├─ Masks subdirectory: {common.masks_subdir}") logger.info(f"└─ Predictions output dir: {common.predictions_dir}") + logger.info(f"├─ Pretrained weights: {common.pretrained_weights or 'None'}") if config.dataset_config.is_training: training = config.dataset_config.training @@ -596,7 +633,6 @@ class CellSegmentator: logger.info(f"├─ Batch size: {training.batch_size}") logger.info(f"├─ Epochs: {training.num_epochs}") logger.info(f"├─ Validation frequency: {training.val_freq}") - logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}") if training.is_split: logger.info(f"├─ Using pre-split directories:") @@ -619,11 +655,6 @@ class CellSegmentator: logger.info(f"├─ Test dir: {testing.test_dir}") logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})") logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}") - logger.info(f"├─ Use ensemble: {'yes' if testing.use_ensemble else 'no'}") - logger.info(f"└─ Pretrained weights:") - logger.info(f" ├─ Single model: {testing.pretrained_weights}") - logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}") - logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}") self._wandb_config = config.wandb_config if self._wandb_config.use_wandb: @@ -747,38 +778,62 @@ class CellSegmentator: return Dataset(data, transforms) - def __print_with_logging(self, results: Dict[str, float], step: int) -> None: + def __print_with_logging(self, results: Dict[str, Union[float, np.ndarray]], step: int) -> None: """ Print metrics in a tabular format and log to W&B. Args: - results (Dict[str, float]): results dictionary. + results (Dict[str, Union[float, np.ndarray]]): Mapping from metric names + to either a float or a ND numpy array. step (int): epoch index. """ + rows: list[tuple[str, str]] = [] + for key, val in results.items(): + if isinstance(val, np.ndarray): + # Convert array to string, e.g. '[0.2, 0.8, 0.5]' + val_str = np.array2string(val, separator=', ') + else: + # Format scalar with 4 decimal places + val_str = f"{val:.4f}" + rows.append((key, val_str)) + table = tabulate( - tabular_data=results.items(), + tabular_data=rows, headers=["Metric", "Value"], floatfmt=".4f", tablefmt="fancy_grid" ) print(table, "\n") + if self._wandb_config.use_wandb: - wandb.log(results, step=step) + # Keep only scalar values + scalar_results: dict[str, float] = {} + for key, val in results.items(): + if isinstance(val, np.ndarray): + continue + # Ensure float type + scalar_results[key] = float(val) + wandb.log(scalar_results, step=step) def __run_epoch(self, mode: Literal["train", "valid", "test"], - epoch: Optional[int] = None - ) -> Dict[str, float]: + epoch: Optional[int] = None, + save_results: bool = True, + only_masks: bool = False + ) -> Dict[str, Union[float, np.ndarray]]: """ Execute one epoch of training, validation, or testing. Args: mode (str): One of 'train', 'valid', or 'test'. epoch (int, optional): Current epoch number for logging. + save_results (bool): If True, the predicted masks and test metrics will be saved. + only_masks (bool): If True and save_results is True, only raw predicted masks are saved, + without visualization overlays. Returns: - Dict[str, float]: Loss metrics and F1 score for valid/test. + Dict[str, Union[float, np.ndarray]]: Metrics for valid/test. """ # Ensure required components are available if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None): @@ -841,18 +896,23 @@ class CellSegmentator: ) # Collecting statistics on the batch - tp, fp, fn = self.__compute_stats( + tp, fp, fn, tp_masks, fp_masks, fn_masks = self.__compute_stats( predicted_masks=preds, ground_truth_masks=labels_post, # type: ignore - iou_threshold=0.5 + iou_threshold=0.5, + return_error_masks=(mode == "test") ) all_tp.append(tp) all_fp.append(fp) all_fn.append(fn) - if mode == "test": + if mode == "test" and save_results is True: self.__save_prediction_masks( - batch, preds, batch_counter + sample=batch, + predicted_mask=preds, + start_index=batch_counter, + only_masks=only_masks, + masks=(tp_masks, fp_masks, fn_masks) # type: ignore ) # Backpropagation and optimizer step in training @@ -871,7 +931,9 @@ class CellSegmentator: if self._criterion is not None: # Collect loss metrics - epoch_metrics = {f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()} + epoch_metrics: Dict[str, Union[float, np.ndarray]] = { + f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items() + } # Reset internal loss metrics accumulator self._criterion.reset_metrics() else: @@ -884,13 +946,13 @@ class CellSegmentator: fp_array = np.vstack(all_fp) fn_array = np.vstack(all_fn) - epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore + epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( tp_array, fp_array, fn_array, reduction="micro" ) - epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( # type: ignore + epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( tp_array, fp_array, fn_array, reduction="imagewise" ) - epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore + epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( tp_array, fp_array, fn_array, reduction="macro" ) @@ -976,8 +1038,10 @@ class CellSegmentator: self, predicted_masks: np.ndarray, ground_truth_masks: np.ndarray, - iou_threshold: float = 0.5 - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + iou_threshold: float = 0.5, + return_error_masks: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, + Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """ Compute batch-wise true positives, false positives, and false negatives for instance segmentation, using a configurable IoU threshold. @@ -987,23 +1051,33 @@ class CellSegmentator: ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W). iou_threshold (float): Intersection-over-Union threshold for matching predictions to ground truths (default: 0.5). + return_error_masks (bool): Whether to also return binary error masks. Returns: - Tuple[np.ndarray, np.ndarray, np.ndarray]: + Tuple(np.ndarray, np.ndarray, np.ndarray, + np.ndarray | None, np.ndarray | None, np.ndarray | None): - tp: True positives per batch and class, shape (B, C) - fp: False positives per batch and class, shape (B, C) - fn: False negatives per batch and class, shape (B, C) + - tp_maks: True positives mask per batch and class, shape (B, C, H, W) + - fp_maks: False positives mask per batch and class, shape (B, C, H, W) + - fn_maks: False negatives mask per batch and class, shape (B, C, H, W) """ stats = compute_batch_segmentation_tp_fp_fn( batch_ground_truth=ground_truth_masks, batch_prediction=predicted_masks, iou_threshold=iou_threshold, + return_error_masks=return_error_masks, remove_boundary_objects=True ) tp = stats["tp"] fp = stats["fp"] fn = stats["fn"] - return tp, fp, fn + + tp_mask = stats["tp_mask"] if return_error_masks else None + fp_mask = stats["fp_mask"] if return_error_masks else None + fn_mask = stats["fn_mask"] if return_error_masks else None + return tp, fp, fn, tp_mask, fp_mask, fn_mask def __compute_f1_metric( @@ -1011,7 +1085,7 @@ class CellSegmentator: true_positives: np.ndarray, false_positives: np.ndarray, false_negatives: np.ndarray, - reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro" + reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro" ) -> Union[float, np.ndarray]: """ Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes. @@ -1023,8 +1097,9 @@ class CellSegmentator: reduction: - 'none': return F1 for each sample, class → shape (batch_size, num_classes) - 'micro': global F1 over all samples & classes - - 'imagewise': F1 per sample (summing over classes), then average over samples - 'macro': average class-wise F1 (classes summed over batch) + - 'imagewise': F1 per sample (summing over classes), then average over samples + - 'per_class': F1 per class (summing over batch), return vector of shape (num_classes,) - 'weighted': class-wise F1 weighted by support (TP+FN) Returns: float for reductions 'micro', 'imagewise', 'macro', 'weighted'; @@ -1040,7 +1115,11 @@ class CellSegmentator: tp_val = int(true_positives[i, c]) fp_val = int(false_positives[i, c]) fn_val = int(false_negatives[i, c]) - _, _, f1_val = compute_f1_score(tp_val, fp_val, fn_val) + _, _, f1_val = compute_f1_score( + tp_val, + fp_val, + fn_val + ) f1_matrix[i, c] = f1_val return f1_matrix @@ -1049,7 +1128,11 @@ class CellSegmentator: tp_total = int(true_positives.sum()) fp_total = int(false_positives.sum()) fn_total = int(false_negatives.sum()) - _, _, f1_global = compute_f1_score(tp_total, fp_total, fn_total) + _, _, f1_global = compute_f1_score( + tp_total, + fp_total, + fn_total + ) return f1_global # 3) Imagewise: compute per-sample F1 (sum over classes), then average @@ -1059,16 +1142,31 @@ class CellSegmentator: tp_i = int(true_positives[i].sum()) fp_i = int(false_positives[i].sum()) fn_i = int(false_negatives[i].sum()) - _, _, f1_i = compute_f1_score(tp_i, fp_i, fn_i) + _, _, f1_i = compute_f1_score( + tp_i, + fp_i, + fn_i + ) f1_per_image[i] = f1_i return float(f1_per_image.mean()) - # For macro/weighted, first aggregate per class across the batch + # Aggregate per class across the batch for per_class, macro, weighted tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,) fp_per_class = false_positives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int) - # 4) Macro: average F1 across classes equally + # 4) Per-class: compute F1 for each class and return vector + if reduction == "per_class": + f1_per_class = np.zeros(num_classes, dtype=float) + for c in range(num_classes): + _, _, f1_per_class[c] = compute_f1_score( + tp_per_class[c], + fp_per_class[c], + fn_per_class[c] + ) + return f1_per_class + + # 5) Macro: average F1 across classes equally if reduction == "macro": f1_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): @@ -1080,7 +1178,7 @@ class CellSegmentator: f1_per_class[c] = f1_c return float(f1_per_class.mean()) - # 5) Weighted: class-wise F1 weighted by support = TP + FN + # 6) Weighted: class-wise F1 weighted by support = TP + FN if reduction == "weighted": f1_per_class = np.zeros(num_classes, dtype=float) support = np.zeros(num_classes, dtype=float) @@ -1088,7 +1186,11 @@ class CellSegmentator: tp_c = tp_per_class[c] fp_c = fp_per_class[c] fn_c = fn_per_class[c] - _, _, f1_c = compute_f1_score(tp_c, fp_c, fn_c) + _, _, f1_c = compute_f1_score( + tp_c, + fp_c, + fn_c + ) f1_per_class[c] = f1_c support[c] = tp_c + fn_c total_support = support.sum() @@ -1106,7 +1208,7 @@ class CellSegmentator: true_positives: np.ndarray, false_positives: np.ndarray, false_negatives: np.ndarray, - reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro" + reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro" ) -> Union[float, np.ndarray]: """ Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes. @@ -1121,8 +1223,9 @@ class CellSegmentator: reduction: - 'none': return AP for each sample and class → shape (batch_size, num_classes) - 'micro': global AP over all samples & classes - - 'imagewise': AP per sample (summing stats over classes), then average over batch - 'macro': average class-wise AP (each class summed over batch) + - 'imagewise': AP per sample (summing stats over classes), then average over batch + - 'per_class': AP per class (summing over batch), return vector of shape (num_classes,) - 'weighted': class-wise AP weighted by support (TP+FN) Returns: @@ -1139,7 +1242,11 @@ class CellSegmentator: tp_val = int(true_positives[i, c]) fp_val = int(false_positives[i, c]) fn_val = int(false_negatives[i, c]) - ap_val = compute_average_precision_score(tp_val, fp_val, fn_val) + ap_val = compute_average_precision_score( + tp_val, + fp_val, + fn_val + ) ap_matrix[i, c] = ap_val return ap_matrix @@ -1148,7 +1255,11 @@ class CellSegmentator: tp_total = int(true_positives.sum()) fp_total = int(false_positives.sum()) fn_total = int(false_negatives.sum()) - return compute_average_precision_score(tp_total, fp_total, fn_total) + return compute_average_precision_score( + tp_total, + fp_total, + fn_total + ) # 3) Imagewise: compute per-sample AP (sum over classes), then mean if reduction == "imagewise": @@ -1157,7 +1268,11 @@ class CellSegmentator: tp_i = int(true_positives[i].sum()) fp_i = int(false_positives[i].sum()) fn_i = int(false_negatives[i].sum()) - ap_per_image[i] = compute_average_precision_score(tp_i, fp_i, fn_i) + ap_per_image[i] = compute_average_precision_score( + tp_i, + fp_i, + fn_i + ) return float(ap_per_image.mean()) # For macro and weighted: first aggregate per class across batch @@ -1165,7 +1280,18 @@ class CellSegmentator: fp_per_class = false_positives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int) - # 4) Macro: average AP across classes equally + # 4) Per-class: compute F1 for each class and return vector + if reduction == "per_class": + ap_per_class = np.zeros(num_classes, dtype=float) + for c in range(num_classes): + ap_per_class[c] = compute_average_precision_score( + tp_per_class[c], + fp_per_class[c], + fn_per_class[c] + ) + return ap_per_class + + # 5) Macro: average AP across classes equally if reduction == "macro": ap_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): @@ -1176,7 +1302,7 @@ class CellSegmentator: ) return float(ap_per_class.mean()) - # 5) Weighted: class-wise AP weighted by support = TP + FN + # 6) Weighted: class-wise AP weighted by support = TP + FN if reduction == "weighted": ap_per_class = np.zeros(num_classes, dtype=float) support = np.zeros(num_classes, dtype=float) @@ -1184,7 +1310,11 @@ class CellSegmentator: tp_c = tp_per_class[c] fp_c = fp_per_class[c] fn_c = fn_per_class[c] - ap_per_class[c] = compute_average_precision_score(tp_c, fp_c, fn_c) + ap_per_class[c] = compute_average_precision_score( + tp_c, + fp_c, + fn_c + ) support[c] = tp_c + fn_c total_support = support.sum() if total_support == 0: @@ -1215,87 +1345,159 @@ class CellSegmentator: sample: Dict[str, Any], predicted_mask: Union[np.ndarray, torch.Tensor], start_index: int = 0, + only_masks: bool = False, + masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None ) -> None: """ - Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders. + Save multi-channel predicted masks as TIFFs and + corresponding visualizations as PNGs in separate folders. Args: - sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict'). + sample (Dict[str, Any]): Batch sample from MONAI + LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict'). predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W). start_index (int): Starting index for naming when metadata is missing. + only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations. + masks (Tuple[np.ndarray, np.ndarray, np.ndarray] | None): + A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None. """ - # Determine base paths + # Base directories (created once per call) base_output_dir = self._dataset_setup.common.predictions_dir masks_dir = base_output_dir plots_dir = os.path.join(base_output_dir, "plots") + evaluate_dir = os.path.join(plots_dir, "evaluate") os.makedirs(masks_dir, exist_ok=True) os.makedirs(plots_dir, exist_ok=True) + os.makedirs(evaluate_dir, exist_ok=True) - # Extract image (C, H, W) or batch of images (B, C, H, W), and metadata - image_obj = sample.get("image") # Expected shape: (C, H, W) or (B, C, H, W) - mask_obj = sample.get("mask") # Expected shape: (C, H, W) or (B, C, H, W) - image_meta = sample.get("image_meta_dict") - - # Convert tensors to numpy + # Convert tensors to numpy if necessary def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: - if isinstance(x, torch.Tensor): - return x.cpu().numpy() - return x - - image_array = to_numpy(image_obj) if image_obj is not None else None - mask_array = to_numpy(mask_obj) if mask_obj is not None else None - pred_array = to_numpy(predicted_mask) - - # Handle batch dimension: (B, C, H, W) - if pred_array.ndim == 4: - for idx in range(pred_array.shape[0]): - batch_sample: Dict[str, Any] = {} - if image_array is not None and image_array.ndim == 4: - batch_sample["image"] = image_array[idx] - if isinstance(image_meta, dict) and "filename_or_obj" in image_meta: - batch_sample["image_meta_dict"] = image_meta["filename_or_obj"][idx] - if mask_array is not None and mask_array.ndim == 4: - batch_sample["mask"] = mask_array[idx] - self.__save_prediction_masks( - batch_sample, - pred_array[idx], - start_index=start_index+idx - ) - return + return x.cpu().numpy() if isinstance(x, torch.Tensor) else x + + pred_array = to_numpy(predicted_mask).astype(np.uint16) + + # Handle batch dimension + for idx in range(pred_array.shape[0]): + batch_sample: Dict[str, Any] = {} + # copy per-sample image and meta + img = to_numpy(sample["image"]) + if img.ndim == 4: + batch_sample["image"] = img[idx] + if "mask" in sample: + msk = to_numpy(sample["mask"]).astype(np.uint16) + if msk.ndim == 4: + batch_sample["mask"] = msk[idx] + + image_meta = sample.get("image_meta_dict") + if isinstance(image_meta, dict) and "filename_or_obj" in image_meta: + fname = image_meta["filename_or_obj"][idx] + batch_sample["image_name"] = fname + + single_masks = ( + (masks[0][idx], masks[1][idx], masks[2][idx]) if masks is not None else None + ) + self.__save_single_prediction_mask( + sample=batch_sample, + pred_array=pred_array[idx], + start_index=start_index + idx, + masks_dir=masks_dir, + plots_dir=plots_dir, + evaluate_dir=evaluate_dir, + only_masks=only_masks, + masks=single_masks, + ) + + + def __save_single_prediction_mask( + self, + sample: Dict[str, Any], + pred_array: np.ndarray, + start_index: int, + masks_dir: str, + plots_dir: str, + evaluate_dir: str, + only_masks: bool = False, + masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None + ) -> None: + """ + Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations. + Assumes output directories already exist. - # Determine base filename + Args: + sample (Dict[str, Any]): Dictionary containing 'image', 'mask', + and optional 'image_meta_dict' for metadata. + pred_array (np.ndarray): Predicted mask array of shape (C,H,W). + start_index (int): Base index for generating filenames when metadata is missing. + masks_dir (str): Directory for saving TIFF masks. + plots_dir (str): Directory for saving PNG visualizations. + evaluate_dir (str): Directory for saving PNG visualizations of evaluation results. + only_masks (bool): If True, saves only TIFF mask files; skips PNG plots. + masks (Tuple[np.ndarray, np.ndarray, np.ndarray], optional): A tuple of + true-positive, false-positive, and false-negative mask arrays, + each of shape (C,H,W). Defaults to None. + """ + if pred_array.ndim == 2: + pred_array = np.expand_dims(pred_array, axis=0) + elif pred_array.ndim != 3: + raise ValueError( + f"Unsupported predicted_mask dimensions: {pred_array.ndim}." + "Expected 2D (H,W) or 3D (C,H,W)." + ) + # Handle image array if present + image_array: np.ndarray = sample["image"] + if image_array.ndim == 2: + image_array = np.expand_dims(image_array, axis=0) + elif image_array.ndim != 3: + raise ValueError( + f"Unsupported image dimensions: {image_array.ndim}." + "Expected 2D (H,W) or 3D (C,H,W)." + ) + + true_mask_array: Optional[np.ndarray] = sample.get("mask") + if isinstance(true_mask_array, np.ndarray): + if true_mask_array.ndim == 2: + true_mask_array = np.expand_dims(true_mask_array, axis=0) + elif true_mask_array.ndim != 3: + raise ValueError( + f"Unsupported true_mask_array dimensions: {true_mask_array.ndim}." + "Expected 2D (H,W) or 3D (C,H,W)." + ) + + # Determine filename base + image_meta = sample.get("image_name") if isinstance(image_meta, (str, os.PathLike)): base_name = os.path.splitext(os.path.basename(image_meta))[0] else: - # Use provided start_index when metadata missing base_name = f"prediction_{start_index:04d}" - # Save mask TIFF (16-bit) - mask_filename = f"{base_name}_mask.tif" - mask_path = os.path.join(masks_dir, mask_filename) + # Save main mask TIFF + mask_path = os.path.join(masks_dir, f"{base_name}_mask.tif") tiff.imwrite(mask_path, pred_array.astype(np.uint16), compression="zlib") - # Now pred_array shape is (C, H, W) - num_channels = pred_array.shape[0] - for channel_idx in range(num_channels): - channel_mask = pred_array[channel_idx] - - # File names - plot_filename = f"{base_name}_ch{channel_idx:01d}.png" - plot_path = os.path.join(plots_dir, plot_filename) - - # Extract corresponding true mask channel if exists - true_mask = None - if mask_array is not None and mask_array.ndim == 3: - true_mask = mask_array[channel_idx] + if only_masks: + return - # Generate and save visualization + # Save channel-wise plots + num_channels = pred_array.shape[0] + for ch in range(num_channels): + true_ch = true_mask_array[ch] if true_mask_array is not None else None + self.__plot_mask( - file_path=plot_path, - image_data=image_array, # type: ignore - predicted_mask=channel_mask, - true_mask=true_mask, + file_path=os.path.join(plots_dir, f"{base_name}_ch{ch}.png"), + image_data=image_array, + predicted_mask=pred_array[ch], + true_mask=true_ch, ) + + if masks is not None and true_ch is not None: + self.__save_mask_comparison_visuals( + gt=true_ch, + pred=pred_array[ch], + tp_mask=masks[0][ch], + fp_mask=masks[1][ch], + fn_mask=masks[2][ch], + file_path=os.path.join(evaluate_dir, f"{base_name}_ch{ch}.png") + ) def __plot_mask( @@ -1307,6 +1509,16 @@ class CellSegmentator: ) -> None: """ Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided. + + Args: + file_path (str): Path where the visualization image will be saved. + image_data (np.ndarray): The original input image array, expected shape (C, H, W). + predicted_mask (np.ndarray): The predicted mask array, shape (H, W), + depending on the task. + true_mask (Optional[np.ndarray], optional): The ground-truth mask array. + If provided, an additional row with true mask and overlap visualization + will be added to the plot. Default is None. + """ img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data @@ -1317,7 +1529,7 @@ class CellSegmentator: ('Original Image','Predicted Mask','Predicted Contours')) else: fig, axs = plt.subplots(2,3,figsize=(15,10)) - plt.subplots_adjust(wspace=0.02, hspace=0.1) + plt.subplots_adjust(wspace=0.02, hspace=0.02) # row 0: predicted self.__plot_panels(axs[0], img, predicted_mask, 'red', ('Original Image','Predicted Mask','Predicted Contours')) @@ -1370,6 +1582,69 @@ class CellSegmentator: ax2.contour(boundaries, colors=contour_color, linewidths=0.5) ax2.set_title(titles[2]) ax2.axis('off') + + + def __save_mask_comparison_visuals( + self, + gt: np.ndarray, + pred: np.ndarray, + tp_mask: np.ndarray, + fp_mask: np.ndarray, + fn_mask: np.ndarray, + file_path: str + ) -> None: + """ + Creates and saves a 1x3 subplot figure showing: + 1) True mask with boundaries + 2) Predicted mask without boundaries + 3) Overlay mask combining FP (R), TP (G), FN (B) + + Args: + gt (np.ndarray): Ground truth mask (H, W). + pred (np.ndarray): Predicted mask (H, W). + tp_mask (np.ndarray): True positive mask (H, W). + fp_mask (np.ndarray): False positive mask (H, W). + fn_mask (np.ndarray): False negative mask (H, W). + file_path (str): Path where the visualization image will be saved. + """ + # Prepare overlay mask + overlap_mask = np.zeros((*gt.shape[:2], 3), dtype=np.uint8) + overlap_mask[..., 0] = np.where(fp_mask, 255, 0) + overlap_mask[..., 1] = np.where(tp_mask, 255, 0) + overlap_mask[..., 2] = np.where(fn_mask, 255, 0) + + # Set up figure + fig, axes = plt.subplots(1, 3, figsize=(15, 5), + gridspec_kw={'width_ratios': [1, 1, 1]}) + plt.subplots_adjust(wspace=0.02, hspace=0.0, + left=0.05, right=0.95, top=0.95, bottom=0.05) + + # Colormap for instances + num_instances = max(np.max(gt), np.max(pred)) + cmap = plt.get_cmap("gist_ncar") + colors = [cmap(i / num_instances) for i in range(num_instances)] + cmap = mcolors.ListedColormap(colors) + + # Plot true mask + axes[0].imshow(gt, cmap=cmap) + axes[0].contour(find_boundaries(gt, mode="thick"), colors="black", linewidths=0.5) + axes[0].set_title("True Mask") + axes[0].axis("off") + + # Plot predicted mask + axes[1].imshow(pred, cmap=cmap) + axes[1].contour(find_boundaries(pred, mode="thick"), colors="black", linewidths=0.5) + axes[1].set_title("Predicted Mask") + axes[1].axis("off") + + # Plot overlay + axes[2].imshow(overlap_mask) + axes[2].set_title("Overlay Mask (R-FP; G-TP; B-FN)") + axes[2].axis("off") + + # Save + plt.savefig(file_path, bbox_inches="tight", dpi=300) + plt.close() def __compute_flows_from_masks( @@ -1522,7 +1797,7 @@ class CellSegmentator: flow_output = np.zeros((2, height, width), dtype=np.float32) ys_np = y.cpu().numpy() - 1 xs_np = x.cpu().numpy() - 1 - flow_output[:, ys_np, xs_np] = mu + flow_output[:, ys_np, xs_np] = mu.reshape(2, -1) flows[2*channel: 2*channel + 2] = flow_output return flows diff --git a/core/utils/measures.py b/core/utils/measures.py index 53135ea..ffe1fba 100644 --- a/core/utils/measures.py +++ b/core/utils/measures.py @@ -21,7 +21,7 @@ __all__ = [ "compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score" ] -logger = get_logger() +logger = get_logger(__name__) def compute_f1_score( diff --git a/main.py b/main.py index c6e1970..aa9894f 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,20 @@ def main(): default='train', help='Run mode: train, test or predict' ) + parser.add_argument( + '-s', '--save-masks', + action='store_true', + default=True, + help='If set to False, do not save predicted masks; by default, saving is enabled' + ) + parser.add_argument( + '--only-masks', + action='store_true', + default=False, + help=('If set and save-masks set, save only the raw predicted' + ' masks without additional visualizations or metrics') + ) + args = parser.parse_args() mode = args.mode @@ -56,22 +70,25 @@ def main(): if config.wandb_config.use_wandb: wandb.watch(segmentator._model, log="all", log_graph=True) - # Run training (or prediction, if implemented) - segmentator.run() + try: + segmentator.run(save_results=args.save_masks, only_masks=args.only_masks) + except Exception: + raise + finally: + if config.dataset_config.is_training: + # Prepare saving path + weights_dir = ( + wandb.run.dir if config.wandb_config.use_wandb else "weights" # type: ignore + ) + saving_path = os.path.join( + weights_dir, + os.path.basename(config.dataset_config.common.predictions_dir) + '.pth' + ) + segmentator.save_checkpoint(saving_path) - if config.dataset_config.is_training: - # Prepare saving path - weights_dir = ( - wandb.run.dir if config.wandb_config.use_wandb else "weights" # type: ignore - ) - saving_path = os.path.join( - weights_dir, - os.path.basename(config.dataset_config.common.predictions_dir) + '.pth' - ) - segmentator.save_checkpoint(saving_path) + if config.wandb_config.use_wandb: + wandb.save(saving_path) - if config.wandb_config.use_wandb: - wandb.save(saving_path) if __name__ == "__main__":