Added the creation of error masks for each class;

Added arguments for creating only prediction masks, as well as refusing to save the output;
Added execution time display;
The pretrained_weight parameter has been moved to common;
Parameters related to ensembles have been removed.
master
laynholt 2 months ago
parent b201367596
commit 28f978956c

@ -8,11 +8,12 @@ class DatasetCommonConfig(BaseModel):
Common configuration fields shared by both training and testing.
"""
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars'
predictions_dir: str = "." # Directory to save predictions
pretrained_weights: str = "" # Path to pretrained weights
@model_validator(mode="after")
def validate_common(self) -> "DatasetCommonConfig":
@ -39,7 +40,7 @@ class TrainingPreSplitInfo(BaseModel):
Contains:
- train_dir, valid_dir, test_dir: Directories for training, validation, and testing data.
"""
train_dir: str = "." # Directory for training data if data is pre-split
train_dir: str = "." # Directory for training data if data is pre-split
valid_dir: str = "" # Directory for validation data if data is pre-split
test_dir: str = "" # Directory for testing data if data is pre-split
@ -72,7 +73,6 @@ class DatasetTrainingConfig(BaseModel):
batch_size: int = 1 # Batch size for training
num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training
pretrained_weights: str = "" # Path to pretrained weights for training
@field_validator("train_size", "valid_size", "test_size", mode="before")
@ -137,15 +137,6 @@ class DatasetTrainingConfig(BaseModel):
raise ValueError("offsets must be >= 0")
return self
@model_validator(mode="after")
def validate_pretrained(self) -> "DatasetTrainingConfig":
"""
Validates that pretrained_weights is provided and exists.
"""
if self.pretrained_weights and not os.path.exists(self.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
return self
class DatasetTestingConfig(BaseModel):
"""
@ -155,11 +146,6 @@ class DatasetTestingConfig(BaseModel):
test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
test_offset: int = 0 # Offset for testing data
shuffle: bool = True # Shuffle data
use_ensemble: bool = False # Flag to use ensemble mode in testing
ensemble_pretrained_weights1: str = "."
ensemble_pretrained_weights2: str = "."
pretrained_weights: str = "."
@field_validator("test_size", mode="before")
def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]:
@ -191,25 +177,11 @@ class DatasetTestingConfig(BaseModel):
"""
Validates the testing configuration:
- test_dir must be non-empty and exist.
- If use_ensemble is True, both ensemble_pretrained_weights1 and ensemble_pretrained_weights2 must be provided and exist.
- If use_ensemble is False, pretrained_weights must be provided and exist.
"""
if not self.test_dir:
raise ValueError("In testing configuration, test_dir must be provided and non-empty")
if not os.path.exists(self.test_dir):
raise ValueError(f"Path for test_dir does not exist: {self.test_dir}")
if self.use_ensemble:
for field in ["ensemble_pretrained_weights1", "ensemble_pretrained_weights2"]:
value = getattr(self, field)
if not value:
raise ValueError(f"When use_ensemble is True, {field} must be provided and non-empty")
if not os.path.exists(value):
raise ValueError(f"Path for {field} does not exist: {value}")
else:
if not self.pretrained_weights:
raise ValueError("When use_ensemble is False, pretrained_weights must be provided and non-empty")
if not os.path.exists(self.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
if self.test_offset < 0:
raise ValueError("test_offset must be >= 0")
return self
@ -237,11 +209,17 @@ class DatasetConfig(BaseModel):
if (self.training.is_split and self.training.pre_split.test_dir) or (not self.training.is_split):
if self.training.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.pretrained_weights and not os.path.exists(self.common.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}")
else:
if self.testing is None:
raise ValueError("Testing configuration must be provided when is_training is False")
if self.testing.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero")
if not self.common.pretrained_weights:
raise ValueError("When testing pretrained_weights must be provided and non-empty")
if not os.path.exists(self.common.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}")
return self
def model_dump(self, **kwargs) -> Dict[str, Any]:

@ -10,7 +10,7 @@ from core.logger import get_logger
__all__ = ["BoundaryExclusion", "IntensityDiversification"]
logger = get_logger("cell_aware")
logger = get_logger(__name__)
class BoundaryExclusion(MapTransform):

@ -0,0 +1,193 @@
import torch
import numpy as np
from typing import Hashable, List, Sequence, Optional, Tuple
from monai.utils.misc import fall_back_tuple
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms import Randomizable, RandCropd, Crop # type: ignore
from core.logger import get_logger
logger = get_logger(__name__)
def _compute_multilabel_bbox(
mask: np.ndarray
) -> Optional[Tuple[List[int], List[int], List[int], List[int]]]:
"""
Compute per-channel bounding-box constraints and return lists of limits for each axis.
Args:
mask: multi-channel instance mask of shape (C, H, W).
Returns:
A tuple of four lists:
- top_mins: list of r_max for each non-empty channel
- top_maxs: list of r_min for each non-empty channel
- left_mins: list of c_max for each non-empty channel
- left_maxs: list of c_min for each non-empty channel
Or None if mask contains no positive labels.
"""
channels, rows, cols = np.nonzero(mask)
if channels.size == 0:
return None
top_mins: List[int] = []
top_maxs: List[int] = []
left_mins: List[int] = []
left_maxs: List[int] = []
C = mask.shape[0]
for ch in range(C):
rs, cs = np.nonzero(mask[ch])
if rs.size == 0:
continue
r_min, r_max = int(rs.min()), int(rs.max())
c_min, c_max = int(cs.min()), int(cs.max())
# For each channel, record the row/col extents
top_mins.append(r_max)
top_maxs.append(r_min)
left_mins.append(c_max)
left_maxs.append(c_min)
return top_mins, top_maxs, left_mins, left_maxs
class SpatialCropAllClasses(Randomizable, Crop):
"""
Cropper for multi-label instance masks and images: ensures each label-channel's
instances lie within the crop if possible.
Must be called on a mask tensor first to compute the crop, then on the image.
Args:
roi_size: desired crop size (height, width).
num_candidates: fallback samples when no single crop fits all instances.
lazy: defer actual cropping.
"""
def __init__(
self,
roi_size: Sequence[int],
num_candidates: int = 10,
lazy: bool = False,
) -> None:
super().__init__(lazy=lazy)
self.roi_size = tuple(roi_size)
self.num_candidates = num_candidates
self._slices: Optional[Tuple[slice, ...]] = None
def randomize(self, img_size: Sequence[int]) -> None: # type: ignore
"""
Choose crop offsets so that each non-empty channel is included if possible.
"""
height, width = img_size
crop_h, crop_w = self.roi_size
max_top = max(0, height - crop_h)
max_left = max(0, width - crop_w)
# Compute per-channel bbox constraints
mask = self._img
bboxes = _compute_multilabel_bbox(mask)
if bboxes is None:
# no labels: random patch using MONAI utils
logger.warning("No labels found; using random patch.")
# determine actual patch size (fallback)
self._size = fall_back_tuple(self.roi_size, img_size)
# compute valid size for random patch
valid_size = get_valid_patch_size(img_size, self._size)
# directly get random patch slices
self._slices = get_random_patch(img_size, valid_size, self.R)
return
else:
top_mins, top_maxs, left_mins, left_maxs = bboxes
# Convert to allowable windows
# top_min_global = max(r_max - crop_h +1 for each channel)
global_top_min = max(0, max(r_max - crop_h + 1 for r_max in top_mins))
# top_max_global = min(r_min for each channel)
global_top_max = min(min(top_maxs), max_top)
# same for left
global_left_min = max(0, max(c_max - crop_w + 1 for c_max in left_mins))
global_left_max = min(min(left_maxs), max_left)
if global_top_min <= global_top_max and global_left_min <= global_left_max:
# there is a window covering all channels fully
top = self.R.randint(global_top_min, global_top_max + 1)
left = self.R.randint(global_left_min, global_left_max + 1)
else:
# fallback: sample candidates to maximize channel coverage
logger.warning(
f"Cannot fit all instances; sampling {self.num_candidates} candidates."
)
best_cover = -1
best_top = best_left = 0
C = mask.shape[0]
for _ in range(self.num_candidates):
cand_top = self.R.randint(0, max_top + 1)
cand_left = self.R.randint(0, max_left + 1)
window = mask[:, cand_top : cand_top + crop_h, cand_left : cand_left + crop_w]
cover = sum(int(window[ch].any()) for ch in range(C))
if cover > best_cover:
best_cover = cover
best_top, best_left = cand_top, cand_left
logger.info(f"Selected crop covering {best_cover}/{C} channels.")
top, left = best_top, best_left
# store slices for use on both mask and image
self._slices = (
slice(None),
slice(top, top + crop_h),
slice(left, left + crop_w),
)
def __call__(self, img: torch.Tensor, lazy: Optional[bool] = None) -> torch.Tensor: # type: ignore
"""
On first call (mask), computes crop. On subsequent (image), just applies.
Raises if mask not provided first.
"""
# Determine tensor shape
img_size = (
img.peek_pending_shape()[1:]
if isinstance(img, MetaTensor)
else img.shape[1:]
)
# First call must be mask to compute slices
if self._slices is None:
if not torch.is_floating_point(img) and img.dtype in (torch.uint8, torch.int16, torch.int32, torch.int64):
# assume integer mask
self._img = img.cpu().numpy()
self.randomize(img_size)
else:
raise RuntimeError(
"Mask tensor must be passed first for computing crop bounds."
)
# Now apply stored slice
if self._slices is None:
raise RuntimeError("Crop slices not computed; call on mask first.")
lazy_exec = self.lazy if lazy is None else lazy
return super().__call__(img=img, slices=self._slices, lazy=lazy_exec)
class RandSpatialCropAllClassesd(RandCropd):
"""
Dict-based wrapper: applies SpatialCropAllClasses to mask then image.
Requires mask present or raises.
"""
def __init__(
self,
keys: Sequence,
roi_size: Sequence[int],
num_candidates: int = 10,
allow_missing_keys: bool = False,
lazy: bool = False,
):
cropper = SpatialCropAllClasses(
roi_size=roi_size,
num_candidates=num_candidates,
lazy=lazy,
)
super().__init__(
keys=keys,
cropper=cropper,
allow_missing_keys=allow_missing_keys,
lazy=lazy,
)

@ -1,3 +1,4 @@
import time
import random
import numpy as np
from numba import njit, prange
@ -69,6 +70,8 @@ class CellSegmentator:
self._valid_dataloader: Optional[DataLoader] = None
self._test_dataloader: Optional[DataLoader] = None
self._predict_dataloader: Optional[DataLoader] = None
self._best_weights = None
def create_dataloaders(
@ -315,9 +318,14 @@ class CellSegmentator:
logger.info(l)
def train(self) -> None:
def train(self, save_results: bool = True, only_masks: bool = False) -> None:
"""
Train the model over multiple epochs, including validation and test.
Args:
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
"""
# Ensure training is enabled in dataset setup
if not self._dataset_setup.is_training:
@ -335,7 +343,6 @@ class CellSegmentator:
logger.info(f"\n{'=' * 50}")
best_f1_score = 0.0
best_weights = None
for epoch in range(1, self._dataset_setup.training.num_epochs + 1):
train_metrics = self.__run_epoch("train", epoch)
@ -356,29 +363,38 @@ class CellSegmentator:
if f1 > best_f1_score:
best_f1_score = f1
# Deep copy weights to avoid reference issues
best_weights = copy.deepcopy(self._model.state_dict())
self._best_weights = copy.deepcopy(self._model.state_dict())
logger.info(f"Updated best model weights with F1 score: {f1:.4f}")
# Restore best model weights if available
if best_weights is not None:
self._model.load_state_dict(best_weights)
if self._best_weights is not None:
self._model.load_state_dict(self._best_weights)
if self._test_dataloader is not None:
test_metrics = self.__run_epoch("test")
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0)
def evaluate(self) -> None:
def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None:
"""
Run a full test epoch and display/log the resulting metrics.
Args:
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
"""
test_metrics = self.__run_epoch("test")
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0)
def predict(self) -> None:
def predict(self, only_masks: bool = False) -> None:
"""
Run inference on the predict set and save the resulting instance masks.
Args:
only_masks (bool): If True, only raw predicted masks are saved,
without visualization overlays.
"""
# Ensure the predict DataLoader has been set
if self._predict_dataloader is None:
@ -404,43 +420,62 @@ class CellSegmentator:
preds, _ = self.__post_process_predictions(raw_output)
# Save out the predicted masks, using batch_counter to index files
self.__save_prediction_masks(batch, preds, batch_counter)
self.__save_prediction_masks(
sample=batch,
predicted_mask=preds,
start_index=batch_counter,
only_masks=only_masks
)
# Increment counter by batch size for unique file naming
batch_counter += inputs.shape[0]
def run(self) -> None:
def run(self, save_results: bool = True, only_masks: bool = False) -> None:
"""
Orchestrate the full workflow:
Orchestrate the full workflow and report execution time:
- If training is enabled in the dataset setup, start training.
- Otherwise, if a test DataLoader is provided, run evaluation.
- Else if a prediction DataLoader is provided, run inference/prediction.
- If neither loader is available in nontraining mode, raise an error.
Args:
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
"""
start_time = time.time()
# 1) TRAINING PATH
if self._dataset_setup.is_training:
# Launch the full training loop (with validation, scheduler steps, etc.)
self.train()
return
# 2) NON-TRAINING PATH (TEST or PREDICT)
# Prefer test if available
if self._test_dataloader is not None:
# Run a single evaluation epoch on the test set and log metrics
self.evaluate()
return
# If no test loader, fall back to prediction if available
if self._predict_dataloader is not None:
# Run inference on the predict set and save outputs
self.predict()
return
self.train(save_results=save_results, only_masks=only_masks)
else:
# 2) NON-TRAINING PATH (TEST or PREDICT)
if self._test_dataloader is not None:
# Run a single evaluation epoch on the test set and log metrics
self.evaluate(save_results=save_results, only_masks=only_masks)
elif self._predict_dataloader is not None:
# Run inference on the predict set and save outputs
self.predict(only_masks=only_masks)
else:
# 3) ERROR: no appropriate loader found
raise RuntimeError(
"Neither test nor predict DataLoader is set for nontraining mode."
)
# 3) ERROR: no appropriate loader found
raise RuntimeError(
"Neither test nor predict DataLoader is set for nontraining mode."
)
elapsed = time.time() - start_time
if elapsed < 60:
logger.info(f"Total execution time: {elapsed:.2f} seconds")
elif elapsed < 3600:
minutes = int(elapsed // 60)
seconds = elapsed % 60
logger.info(f"Total execution time: {minutes} min {seconds:.2f} sec")
else:
hours = int(elapsed // 3600)
minutes = int((elapsed % 3600) // 60)
seconds = elapsed % 60
logger.info(f"Total execution time: {hours} h {minutes} min {seconds:.2f} sec")
def load_from_checkpoint(self, checkpoint_path: str) -> None:
@ -490,7 +525,12 @@ class CellSegmentator:
"""
# Write the checkpoint to disk
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(self._model.state_dict(), checkpoint_path)
torch.save((
self._model.state_dict()
if self._best_weights is None
else self._best_weights),
checkpoint_path
)
def __parse_config(self, config: Config) -> None:
@ -523,11 +563,7 @@ class CellSegmentator:
self._model = ModelRegistry.get_model_class(model.name)(model.params)
# Loads model weights from a specified checkpoint
pretrained_weights = (
config.dataset_config.training.pretrained_weights
if config.dataset_config.is_training
else config.dataset_config.testing.pretrained_weights
)
pretrained_weights = config.dataset_config.common.pretrained_weights
if pretrained_weights:
self.load_from_checkpoint(pretrained_weights)
logger.info(f"Loaded pre-trained weights from: {pretrained_weights}")
@ -589,6 +625,7 @@ class CellSegmentator:
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
logger.info(f"├─ Masks subdirectory: {common.masks_subdir}")
logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
logger.info(f"├─ Pretrained weights: {common.pretrained_weights or 'None'}")
if config.dataset_config.is_training:
training = config.dataset_config.training
@ -596,7 +633,6 @@ class CellSegmentator:
logger.info(f"├─ Batch size: {training.batch_size}")
logger.info(f"├─ Epochs: {training.num_epochs}")
logger.info(f"├─ Validation frequency: {training.val_freq}")
logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}")
if training.is_split:
logger.info(f"├─ Using pre-split directories:")
@ -619,11 +655,6 @@ class CellSegmentator:
logger.info(f"├─ Test dir: {testing.test_dir}")
logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})")
logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}")
logger.info(f"├─ Use ensemble: {'yes' if testing.use_ensemble else 'no'}")
logger.info(f"└─ Pretrained weights:")
logger.info(f" ├─ Single model: {testing.pretrained_weights}")
logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}")
logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}")
self._wandb_config = config.wandb_config
if self._wandb_config.use_wandb:
@ -747,38 +778,62 @@ class CellSegmentator:
return Dataset(data, transforms)
def __print_with_logging(self, results: Dict[str, float], step: int) -> None:
def __print_with_logging(self, results: Dict[str, Union[float, np.ndarray]], step: int) -> None:
"""
Print metrics in a tabular format and log to W&B.
Args:
results (Dict[str, float]): results dictionary.
results (Dict[str, Union[float, np.ndarray]]): Mapping from metric names
to either a float or a ND numpy array.
step (int): epoch index.
"""
rows: list[tuple[str, str]] = []
for key, val in results.items():
if isinstance(val, np.ndarray):
# Convert array to string, e.g. '[0.2, 0.8, 0.5]'
val_str = np.array2string(val, separator=', ')
else:
# Format scalar with 4 decimal places
val_str = f"{val:.4f}"
rows.append((key, val_str))
table = tabulate(
tabular_data=results.items(),
tabular_data=rows,
headers=["Metric", "Value"],
floatfmt=".4f",
tablefmt="fancy_grid"
)
print(table, "\n")
if self._wandb_config.use_wandb:
wandb.log(results, step=step)
# Keep only scalar values
scalar_results: dict[str, float] = {}
for key, val in results.items():
if isinstance(val, np.ndarray):
continue
# Ensure float type
scalar_results[key] = float(val)
wandb.log(scalar_results, step=step)
def __run_epoch(self,
mode: Literal["train", "valid", "test"],
epoch: Optional[int] = None
) -> Dict[str, float]:
epoch: Optional[int] = None,
save_results: bool = True,
only_masks: bool = False
) -> Dict[str, Union[float, np.ndarray]]:
"""
Execute one epoch of training, validation, or testing.
Args:
mode (str): One of 'train', 'valid', or 'test'.
epoch (int, optional): Current epoch number for logging.
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
Returns:
Dict[str, float]: Loss metrics and F1 score for valid/test.
Dict[str, Union[float, np.ndarray]]: Metrics for valid/test.
"""
# Ensure required components are available
if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None):
@ -841,18 +896,23 @@ class CellSegmentator:
)
# Collecting statistics on the batch
tp, fp, fn = self.__compute_stats(
tp, fp, fn, tp_masks, fp_masks, fn_masks = self.__compute_stats(
predicted_masks=preds,
ground_truth_masks=labels_post, # type: ignore
iou_threshold=0.5
iou_threshold=0.5,
return_error_masks=(mode == "test")
)
all_tp.append(tp)
all_fp.append(fp)
all_fn.append(fn)
if mode == "test":
if mode == "test" and save_results is True:
self.__save_prediction_masks(
batch, preds, batch_counter
sample=batch,
predicted_mask=preds,
start_index=batch_counter,
only_masks=only_masks,
masks=(tp_masks, fp_masks, fn_masks) # type: ignore
)
# Backpropagation and optimizer step in training
@ -871,7 +931,9 @@ class CellSegmentator:
if self._criterion is not None:
# Collect loss metrics
epoch_metrics = {f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()}
epoch_metrics: Dict[str, Union[float, np.ndarray]] = {
f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()
}
# Reset internal loss metrics accumulator
self._criterion.reset_metrics()
else:
@ -884,13 +946,13 @@ class CellSegmentator:
fp_array = np.vstack(all_fp)
fn_array = np.vstack(all_fn)
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="micro"
)
epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( # type: ignore
epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="imagewise"
)
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric(
tp_array, fp_array, fn_array, reduction="macro"
)
@ -976,8 +1038,10 @@ class CellSegmentator:
self,
predicted_masks: np.ndarray,
ground_truth_masks: np.ndarray,
iou_threshold: float = 0.5
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
iou_threshold: float = 0.5,
return_error_masks: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray,
Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
"""
Compute batch-wise true positives, false positives, and false negatives
for instance segmentation, using a configurable IoU threshold.
@ -987,23 +1051,33 @@ class CellSegmentator:
ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W).
iou_threshold (float): Intersection-over-Union threshold for matching predictions
to ground truths (default: 0.5).
return_error_masks (bool): Whether to also return binary error masks.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]:
Tuple(np.ndarray, np.ndarray, np.ndarray,
np.ndarray | None, np.ndarray | None, np.ndarray | None):
- tp: True positives per batch and class, shape (B, C)
- fp: False positives per batch and class, shape (B, C)
- fn: False negatives per batch and class, shape (B, C)
- tp_maks: True positives mask per batch and class, shape (B, C, H, W)
- fp_maks: False positives mask per batch and class, shape (B, C, H, W)
- fn_maks: False negatives mask per batch and class, shape (B, C, H, W)
"""
stats = compute_batch_segmentation_tp_fp_fn(
batch_ground_truth=ground_truth_masks,
batch_prediction=predicted_masks,
iou_threshold=iou_threshold,
return_error_masks=return_error_masks,
remove_boundary_objects=True
)
tp = stats["tp"]
fp = stats["fp"]
fn = stats["fn"]
return tp, fp, fn
tp_mask = stats["tp_mask"] if return_error_masks else None
fp_mask = stats["fp_mask"] if return_error_masks else None
fn_mask = stats["fn_mask"] if return_error_masks else None
return tp, fp, fn, tp_mask, fp_mask, fn_mask
def __compute_f1_metric(
@ -1011,7 +1085,7 @@ class CellSegmentator:
true_positives: np.ndarray,
false_positives: np.ndarray,
false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro"
reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro"
) -> Union[float, np.ndarray]:
"""
Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes.
@ -1023,8 +1097,9 @@ class CellSegmentator:
reduction:
- 'none': return F1 for each sample, class shape (batch_size, num_classes)
- 'micro': global F1 over all samples & classes
- 'imagewise': F1 per sample (summing over classes), then average over samples
- 'macro': average class-wise F1 (classes summed over batch)
- 'imagewise': F1 per sample (summing over classes), then average over samples
- 'per_class': F1 per class (summing over batch), return vector of shape (num_classes,)
- 'weighted': class-wise F1 weighted by support (TP+FN)
Returns:
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
@ -1040,7 +1115,11 @@ class CellSegmentator:
tp_val = int(true_positives[i, c])
fp_val = int(false_positives[i, c])
fn_val = int(false_negatives[i, c])
_, _, f1_val = compute_f1_score(tp_val, fp_val, fn_val)
_, _, f1_val = compute_f1_score(
tp_val,
fp_val,
fn_val
)
f1_matrix[i, c] = f1_val
return f1_matrix
@ -1049,7 +1128,11 @@ class CellSegmentator:
tp_total = int(true_positives.sum())
fp_total = int(false_positives.sum())
fn_total = int(false_negatives.sum())
_, _, f1_global = compute_f1_score(tp_total, fp_total, fn_total)
_, _, f1_global = compute_f1_score(
tp_total,
fp_total,
fn_total
)
return f1_global
# 3) Imagewise: compute per-sample F1 (sum over classes), then average
@ -1059,16 +1142,31 @@ class CellSegmentator:
tp_i = int(true_positives[i].sum())
fp_i = int(false_positives[i].sum())
fn_i = int(false_negatives[i].sum())
_, _, f1_i = compute_f1_score(tp_i, fp_i, fn_i)
_, _, f1_i = compute_f1_score(
tp_i,
fp_i,
fn_i
)
f1_per_image[i] = f1_i
return float(f1_per_image.mean())
# For macro/weighted, first aggregate per class across the batch
# Aggregate per class across the batch for per_class, macro, weighted
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Macro: average F1 across classes equally
# 4) Per-class: compute F1 for each class and return vector
if reduction == "per_class":
f1_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
_, _, f1_per_class[c] = compute_f1_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
return f1_per_class
# 5) Macro: average F1 across classes equally
if reduction == "macro":
f1_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
@ -1080,7 +1178,7 @@ class CellSegmentator:
f1_per_class[c] = f1_c
return float(f1_per_class.mean())
# 5) Weighted: class-wise F1 weighted by support = TP + FN
# 6) Weighted: class-wise F1 weighted by support = TP + FN
if reduction == "weighted":
f1_per_class = np.zeros(num_classes, dtype=float)
support = np.zeros(num_classes, dtype=float)
@ -1088,7 +1186,11 @@ class CellSegmentator:
tp_c = tp_per_class[c]
fp_c = fp_per_class[c]
fn_c = fn_per_class[c]
_, _, f1_c = compute_f1_score(tp_c, fp_c, fn_c)
_, _, f1_c = compute_f1_score(
tp_c,
fp_c,
fn_c
)
f1_per_class[c] = f1_c
support[c] = tp_c + fn_c
total_support = support.sum()
@ -1106,7 +1208,7 @@ class CellSegmentator:
true_positives: np.ndarray,
false_positives: np.ndarray,
false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro"
reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro"
) -> Union[float, np.ndarray]:
"""
Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes.
@ -1121,8 +1223,9 @@ class CellSegmentator:
reduction:
- 'none': return AP for each sample and class shape (batch_size, num_classes)
- 'micro': global AP over all samples & classes
- 'imagewise': AP per sample (summing stats over classes), then average over batch
- 'macro': average class-wise AP (each class summed over batch)
- 'imagewise': AP per sample (summing stats over classes), then average over batch
- 'per_class': AP per class (summing over batch), return vector of shape (num_classes,)
- 'weighted': class-wise AP weighted by support (TP+FN)
Returns:
@ -1139,7 +1242,11 @@ class CellSegmentator:
tp_val = int(true_positives[i, c])
fp_val = int(false_positives[i, c])
fn_val = int(false_negatives[i, c])
ap_val = compute_average_precision_score(tp_val, fp_val, fn_val)
ap_val = compute_average_precision_score(
tp_val,
fp_val,
fn_val
)
ap_matrix[i, c] = ap_val
return ap_matrix
@ -1148,7 +1255,11 @@ class CellSegmentator:
tp_total = int(true_positives.sum())
fp_total = int(false_positives.sum())
fn_total = int(false_negatives.sum())
return compute_average_precision_score(tp_total, fp_total, fn_total)
return compute_average_precision_score(
tp_total,
fp_total,
fn_total
)
# 3) Imagewise: compute per-sample AP (sum over classes), then mean
if reduction == "imagewise":
@ -1157,7 +1268,11 @@ class CellSegmentator:
tp_i = int(true_positives[i].sum())
fp_i = int(false_positives[i].sum())
fn_i = int(false_negatives[i].sum())
ap_per_image[i] = compute_average_precision_score(tp_i, fp_i, fn_i)
ap_per_image[i] = compute_average_precision_score(
tp_i,
fp_i,
fn_i
)
return float(ap_per_image.mean())
# For macro and weighted: first aggregate per class across batch
@ -1165,7 +1280,18 @@ class CellSegmentator:
fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Macro: average AP across classes equally
# 4) Per-class: compute F1 for each class and return vector
if reduction == "per_class":
ap_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
ap_per_class[c] = compute_average_precision_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
return ap_per_class
# 5) Macro: average AP across classes equally
if reduction == "macro":
ap_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
@ -1176,7 +1302,7 @@ class CellSegmentator:
)
return float(ap_per_class.mean())
# 5) Weighted: class-wise AP weighted by support = TP + FN
# 6) Weighted: class-wise AP weighted by support = TP + FN
if reduction == "weighted":
ap_per_class = np.zeros(num_classes, dtype=float)
support = np.zeros(num_classes, dtype=float)
@ -1184,7 +1310,11 @@ class CellSegmentator:
tp_c = tp_per_class[c]
fp_c = fp_per_class[c]
fn_c = fn_per_class[c]
ap_per_class[c] = compute_average_precision_score(tp_c, fp_c, fn_c)
ap_per_class[c] = compute_average_precision_score(
tp_c,
fp_c,
fn_c
)
support[c] = tp_c + fn_c
total_support = support.sum()
if total_support == 0:
@ -1215,87 +1345,159 @@ class CellSegmentator:
sample: Dict[str, Any],
predicted_mask: Union[np.ndarray, torch.Tensor],
start_index: int = 0,
only_masks: bool = False,
masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None
) -> None:
"""
Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders.
Save multi-channel predicted masks as TIFFs and
corresponding visualizations as PNGs in separate folders.
Args:
sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
sample (Dict[str, Any]): Batch sample from MONAI
LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
start_index (int): Starting index for naming when metadata is missing.
only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations.
masks (Tuple[np.ndarray, np.ndarray, np.ndarray] | None):
A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None.
"""
# Determine base paths
# Base directories (created once per call)
base_output_dir = self._dataset_setup.common.predictions_dir
masks_dir = base_output_dir
plots_dir = os.path.join(base_output_dir, "plots")
evaluate_dir = os.path.join(plots_dir, "evaluate")
os.makedirs(masks_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)
os.makedirs(evaluate_dir, exist_ok=True)
# Extract image (C, H, W) or batch of images (B, C, H, W), and metadata
image_obj = sample.get("image") # Expected shape: (C, H, W) or (B, C, H, W)
mask_obj = sample.get("mask") # Expected shape: (C, H, W) or (B, C, H, W)
image_meta = sample.get("image_meta_dict")
# Convert tensors to numpy
# Convert tensors to numpy if necessary
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
return x
image_array = to_numpy(image_obj) if image_obj is not None else None
mask_array = to_numpy(mask_obj) if mask_obj is not None else None
pred_array = to_numpy(predicted_mask)
# Handle batch dimension: (B, C, H, W)
if pred_array.ndim == 4:
for idx in range(pred_array.shape[0]):
batch_sample: Dict[str, Any] = {}
if image_array is not None and image_array.ndim == 4:
batch_sample["image"] = image_array[idx]
if isinstance(image_meta, dict) and "filename_or_obj" in image_meta:
batch_sample["image_meta_dict"] = image_meta["filename_or_obj"][idx]
if mask_array is not None and mask_array.ndim == 4:
batch_sample["mask"] = mask_array[idx]
self.__save_prediction_masks(
batch_sample,
pred_array[idx],
start_index=start_index+idx
)
return
return x.cpu().numpy() if isinstance(x, torch.Tensor) else x
pred_array = to_numpy(predicted_mask).astype(np.uint16)
# Handle batch dimension
for idx in range(pred_array.shape[0]):
batch_sample: Dict[str, Any] = {}
# copy per-sample image and meta
img = to_numpy(sample["image"])
if img.ndim == 4:
batch_sample["image"] = img[idx]
if "mask" in sample:
msk = to_numpy(sample["mask"]).astype(np.uint16)
if msk.ndim == 4:
batch_sample["mask"] = msk[idx]
image_meta = sample.get("image_meta_dict")
if isinstance(image_meta, dict) and "filename_or_obj" in image_meta:
fname = image_meta["filename_or_obj"][idx]
batch_sample["image_name"] = fname
single_masks = (
(masks[0][idx], masks[1][idx], masks[2][idx]) if masks is not None else None
)
self.__save_single_prediction_mask(
sample=batch_sample,
pred_array=pred_array[idx],
start_index=start_index + idx,
masks_dir=masks_dir,
plots_dir=plots_dir,
evaluate_dir=evaluate_dir,
only_masks=only_masks,
masks=single_masks,
)
def __save_single_prediction_mask(
self,
sample: Dict[str, Any],
pred_array: np.ndarray,
start_index: int,
masks_dir: str,
plots_dir: str,
evaluate_dir: str,
only_masks: bool = False,
masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None
) -> None:
"""
Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations.
Assumes output directories already exist.
# Determine base filename
Args:
sample (Dict[str, Any]): Dictionary containing 'image', 'mask',
and optional 'image_meta_dict' for metadata.
pred_array (np.ndarray): Predicted mask array of shape (C,H,W).
start_index (int): Base index for generating filenames when metadata is missing.
masks_dir (str): Directory for saving TIFF masks.
plots_dir (str): Directory for saving PNG visualizations.
evaluate_dir (str): Directory for saving PNG visualizations of evaluation results.
only_masks (bool): If True, saves only TIFF mask files; skips PNG plots.
masks (Tuple[np.ndarray, np.ndarray, np.ndarray], optional): A tuple of
true-positive, false-positive, and false-negative mask arrays,
each of shape (C,H,W). Defaults to None.
"""
if pred_array.ndim == 2:
pred_array = np.expand_dims(pred_array, axis=0)
elif pred_array.ndim != 3:
raise ValueError(
f"Unsupported predicted_mask dimensions: {pred_array.ndim}."
"Expected 2D (H,W) or 3D (C,H,W)."
)
# Handle image array if present
image_array: np.ndarray = sample["image"]
if image_array.ndim == 2:
image_array = np.expand_dims(image_array, axis=0)
elif image_array.ndim != 3:
raise ValueError(
f"Unsupported image dimensions: {image_array.ndim}."
"Expected 2D (H,W) or 3D (C,H,W)."
)
true_mask_array: Optional[np.ndarray] = sample.get("mask")
if isinstance(true_mask_array, np.ndarray):
if true_mask_array.ndim == 2:
true_mask_array = np.expand_dims(true_mask_array, axis=0)
elif true_mask_array.ndim != 3:
raise ValueError(
f"Unsupported true_mask_array dimensions: {true_mask_array.ndim}."
"Expected 2D (H,W) or 3D (C,H,W)."
)
# Determine filename base
image_meta = sample.get("image_name")
if isinstance(image_meta, (str, os.PathLike)):
base_name = os.path.splitext(os.path.basename(image_meta))[0]
else:
# Use provided start_index when metadata missing
base_name = f"prediction_{start_index:04d}"
# Save mask TIFF (16-bit)
mask_filename = f"{base_name}_mask.tif"
mask_path = os.path.join(masks_dir, mask_filename)
# Save main mask TIFF
mask_path = os.path.join(masks_dir, f"{base_name}_mask.tif")
tiff.imwrite(mask_path, pred_array.astype(np.uint16), compression="zlib")
# Now pred_array shape is (C, H, W)
num_channels = pred_array.shape[0]
for channel_idx in range(num_channels):
channel_mask = pred_array[channel_idx]
# File names
plot_filename = f"{base_name}_ch{channel_idx:01d}.png"
plot_path = os.path.join(plots_dir, plot_filename)
# Extract corresponding true mask channel if exists
true_mask = None
if mask_array is not None and mask_array.ndim == 3:
true_mask = mask_array[channel_idx]
if only_masks:
return
# Generate and save visualization
# Save channel-wise plots
num_channels = pred_array.shape[0]
for ch in range(num_channels):
true_ch = true_mask_array[ch] if true_mask_array is not None else None
self.__plot_mask(
file_path=plot_path,
image_data=image_array, # type: ignore
predicted_mask=channel_mask,
true_mask=true_mask,
file_path=os.path.join(plots_dir, f"{base_name}_ch{ch}.png"),
image_data=image_array,
predicted_mask=pred_array[ch],
true_mask=true_ch,
)
if masks is not None and true_ch is not None:
self.__save_mask_comparison_visuals(
gt=true_ch,
pred=pred_array[ch],
tp_mask=masks[0][ch],
fp_mask=masks[1][ch],
fn_mask=masks[2][ch],
file_path=os.path.join(evaluate_dir, f"{base_name}_ch{ch}.png")
)
def __plot_mask(
@ -1307,6 +1509,16 @@ class CellSegmentator:
) -> None:
"""
Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided.
Args:
file_path (str): Path where the visualization image will be saved.
image_data (np.ndarray): The original input image array, expected shape (C, H, W).
predicted_mask (np.ndarray): The predicted mask array, shape (H, W),
depending on the task.
true_mask (Optional[np.ndarray], optional): The ground-truth mask array.
If provided, an additional row with true mask and overlap visualization
will be added to the plot. Default is None.
"""
img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data
@ -1317,7 +1529,7 @@ class CellSegmentator:
('Original Image','Predicted Mask','Predicted Contours'))
else:
fig, axs = plt.subplots(2,3,figsize=(15,10))
plt.subplots_adjust(wspace=0.02, hspace=0.1)
plt.subplots_adjust(wspace=0.02, hspace=0.02)
# row 0: predicted
self.__plot_panels(axs[0], img, predicted_mask, 'red',
('Original Image','Predicted Mask','Predicted Contours'))
@ -1370,6 +1582,69 @@ class CellSegmentator:
ax2.contour(boundaries, colors=contour_color, linewidths=0.5)
ax2.set_title(titles[2])
ax2.axis('off')
def __save_mask_comparison_visuals(
self,
gt: np.ndarray,
pred: np.ndarray,
tp_mask: np.ndarray,
fp_mask: np.ndarray,
fn_mask: np.ndarray,
file_path: str
) -> None:
"""
Creates and saves a 1x3 subplot figure showing:
1) True mask with boundaries
2) Predicted mask without boundaries
3) Overlay mask combining FP (R), TP (G), FN (B)
Args:
gt (np.ndarray): Ground truth mask (H, W).
pred (np.ndarray): Predicted mask (H, W).
tp_mask (np.ndarray): True positive mask (H, W).
fp_mask (np.ndarray): False positive mask (H, W).
fn_mask (np.ndarray): False negative mask (H, W).
file_path (str): Path where the visualization image will be saved.
"""
# Prepare overlay mask
overlap_mask = np.zeros((*gt.shape[:2], 3), dtype=np.uint8)
overlap_mask[..., 0] = np.where(fp_mask, 255, 0)
overlap_mask[..., 1] = np.where(tp_mask, 255, 0)
overlap_mask[..., 2] = np.where(fn_mask, 255, 0)
# Set up figure
fig, axes = plt.subplots(1, 3, figsize=(15, 5),
gridspec_kw={'width_ratios': [1, 1, 1]})
plt.subplots_adjust(wspace=0.02, hspace=0.0,
left=0.05, right=0.95, top=0.95, bottom=0.05)
# Colormap for instances
num_instances = max(np.max(gt), np.max(pred))
cmap = plt.get_cmap("gist_ncar")
colors = [cmap(i / num_instances) for i in range(num_instances)]
cmap = mcolors.ListedColormap(colors)
# Plot true mask
axes[0].imshow(gt, cmap=cmap)
axes[0].contour(find_boundaries(gt, mode="thick"), colors="black", linewidths=0.5)
axes[0].set_title("True Mask")
axes[0].axis("off")
# Plot predicted mask
axes[1].imshow(pred, cmap=cmap)
axes[1].contour(find_boundaries(pred, mode="thick"), colors="black", linewidths=0.5)
axes[1].set_title("Predicted Mask")
axes[1].axis("off")
# Plot overlay
axes[2].imshow(overlap_mask)
axes[2].set_title("Overlay Mask (R-FP; G-TP; B-FN)")
axes[2].axis("off")
# Save
plt.savefig(file_path, bbox_inches="tight", dpi=300)
plt.close()
def __compute_flows_from_masks(
@ -1522,7 +1797,7 @@ class CellSegmentator:
flow_output = np.zeros((2, height, width), dtype=np.float32)
ys_np = y.cpu().numpy() - 1
xs_np = x.cpu().numpy() - 1
flow_output[:, ys_np, xs_np] = mu
flow_output[:, ys_np, xs_np] = mu.reshape(2, -1)
flows[2*channel: 2*channel + 2] = flow_output
return flows

@ -21,7 +21,7 @@ __all__ = [
"compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score"
]
logger = get_logger()
logger = get_logger(__name__)
def compute_f1_score(

@ -23,6 +23,20 @@ def main():
default='train',
help='Run mode: train, test or predict'
)
parser.add_argument(
'-s', '--save-masks',
action='store_true',
default=True,
help='If set to False, do not save predicted masks; by default, saving is enabled'
)
parser.add_argument(
'--only-masks',
action='store_true',
default=False,
help=('If set and save-masks set, save only the raw predicted'
' masks without additional visualizations or metrics')
)
args = parser.parse_args()
mode = args.mode
@ -56,22 +70,25 @@ def main():
if config.wandb_config.use_wandb:
wandb.watch(segmentator._model, log="all", log_graph=True)
# Run training (or prediction, if implemented)
segmentator.run()
try:
segmentator.run(save_results=args.save_masks, only_masks=args.only_masks)
except Exception:
raise
finally:
if config.dataset_config.is_training:
# Prepare saving path
weights_dir = (
wandb.run.dir if config.wandb_config.use_wandb else "weights" # type: ignore
)
saving_path = os.path.join(
weights_dir,
os.path.basename(config.dataset_config.common.predictions_dir) + '.pth'
)
segmentator.save_checkpoint(saving_path)
if config.dataset_config.is_training:
# Prepare saving path
weights_dir = (
wandb.run.dir if config.wandb_config.use_wandb else "weights" # type: ignore
)
saving_path = os.path.join(
weights_dir,
os.path.basename(config.dataset_config.common.predictions_dir) + '.pth'
)
segmentator.save_checkpoint(saving_path)
if config.wandb_config.use_wandb:
wandb.save(saving_path)
if config.wandb_config.use_wandb:
wandb.save(saving_path)
if __name__ == "__main__":

Loading…
Cancel
Save