Added the creation of error masks for each class;

Added arguments for creating only prediction masks, as well as refusing to save the output;
Added execution time display;
The pretrained_weight parameter has been moved to common;
Parameters related to ensembles have been removed.
master
laynholt 2 months ago
parent b201367596
commit 28f978956c

@ -13,6 +13,7 @@ class DatasetCommonConfig(BaseModel):
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars' masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars'
predictions_dir: str = "." # Directory to save predictions predictions_dir: str = "." # Directory to save predictions
pretrained_weights: str = "" # Path to pretrained weights
@model_validator(mode="after") @model_validator(mode="after")
def validate_common(self) -> "DatasetCommonConfig": def validate_common(self) -> "DatasetCommonConfig":
@ -72,7 +73,6 @@ class DatasetTrainingConfig(BaseModel):
batch_size: int = 1 # Batch size for training batch_size: int = 1 # Batch size for training
num_epochs: int = 100 # Number of training epochs num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training val_freq: int = 1 # Frequency of validation during training
pretrained_weights: str = "" # Path to pretrained weights for training
@field_validator("train_size", "valid_size", "test_size", mode="before") @field_validator("train_size", "valid_size", "test_size", mode="before")
@ -137,15 +137,6 @@ class DatasetTrainingConfig(BaseModel):
raise ValueError("offsets must be >= 0") raise ValueError("offsets must be >= 0")
return self return self
@model_validator(mode="after")
def validate_pretrained(self) -> "DatasetTrainingConfig":
"""
Validates that pretrained_weights is provided and exists.
"""
if self.pretrained_weights and not os.path.exists(self.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
return self
class DatasetTestingConfig(BaseModel): class DatasetTestingConfig(BaseModel):
""" """
@ -156,11 +147,6 @@ class DatasetTestingConfig(BaseModel):
test_offset: int = 0 # Offset for testing data test_offset: int = 0 # Offset for testing data
shuffle: bool = True # Shuffle data shuffle: bool = True # Shuffle data
use_ensemble: bool = False # Flag to use ensemble mode in testing
ensemble_pretrained_weights1: str = "."
ensemble_pretrained_weights2: str = "."
pretrained_weights: str = "."
@field_validator("test_size", mode="before") @field_validator("test_size", mode="before")
def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]: def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]:
""" """
@ -191,25 +177,11 @@ class DatasetTestingConfig(BaseModel):
""" """
Validates the testing configuration: Validates the testing configuration:
- test_dir must be non-empty and exist. - test_dir must be non-empty and exist.
- If use_ensemble is True, both ensemble_pretrained_weights1 and ensemble_pretrained_weights2 must be provided and exist.
- If use_ensemble is False, pretrained_weights must be provided and exist.
""" """
if not self.test_dir: if not self.test_dir:
raise ValueError("In testing configuration, test_dir must be provided and non-empty") raise ValueError("In testing configuration, test_dir must be provided and non-empty")
if not os.path.exists(self.test_dir): if not os.path.exists(self.test_dir):
raise ValueError(f"Path for test_dir does not exist: {self.test_dir}") raise ValueError(f"Path for test_dir does not exist: {self.test_dir}")
if self.use_ensemble:
for field in ["ensemble_pretrained_weights1", "ensemble_pretrained_weights2"]:
value = getattr(self, field)
if not value:
raise ValueError(f"When use_ensemble is True, {field} must be provided and non-empty")
if not os.path.exists(value):
raise ValueError(f"Path for {field} does not exist: {value}")
else:
if not self.pretrained_weights:
raise ValueError("When use_ensemble is False, pretrained_weights must be provided and non-empty")
if not os.path.exists(self.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
if self.test_offset < 0: if self.test_offset < 0:
raise ValueError("test_offset must be >= 0") raise ValueError("test_offset must be >= 0")
return self return self
@ -237,11 +209,17 @@ class DatasetConfig(BaseModel):
if (self.training.is_split and self.training.pre_split.test_dir) or (not self.training.is_split): if (self.training.is_split and self.training.pre_split.test_dir) or (not self.training.is_split):
if self.training.test_size > 0 and not self.common.predictions_dir: if self.training.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero") raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.pretrained_weights and not os.path.exists(self.common.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}")
else: else:
if self.testing is None: if self.testing is None:
raise ValueError("Testing configuration must be provided when is_training is False") raise ValueError("Testing configuration must be provided when is_training is False")
if self.testing.test_size > 0 and not self.common.predictions_dir: if self.testing.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero") raise ValueError("predictions_dir must be provided when test_size is non-zero")
if not self.common.pretrained_weights:
raise ValueError("When testing pretrained_weights must be provided and non-empty")
if not os.path.exists(self.common.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}")
return self return self
def model_dump(self, **kwargs) -> Dict[str, Any]: def model_dump(self, **kwargs) -> Dict[str, Any]:

@ -10,7 +10,7 @@ from core.logger import get_logger
__all__ = ["BoundaryExclusion", "IntensityDiversification"] __all__ = ["BoundaryExclusion", "IntensityDiversification"]
logger = get_logger("cell_aware") logger = get_logger(__name__)
class BoundaryExclusion(MapTransform): class BoundaryExclusion(MapTransform):

@ -0,0 +1,193 @@
import torch
import numpy as np
from typing import Hashable, List, Sequence, Optional, Tuple
from monai.utils.misc import fall_back_tuple
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms import Randomizable, RandCropd, Crop # type: ignore
from core.logger import get_logger
logger = get_logger(__name__)
def _compute_multilabel_bbox(
mask: np.ndarray
) -> Optional[Tuple[List[int], List[int], List[int], List[int]]]:
"""
Compute per-channel bounding-box constraints and return lists of limits for each axis.
Args:
mask: multi-channel instance mask of shape (C, H, W).
Returns:
A tuple of four lists:
- top_mins: list of r_max for each non-empty channel
- top_maxs: list of r_min for each non-empty channel
- left_mins: list of c_max for each non-empty channel
- left_maxs: list of c_min for each non-empty channel
Or None if mask contains no positive labels.
"""
channels, rows, cols = np.nonzero(mask)
if channels.size == 0:
return None
top_mins: List[int] = []
top_maxs: List[int] = []
left_mins: List[int] = []
left_maxs: List[int] = []
C = mask.shape[0]
for ch in range(C):
rs, cs = np.nonzero(mask[ch])
if rs.size == 0:
continue
r_min, r_max = int(rs.min()), int(rs.max())
c_min, c_max = int(cs.min()), int(cs.max())
# For each channel, record the row/col extents
top_mins.append(r_max)
top_maxs.append(r_min)
left_mins.append(c_max)
left_maxs.append(c_min)
return top_mins, top_maxs, left_mins, left_maxs
class SpatialCropAllClasses(Randomizable, Crop):
"""
Cropper for multi-label instance masks and images: ensures each label-channel's
instances lie within the crop if possible.
Must be called on a mask tensor first to compute the crop, then on the image.
Args:
roi_size: desired crop size (height, width).
num_candidates: fallback samples when no single crop fits all instances.
lazy: defer actual cropping.
"""
def __init__(
self,
roi_size: Sequence[int],
num_candidates: int = 10,
lazy: bool = False,
) -> None:
super().__init__(lazy=lazy)
self.roi_size = tuple(roi_size)
self.num_candidates = num_candidates
self._slices: Optional[Tuple[slice, ...]] = None
def randomize(self, img_size: Sequence[int]) -> None: # type: ignore
"""
Choose crop offsets so that each non-empty channel is included if possible.
"""
height, width = img_size
crop_h, crop_w = self.roi_size
max_top = max(0, height - crop_h)
max_left = max(0, width - crop_w)
# Compute per-channel bbox constraints
mask = self._img
bboxes = _compute_multilabel_bbox(mask)
if bboxes is None:
# no labels: random patch using MONAI utils
logger.warning("No labels found; using random patch.")
# determine actual patch size (fallback)
self._size = fall_back_tuple(self.roi_size, img_size)
# compute valid size for random patch
valid_size = get_valid_patch_size(img_size, self._size)
# directly get random patch slices
self._slices = get_random_patch(img_size, valid_size, self.R)
return
else:
top_mins, top_maxs, left_mins, left_maxs = bboxes
# Convert to allowable windows
# top_min_global = max(r_max - crop_h +1 for each channel)
global_top_min = max(0, max(r_max - crop_h + 1 for r_max in top_mins))
# top_max_global = min(r_min for each channel)
global_top_max = min(min(top_maxs), max_top)
# same for left
global_left_min = max(0, max(c_max - crop_w + 1 for c_max in left_mins))
global_left_max = min(min(left_maxs), max_left)
if global_top_min <= global_top_max and global_left_min <= global_left_max:
# there is a window covering all channels fully
top = self.R.randint(global_top_min, global_top_max + 1)
left = self.R.randint(global_left_min, global_left_max + 1)
else:
# fallback: sample candidates to maximize channel coverage
logger.warning(
f"Cannot fit all instances; sampling {self.num_candidates} candidates."
)
best_cover = -1
best_top = best_left = 0
C = mask.shape[0]
for _ in range(self.num_candidates):
cand_top = self.R.randint(0, max_top + 1)
cand_left = self.R.randint(0, max_left + 1)
window = mask[:, cand_top : cand_top + crop_h, cand_left : cand_left + crop_w]
cover = sum(int(window[ch].any()) for ch in range(C))
if cover > best_cover:
best_cover = cover
best_top, best_left = cand_top, cand_left
logger.info(f"Selected crop covering {best_cover}/{C} channels.")
top, left = best_top, best_left
# store slices for use on both mask and image
self._slices = (
slice(None),
slice(top, top + crop_h),
slice(left, left + crop_w),
)
def __call__(self, img: torch.Tensor, lazy: Optional[bool] = None) -> torch.Tensor: # type: ignore
"""
On first call (mask), computes crop. On subsequent (image), just applies.
Raises if mask not provided first.
"""
# Determine tensor shape
img_size = (
img.peek_pending_shape()[1:]
if isinstance(img, MetaTensor)
else img.shape[1:]
)
# First call must be mask to compute slices
if self._slices is None:
if not torch.is_floating_point(img) and img.dtype in (torch.uint8, torch.int16, torch.int32, torch.int64):
# assume integer mask
self._img = img.cpu().numpy()
self.randomize(img_size)
else:
raise RuntimeError(
"Mask tensor must be passed first for computing crop bounds."
)
# Now apply stored slice
if self._slices is None:
raise RuntimeError("Crop slices not computed; call on mask first.")
lazy_exec = self.lazy if lazy is None else lazy
return super().__call__(img=img, slices=self._slices, lazy=lazy_exec)
class RandSpatialCropAllClassesd(RandCropd):
"""
Dict-based wrapper: applies SpatialCropAllClasses to mask then image.
Requires mask present or raises.
"""
def __init__(
self,
keys: Sequence,
roi_size: Sequence[int],
num_candidates: int = 10,
allow_missing_keys: bool = False,
lazy: bool = False,
):
cropper = SpatialCropAllClasses(
roi_size=roi_size,
num_candidates=num_candidates,
lazy=lazy,
)
super().__init__(
keys=keys,
cropper=cropper,
allow_missing_keys=allow_missing_keys,
lazy=lazy,
)

@ -1,3 +1,4 @@
import time
import random import random
import numpy as np import numpy as np
from numba import njit, prange from numba import njit, prange
@ -70,6 +71,8 @@ class CellSegmentator:
self._test_dataloader: Optional[DataLoader] = None self._test_dataloader: Optional[DataLoader] = None
self._predict_dataloader: Optional[DataLoader] = None self._predict_dataloader: Optional[DataLoader] = None
self._best_weights = None
def create_dataloaders( def create_dataloaders(
self, self,
@ -315,9 +318,14 @@ class CellSegmentator:
logger.info(l) logger.info(l)
def train(self) -> None: def train(self, save_results: bool = True, only_masks: bool = False) -> None:
""" """
Train the model over multiple epochs, including validation and test. Train the model over multiple epochs, including validation and test.
Args:
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
""" """
# Ensure training is enabled in dataset setup # Ensure training is enabled in dataset setup
if not self._dataset_setup.is_training: if not self._dataset_setup.is_training:
@ -335,7 +343,6 @@ class CellSegmentator:
logger.info(f"\n{'=' * 50}") logger.info(f"\n{'=' * 50}")
best_f1_score = 0.0 best_f1_score = 0.0
best_weights = None
for epoch in range(1, self._dataset_setup.training.num_epochs + 1): for epoch in range(1, self._dataset_setup.training.num_epochs + 1):
train_metrics = self.__run_epoch("train", epoch) train_metrics = self.__run_epoch("train", epoch)
@ -356,29 +363,38 @@ class CellSegmentator:
if f1 > best_f1_score: if f1 > best_f1_score:
best_f1_score = f1 best_f1_score = f1
# Deep copy weights to avoid reference issues # Deep copy weights to avoid reference issues
best_weights = copy.deepcopy(self._model.state_dict()) self._best_weights = copy.deepcopy(self._model.state_dict())
logger.info(f"Updated best model weights with F1 score: {f1:.4f}") logger.info(f"Updated best model weights with F1 score: {f1:.4f}")
# Restore best model weights if available # Restore best model weights if available
if best_weights is not None: if self._best_weights is not None:
self._model.load_state_dict(best_weights) self._model.load_state_dict(self._best_weights)
if self._test_dataloader is not None: if self._test_dataloader is not None:
test_metrics = self.__run_epoch("test") test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0) self.__print_with_logging(test_metrics, 0)
def evaluate(self) -> None: def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None:
""" """
Run a full test epoch and display/log the resulting metrics. Run a full test epoch and display/log the resulting metrics.
Args:
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
""" """
test_metrics = self.__run_epoch("test") test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0) self.__print_with_logging(test_metrics, 0)
def predict(self) -> None: def predict(self, only_masks: bool = False) -> None:
""" """
Run inference on the predict set and save the resulting instance masks. Run inference on the predict set and save the resulting instance masks.
Args:
only_masks (bool): If True, only raw predicted masks are saved,
without visualization overlays.
""" """
# Ensure the predict DataLoader has been set # Ensure the predict DataLoader has been set
if self._predict_dataloader is None: if self._predict_dataloader is None:
@ -404,44 +420,63 @@ class CellSegmentator:
preds, _ = self.__post_process_predictions(raw_output) preds, _ = self.__post_process_predictions(raw_output)
# Save out the predicted masks, using batch_counter to index files # Save out the predicted masks, using batch_counter to index files
self.__save_prediction_masks(batch, preds, batch_counter) self.__save_prediction_masks(
sample=batch,
predicted_mask=preds,
start_index=batch_counter,
only_masks=only_masks
)
# Increment counter by batch size for unique file naming # Increment counter by batch size for unique file naming
batch_counter += inputs.shape[0] batch_counter += inputs.shape[0]
def run(self) -> None: def run(self, save_results: bool = True, only_masks: bool = False) -> None:
""" """
Orchestrate the full workflow: Orchestrate the full workflow and report execution time:
- If training is enabled in the dataset setup, start training. - If training is enabled in the dataset setup, start training.
- Otherwise, if a test DataLoader is provided, run evaluation. - Otherwise, if a test DataLoader is provided, run evaluation.
- Else if a prediction DataLoader is provided, run inference/prediction. - Else if a prediction DataLoader is provided, run inference/prediction.
- If neither loader is available in nontraining mode, raise an error. - If neither loader is available in nontraining mode, raise an error.
Args:
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
""" """
start_time = time.time()
# 1) TRAINING PATH # 1) TRAINING PATH
if self._dataset_setup.is_training: if self._dataset_setup.is_training:
# Launch the full training loop (with validation, scheduler steps, etc.) # Launch the full training loop (with validation, scheduler steps, etc.)
self.train() self.train(save_results=save_results, only_masks=only_masks)
return else:
# 2) NON-TRAINING PATH (TEST or PREDICT) # 2) NON-TRAINING PATH (TEST or PREDICT)
# Prefer test if available
if self._test_dataloader is not None: if self._test_dataloader is not None:
# Run a single evaluation epoch on the test set and log metrics # Run a single evaluation epoch on the test set and log metrics
self.evaluate() self.evaluate(save_results=save_results, only_masks=only_masks)
return elif self._predict_dataloader is not None:
# If no test loader, fall back to prediction if available
if self._predict_dataloader is not None:
# Run inference on the predict set and save outputs # Run inference on the predict set and save outputs
self.predict() self.predict(only_masks=only_masks)
return else:
# 3) ERROR: no appropriate loader found # 3) ERROR: no appropriate loader found
raise RuntimeError( raise RuntimeError(
"Neither test nor predict DataLoader is set for nontraining mode." "Neither test nor predict DataLoader is set for nontraining mode."
) )
elapsed = time.time() - start_time
if elapsed < 60:
logger.info(f"Total execution time: {elapsed:.2f} seconds")
elif elapsed < 3600:
minutes = int(elapsed // 60)
seconds = elapsed % 60
logger.info(f"Total execution time: {minutes} min {seconds:.2f} sec")
else:
hours = int(elapsed // 3600)
minutes = int((elapsed % 3600) // 60)
seconds = elapsed % 60
logger.info(f"Total execution time: {hours} h {minutes} min {seconds:.2f} sec")
def load_from_checkpoint(self, checkpoint_path: str) -> None: def load_from_checkpoint(self, checkpoint_path: str) -> None:
""" """
@ -490,7 +525,12 @@ class CellSegmentator:
""" """
# Write the checkpoint to disk # Write the checkpoint to disk
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(self._model.state_dict(), checkpoint_path) torch.save((
self._model.state_dict()
if self._best_weights is None
else self._best_weights),
checkpoint_path
)
def __parse_config(self, config: Config) -> None: def __parse_config(self, config: Config) -> None:
@ -523,11 +563,7 @@ class CellSegmentator:
self._model = ModelRegistry.get_model_class(model.name)(model.params) self._model = ModelRegistry.get_model_class(model.name)(model.params)
# Loads model weights from a specified checkpoint # Loads model weights from a specified checkpoint
pretrained_weights = ( pretrained_weights = config.dataset_config.common.pretrained_weights
config.dataset_config.training.pretrained_weights
if config.dataset_config.is_training
else config.dataset_config.testing.pretrained_weights
)
if pretrained_weights: if pretrained_weights:
self.load_from_checkpoint(pretrained_weights) self.load_from_checkpoint(pretrained_weights)
logger.info(f"Loaded pre-trained weights from: {pretrained_weights}") logger.info(f"Loaded pre-trained weights from: {pretrained_weights}")
@ -589,6 +625,7 @@ class CellSegmentator:
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}") logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
logger.info(f"├─ Masks subdirectory: {common.masks_subdir}") logger.info(f"├─ Masks subdirectory: {common.masks_subdir}")
logger.info(f"└─ Predictions output dir: {common.predictions_dir}") logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
logger.info(f"├─ Pretrained weights: {common.pretrained_weights or 'None'}")
if config.dataset_config.is_training: if config.dataset_config.is_training:
training = config.dataset_config.training training = config.dataset_config.training
@ -596,7 +633,6 @@ class CellSegmentator:
logger.info(f"├─ Batch size: {training.batch_size}") logger.info(f"├─ Batch size: {training.batch_size}")
logger.info(f"├─ Epochs: {training.num_epochs}") logger.info(f"├─ Epochs: {training.num_epochs}")
logger.info(f"├─ Validation frequency: {training.val_freq}") logger.info(f"├─ Validation frequency: {training.val_freq}")
logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}")
if training.is_split: if training.is_split:
logger.info(f"├─ Using pre-split directories:") logger.info(f"├─ Using pre-split directories:")
@ -619,11 +655,6 @@ class CellSegmentator:
logger.info(f"├─ Test dir: {testing.test_dir}") logger.info(f"├─ Test dir: {testing.test_dir}")
logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})") logger.info(f"├─ Test size: {testing.test_size} (offset: {testing.test_offset})")
logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}") logger.info(f"├─ Shuffle: {'yes' if testing.shuffle else 'no'}")
logger.info(f"├─ Use ensemble: {'yes' if testing.use_ensemble else 'no'}")
logger.info(f"└─ Pretrained weights:")
logger.info(f" ├─ Single model: {testing.pretrained_weights}")
logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}")
logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}")
self._wandb_config = config.wandb_config self._wandb_config = config.wandb_config
if self._wandb_config.use_wandb: if self._wandb_config.use_wandb:
@ -747,38 +778,62 @@ class CellSegmentator:
return Dataset(data, transforms) return Dataset(data, transforms)
def __print_with_logging(self, results: Dict[str, float], step: int) -> None: def __print_with_logging(self, results: Dict[str, Union[float, np.ndarray]], step: int) -> None:
""" """
Print metrics in a tabular format and log to W&B. Print metrics in a tabular format and log to W&B.
Args: Args:
results (Dict[str, float]): results dictionary. results (Dict[str, Union[float, np.ndarray]]): Mapping from metric names
to either a float or a ND numpy array.
step (int): epoch index. step (int): epoch index.
""" """
rows: list[tuple[str, str]] = []
for key, val in results.items():
if isinstance(val, np.ndarray):
# Convert array to string, e.g. '[0.2, 0.8, 0.5]'
val_str = np.array2string(val, separator=', ')
else:
# Format scalar with 4 decimal places
val_str = f"{val:.4f}"
rows.append((key, val_str))
table = tabulate( table = tabulate(
tabular_data=results.items(), tabular_data=rows,
headers=["Metric", "Value"], headers=["Metric", "Value"],
floatfmt=".4f", floatfmt=".4f",
tablefmt="fancy_grid" tablefmt="fancy_grid"
) )
print(table, "\n") print(table, "\n")
if self._wandb_config.use_wandb: if self._wandb_config.use_wandb:
wandb.log(results, step=step) # Keep only scalar values
scalar_results: dict[str, float] = {}
for key, val in results.items():
if isinstance(val, np.ndarray):
continue
# Ensure float type
scalar_results[key] = float(val)
wandb.log(scalar_results, step=step)
def __run_epoch(self, def __run_epoch(self,
mode: Literal["train", "valid", "test"], mode: Literal["train", "valid", "test"],
epoch: Optional[int] = None epoch: Optional[int] = None,
) -> Dict[str, float]: save_results: bool = True,
only_masks: bool = False
) -> Dict[str, Union[float, np.ndarray]]:
""" """
Execute one epoch of training, validation, or testing. Execute one epoch of training, validation, or testing.
Args: Args:
mode (str): One of 'train', 'valid', or 'test'. mode (str): One of 'train', 'valid', or 'test'.
epoch (int, optional): Current epoch number for logging. epoch (int, optional): Current epoch number for logging.
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
Returns: Returns:
Dict[str, float]: Loss metrics and F1 score for valid/test. Dict[str, Union[float, np.ndarray]]: Metrics for valid/test.
""" """
# Ensure required components are available # Ensure required components are available
if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None): if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None):
@ -841,18 +896,23 @@ class CellSegmentator:
) )
# Collecting statistics on the batch # Collecting statistics on the batch
tp, fp, fn = self.__compute_stats( tp, fp, fn, tp_masks, fp_masks, fn_masks = self.__compute_stats(
predicted_masks=preds, predicted_masks=preds,
ground_truth_masks=labels_post, # type: ignore ground_truth_masks=labels_post, # type: ignore
iou_threshold=0.5 iou_threshold=0.5,
return_error_masks=(mode == "test")
) )
all_tp.append(tp) all_tp.append(tp)
all_fp.append(fp) all_fp.append(fp)
all_fn.append(fn) all_fn.append(fn)
if mode == "test": if mode == "test" and save_results is True:
self.__save_prediction_masks( self.__save_prediction_masks(
batch, preds, batch_counter sample=batch,
predicted_mask=preds,
start_index=batch_counter,
only_masks=only_masks,
masks=(tp_masks, fp_masks, fn_masks) # type: ignore
) )
# Backpropagation and optimizer step in training # Backpropagation and optimizer step in training
@ -871,7 +931,9 @@ class CellSegmentator:
if self._criterion is not None: if self._criterion is not None:
# Collect loss metrics # Collect loss metrics
epoch_metrics = {f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()} epoch_metrics: Dict[str, Union[float, np.ndarray]] = {
f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()
}
# Reset internal loss metrics accumulator # Reset internal loss metrics accumulator
self._criterion.reset_metrics() self._criterion.reset_metrics()
else: else:
@ -884,13 +946,13 @@ class CellSegmentator:
fp_array = np.vstack(all_fp) fp_array = np.vstack(all_fp)
fn_array = np.vstack(all_fn) fn_array = np.vstack(all_fn)
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="micro" tp_array, fp_array, fn_array, reduction="micro"
) )
epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( # type: ignore epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="imagewise" tp_array, fp_array, fn_array, reduction="imagewise"
) )
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric(
tp_array, fp_array, fn_array, reduction="macro" tp_array, fp_array, fn_array, reduction="macro"
) )
@ -976,8 +1038,10 @@ class CellSegmentator:
self, self,
predicted_masks: np.ndarray, predicted_masks: np.ndarray,
ground_truth_masks: np.ndarray, ground_truth_masks: np.ndarray,
iou_threshold: float = 0.5 iou_threshold: float = 0.5,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: return_error_masks: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray,
Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
""" """
Compute batch-wise true positives, false positives, and false negatives Compute batch-wise true positives, false positives, and false negatives
for instance segmentation, using a configurable IoU threshold. for instance segmentation, using a configurable IoU threshold.
@ -987,23 +1051,33 @@ class CellSegmentator:
ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W). ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W).
iou_threshold (float): Intersection-over-Union threshold for matching predictions iou_threshold (float): Intersection-over-Union threshold for matching predictions
to ground truths (default: 0.5). to ground truths (default: 0.5).
return_error_masks (bool): Whether to also return binary error masks.
Returns: Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple(np.ndarray, np.ndarray, np.ndarray,
np.ndarray | None, np.ndarray | None, np.ndarray | None):
- tp: True positives per batch and class, shape (B, C) - tp: True positives per batch and class, shape (B, C)
- fp: False positives per batch and class, shape (B, C) - fp: False positives per batch and class, shape (B, C)
- fn: False negatives per batch and class, shape (B, C) - fn: False negatives per batch and class, shape (B, C)
- tp_maks: True positives mask per batch and class, shape (B, C, H, W)
- fp_maks: False positives mask per batch and class, shape (B, C, H, W)
- fn_maks: False negatives mask per batch and class, shape (B, C, H, W)
""" """
stats = compute_batch_segmentation_tp_fp_fn( stats = compute_batch_segmentation_tp_fp_fn(
batch_ground_truth=ground_truth_masks, batch_ground_truth=ground_truth_masks,
batch_prediction=predicted_masks, batch_prediction=predicted_masks,
iou_threshold=iou_threshold, iou_threshold=iou_threshold,
return_error_masks=return_error_masks,
remove_boundary_objects=True remove_boundary_objects=True
) )
tp = stats["tp"] tp = stats["tp"]
fp = stats["fp"] fp = stats["fp"]
fn = stats["fn"] fn = stats["fn"]
return tp, fp, fn
tp_mask = stats["tp_mask"] if return_error_masks else None
fp_mask = stats["fp_mask"] if return_error_masks else None
fn_mask = stats["fn_mask"] if return_error_masks else None
return tp, fp, fn, tp_mask, fp_mask, fn_mask
def __compute_f1_metric( def __compute_f1_metric(
@ -1011,7 +1085,7 @@ class CellSegmentator:
true_positives: np.ndarray, true_positives: np.ndarray,
false_positives: np.ndarray, false_positives: np.ndarray,
false_negatives: np.ndarray, false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro" reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro"
) -> Union[float, np.ndarray]: ) -> Union[float, np.ndarray]:
""" """
Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes. Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes.
@ -1023,8 +1097,9 @@ class CellSegmentator:
reduction: reduction:
- 'none': return F1 for each sample, class shape (batch_size, num_classes) - 'none': return F1 for each sample, class shape (batch_size, num_classes)
- 'micro': global F1 over all samples & classes - 'micro': global F1 over all samples & classes
- 'imagewise': F1 per sample (summing over classes), then average over samples
- 'macro': average class-wise F1 (classes summed over batch) - 'macro': average class-wise F1 (classes summed over batch)
- 'imagewise': F1 per sample (summing over classes), then average over samples
- 'per_class': F1 per class (summing over batch), return vector of shape (num_classes,)
- 'weighted': class-wise F1 weighted by support (TP+FN) - 'weighted': class-wise F1 weighted by support (TP+FN)
Returns: Returns:
float for reductions 'micro', 'imagewise', 'macro', 'weighted'; float for reductions 'micro', 'imagewise', 'macro', 'weighted';
@ -1040,7 +1115,11 @@ class CellSegmentator:
tp_val = int(true_positives[i, c]) tp_val = int(true_positives[i, c])
fp_val = int(false_positives[i, c]) fp_val = int(false_positives[i, c])
fn_val = int(false_negatives[i, c]) fn_val = int(false_negatives[i, c])
_, _, f1_val = compute_f1_score(tp_val, fp_val, fn_val) _, _, f1_val = compute_f1_score(
tp_val,
fp_val,
fn_val
)
f1_matrix[i, c] = f1_val f1_matrix[i, c] = f1_val
return f1_matrix return f1_matrix
@ -1049,7 +1128,11 @@ class CellSegmentator:
tp_total = int(true_positives.sum()) tp_total = int(true_positives.sum())
fp_total = int(false_positives.sum()) fp_total = int(false_positives.sum())
fn_total = int(false_negatives.sum()) fn_total = int(false_negatives.sum())
_, _, f1_global = compute_f1_score(tp_total, fp_total, fn_total) _, _, f1_global = compute_f1_score(
tp_total,
fp_total,
fn_total
)
return f1_global return f1_global
# 3) Imagewise: compute per-sample F1 (sum over classes), then average # 3) Imagewise: compute per-sample F1 (sum over classes), then average
@ -1059,16 +1142,31 @@ class CellSegmentator:
tp_i = int(true_positives[i].sum()) tp_i = int(true_positives[i].sum())
fp_i = int(false_positives[i].sum()) fp_i = int(false_positives[i].sum())
fn_i = int(false_negatives[i].sum()) fn_i = int(false_negatives[i].sum())
_, _, f1_i = compute_f1_score(tp_i, fp_i, fn_i) _, _, f1_i = compute_f1_score(
tp_i,
fp_i,
fn_i
)
f1_per_image[i] = f1_i f1_per_image[i] = f1_i
return float(f1_per_image.mean()) return float(f1_per_image.mean())
# For macro/weighted, first aggregate per class across the batch # Aggregate per class across the batch for per_class, macro, weighted
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,) tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
fp_per_class = false_positives.sum(axis=0).astype(int) fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Macro: average F1 across classes equally # 4) Per-class: compute F1 for each class and return vector
if reduction == "per_class":
f1_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
_, _, f1_per_class[c] = compute_f1_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
return f1_per_class
# 5) Macro: average F1 across classes equally
if reduction == "macro": if reduction == "macro":
f1_per_class = np.zeros(num_classes, dtype=float) f1_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes): for c in range(num_classes):
@ -1080,7 +1178,7 @@ class CellSegmentator:
f1_per_class[c] = f1_c f1_per_class[c] = f1_c
return float(f1_per_class.mean()) return float(f1_per_class.mean())
# 5) Weighted: class-wise F1 weighted by support = TP + FN # 6) Weighted: class-wise F1 weighted by support = TP + FN
if reduction == "weighted": if reduction == "weighted":
f1_per_class = np.zeros(num_classes, dtype=float) f1_per_class = np.zeros(num_classes, dtype=float)
support = np.zeros(num_classes, dtype=float) support = np.zeros(num_classes, dtype=float)
@ -1088,7 +1186,11 @@ class CellSegmentator:
tp_c = tp_per_class[c] tp_c = tp_per_class[c]
fp_c = fp_per_class[c] fp_c = fp_per_class[c]
fn_c = fn_per_class[c] fn_c = fn_per_class[c]
_, _, f1_c = compute_f1_score(tp_c, fp_c, fn_c) _, _, f1_c = compute_f1_score(
tp_c,
fp_c,
fn_c
)
f1_per_class[c] = f1_c f1_per_class[c] = f1_c
support[c] = tp_c + fn_c support[c] = tp_c + fn_c
total_support = support.sum() total_support = support.sum()
@ -1106,7 +1208,7 @@ class CellSegmentator:
true_positives: np.ndarray, true_positives: np.ndarray,
false_positives: np.ndarray, false_positives: np.ndarray,
false_negatives: np.ndarray, false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro" reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro"
) -> Union[float, np.ndarray]: ) -> Union[float, np.ndarray]:
""" """
Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes. Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes.
@ -1121,8 +1223,9 @@ class CellSegmentator:
reduction: reduction:
- 'none': return AP for each sample and class shape (batch_size, num_classes) - 'none': return AP for each sample and class shape (batch_size, num_classes)
- 'micro': global AP over all samples & classes - 'micro': global AP over all samples & classes
- 'imagewise': AP per sample (summing stats over classes), then average over batch
- 'macro': average class-wise AP (each class summed over batch) - 'macro': average class-wise AP (each class summed over batch)
- 'imagewise': AP per sample (summing stats over classes), then average over batch
- 'per_class': AP per class (summing over batch), return vector of shape (num_classes,)
- 'weighted': class-wise AP weighted by support (TP+FN) - 'weighted': class-wise AP weighted by support (TP+FN)
Returns: Returns:
@ -1139,7 +1242,11 @@ class CellSegmentator:
tp_val = int(true_positives[i, c]) tp_val = int(true_positives[i, c])
fp_val = int(false_positives[i, c]) fp_val = int(false_positives[i, c])
fn_val = int(false_negatives[i, c]) fn_val = int(false_negatives[i, c])
ap_val = compute_average_precision_score(tp_val, fp_val, fn_val) ap_val = compute_average_precision_score(
tp_val,
fp_val,
fn_val
)
ap_matrix[i, c] = ap_val ap_matrix[i, c] = ap_val
return ap_matrix return ap_matrix
@ -1148,7 +1255,11 @@ class CellSegmentator:
tp_total = int(true_positives.sum()) tp_total = int(true_positives.sum())
fp_total = int(false_positives.sum()) fp_total = int(false_positives.sum())
fn_total = int(false_negatives.sum()) fn_total = int(false_negatives.sum())
return compute_average_precision_score(tp_total, fp_total, fn_total) return compute_average_precision_score(
tp_total,
fp_total,
fn_total
)
# 3) Imagewise: compute per-sample AP (sum over classes), then mean # 3) Imagewise: compute per-sample AP (sum over classes), then mean
if reduction == "imagewise": if reduction == "imagewise":
@ -1157,7 +1268,11 @@ class CellSegmentator:
tp_i = int(true_positives[i].sum()) tp_i = int(true_positives[i].sum())
fp_i = int(false_positives[i].sum()) fp_i = int(false_positives[i].sum())
fn_i = int(false_negatives[i].sum()) fn_i = int(false_negatives[i].sum())
ap_per_image[i] = compute_average_precision_score(tp_i, fp_i, fn_i) ap_per_image[i] = compute_average_precision_score(
tp_i,
fp_i,
fn_i
)
return float(ap_per_image.mean()) return float(ap_per_image.mean())
# For macro and weighted: first aggregate per class across batch # For macro and weighted: first aggregate per class across batch
@ -1165,7 +1280,18 @@ class CellSegmentator:
fp_per_class = false_positives.sum(axis=0).astype(int) fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Macro: average AP across classes equally # 4) Per-class: compute F1 for each class and return vector
if reduction == "per_class":
ap_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
ap_per_class[c] = compute_average_precision_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
return ap_per_class
# 5) Macro: average AP across classes equally
if reduction == "macro": if reduction == "macro":
ap_per_class = np.zeros(num_classes, dtype=float) ap_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes): for c in range(num_classes):
@ -1176,7 +1302,7 @@ class CellSegmentator:
) )
return float(ap_per_class.mean()) return float(ap_per_class.mean())
# 5) Weighted: class-wise AP weighted by support = TP + FN # 6) Weighted: class-wise AP weighted by support = TP + FN
if reduction == "weighted": if reduction == "weighted":
ap_per_class = np.zeros(num_classes, dtype=float) ap_per_class = np.zeros(num_classes, dtype=float)
support = np.zeros(num_classes, dtype=float) support = np.zeros(num_classes, dtype=float)
@ -1184,7 +1310,11 @@ class CellSegmentator:
tp_c = tp_per_class[c] tp_c = tp_per_class[c]
fp_c = fp_per_class[c] fp_c = fp_per_class[c]
fn_c = fn_per_class[c] fn_c = fn_per_class[c]
ap_per_class[c] = compute_average_precision_score(tp_c, fp_c, fn_c) ap_per_class[c] = compute_average_precision_score(
tp_c,
fp_c,
fn_c
)
support[c] = tp_c + fn_c support[c] = tp_c + fn_c
total_support = support.sum() total_support = support.sum()
if total_support == 0: if total_support == 0:
@ -1215,86 +1345,158 @@ class CellSegmentator:
sample: Dict[str, Any], sample: Dict[str, Any],
predicted_mask: Union[np.ndarray, torch.Tensor], predicted_mask: Union[np.ndarray, torch.Tensor],
start_index: int = 0, start_index: int = 0,
only_masks: bool = False,
masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None
) -> None: ) -> None:
""" """
Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders. Save multi-channel predicted masks as TIFFs and
corresponding visualizations as PNGs in separate folders.
Args: Args:
sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict'). sample (Dict[str, Any]): Batch sample from MONAI
LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W). predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
start_index (int): Starting index for naming when metadata is missing. start_index (int): Starting index for naming when metadata is missing.
only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations.
masks (Tuple[np.ndarray, np.ndarray, np.ndarray] | None):
A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None.
""" """
# Determine base paths # Base directories (created once per call)
base_output_dir = self._dataset_setup.common.predictions_dir base_output_dir = self._dataset_setup.common.predictions_dir
masks_dir = base_output_dir masks_dir = base_output_dir
plots_dir = os.path.join(base_output_dir, "plots") plots_dir = os.path.join(base_output_dir, "plots")
evaluate_dir = os.path.join(plots_dir, "evaluate")
os.makedirs(masks_dir, exist_ok=True) os.makedirs(masks_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True) os.makedirs(plots_dir, exist_ok=True)
os.makedirs(evaluate_dir, exist_ok=True)
# Extract image (C, H, W) or batch of images (B, C, H, W), and metadata # Convert tensors to numpy if necessary
image_obj = sample.get("image") # Expected shape: (C, H, W) or (B, C, H, W)
mask_obj = sample.get("mask") # Expected shape: (C, H, W) or (B, C, H, W)
image_meta = sample.get("image_meta_dict")
# Convert tensors to numpy
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(x, torch.Tensor): return x.cpu().numpy() if isinstance(x, torch.Tensor) else x
return x.cpu().numpy()
return x
image_array = to_numpy(image_obj) if image_obj is not None else None pred_array = to_numpy(predicted_mask).astype(np.uint16)
mask_array = to_numpy(mask_obj) if mask_obj is not None else None
pred_array = to_numpy(predicted_mask)
# Handle batch dimension: (B, C, H, W) # Handle batch dimension
if pred_array.ndim == 4:
for idx in range(pred_array.shape[0]): for idx in range(pred_array.shape[0]):
batch_sample: Dict[str, Any] = {} batch_sample: Dict[str, Any] = {}
if image_array is not None and image_array.ndim == 4: # copy per-sample image and meta
batch_sample["image"] = image_array[idx] img = to_numpy(sample["image"])
if img.ndim == 4:
batch_sample["image"] = img[idx]
if "mask" in sample:
msk = to_numpy(sample["mask"]).astype(np.uint16)
if msk.ndim == 4:
batch_sample["mask"] = msk[idx]
image_meta = sample.get("image_meta_dict")
if isinstance(image_meta, dict) and "filename_or_obj" in image_meta: if isinstance(image_meta, dict) and "filename_or_obj" in image_meta:
batch_sample["image_meta_dict"] = image_meta["filename_or_obj"][idx] fname = image_meta["filename_or_obj"][idx]
if mask_array is not None and mask_array.ndim == 4: batch_sample["image_name"] = fname
batch_sample["mask"] = mask_array[idx]
self.__save_prediction_masks( single_masks = (
batch_sample, (masks[0][idx], masks[1][idx], masks[2][idx]) if masks is not None else None
pred_array[idx], )
start_index=start_index+idx self.__save_single_prediction_mask(
sample=batch_sample,
pred_array=pred_array[idx],
start_index=start_index + idx,
masks_dir=masks_dir,
plots_dir=plots_dir,
evaluate_dir=evaluate_dir,
only_masks=only_masks,
masks=single_masks,
) )
return
# Determine base filename
def __save_single_prediction_mask(
self,
sample: Dict[str, Any],
pred_array: np.ndarray,
start_index: int,
masks_dir: str,
plots_dir: str,
evaluate_dir: str,
only_masks: bool = False,
masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None
) -> None:
"""
Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations.
Assumes output directories already exist.
Args:
sample (Dict[str, Any]): Dictionary containing 'image', 'mask',
and optional 'image_meta_dict' for metadata.
pred_array (np.ndarray): Predicted mask array of shape (C,H,W).
start_index (int): Base index for generating filenames when metadata is missing.
masks_dir (str): Directory for saving TIFF masks.
plots_dir (str): Directory for saving PNG visualizations.
evaluate_dir (str): Directory for saving PNG visualizations of evaluation results.
only_masks (bool): If True, saves only TIFF mask files; skips PNG plots.
masks (Tuple[np.ndarray, np.ndarray, np.ndarray], optional): A tuple of
true-positive, false-positive, and false-negative mask arrays,
each of shape (C,H,W). Defaults to None.
"""
if pred_array.ndim == 2:
pred_array = np.expand_dims(pred_array, axis=0)
elif pred_array.ndim != 3:
raise ValueError(
f"Unsupported predicted_mask dimensions: {pred_array.ndim}."
"Expected 2D (H,W) or 3D (C,H,W)."
)
# Handle image array if present
image_array: np.ndarray = sample["image"]
if image_array.ndim == 2:
image_array = np.expand_dims(image_array, axis=0)
elif image_array.ndim != 3:
raise ValueError(
f"Unsupported image dimensions: {image_array.ndim}."
"Expected 2D (H,W) or 3D (C,H,W)."
)
true_mask_array: Optional[np.ndarray] = sample.get("mask")
if isinstance(true_mask_array, np.ndarray):
if true_mask_array.ndim == 2:
true_mask_array = np.expand_dims(true_mask_array, axis=0)
elif true_mask_array.ndim != 3:
raise ValueError(
f"Unsupported true_mask_array dimensions: {true_mask_array.ndim}."
"Expected 2D (H,W) or 3D (C,H,W)."
)
# Determine filename base
image_meta = sample.get("image_name")
if isinstance(image_meta, (str, os.PathLike)): if isinstance(image_meta, (str, os.PathLike)):
base_name = os.path.splitext(os.path.basename(image_meta))[0] base_name = os.path.splitext(os.path.basename(image_meta))[0]
else: else:
# Use provided start_index when metadata missing
base_name = f"prediction_{start_index:04d}" base_name = f"prediction_{start_index:04d}"
# Save mask TIFF (16-bit) # Save main mask TIFF
mask_filename = f"{base_name}_mask.tif" mask_path = os.path.join(masks_dir, f"{base_name}_mask.tif")
mask_path = os.path.join(masks_dir, mask_filename)
tiff.imwrite(mask_path, pred_array.astype(np.uint16), compression="zlib") tiff.imwrite(mask_path, pred_array.astype(np.uint16), compression="zlib")
# Now pred_array shape is (C, H, W) if only_masks:
num_channels = pred_array.shape[0] return
for channel_idx in range(num_channels):
channel_mask = pred_array[channel_idx]
# File names
plot_filename = f"{base_name}_ch{channel_idx:01d}.png"
plot_path = os.path.join(plots_dir, plot_filename)
# Extract corresponding true mask channel if exists # Save channel-wise plots
true_mask = None num_channels = pred_array.shape[0]
if mask_array is not None and mask_array.ndim == 3: for ch in range(num_channels):
true_mask = mask_array[channel_idx] true_ch = true_mask_array[ch] if true_mask_array is not None else None
# Generate and save visualization
self.__plot_mask( self.__plot_mask(
file_path=plot_path, file_path=os.path.join(plots_dir, f"{base_name}_ch{ch}.png"),
image_data=image_array, # type: ignore image_data=image_array,
predicted_mask=channel_mask, predicted_mask=pred_array[ch],
true_mask=true_mask, true_mask=true_ch,
)
if masks is not None and true_ch is not None:
self.__save_mask_comparison_visuals(
gt=true_ch,
pred=pred_array[ch],
tp_mask=masks[0][ch],
fp_mask=masks[1][ch],
fn_mask=masks[2][ch],
file_path=os.path.join(evaluate_dir, f"{base_name}_ch{ch}.png")
) )
@ -1307,6 +1509,16 @@ class CellSegmentator:
) -> None: ) -> None:
""" """
Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided. Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided.
Args:
file_path (str): Path where the visualization image will be saved.
image_data (np.ndarray): The original input image array, expected shape (C, H, W).
predicted_mask (np.ndarray): The predicted mask array, shape (H, W),
depending on the task.
true_mask (Optional[np.ndarray], optional): The ground-truth mask array.
If provided, an additional row with true mask and overlap visualization
will be added to the plot. Default is None.
""" """
img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data
@ -1317,7 +1529,7 @@ class CellSegmentator:
('Original Image','Predicted Mask','Predicted Contours')) ('Original Image','Predicted Mask','Predicted Contours'))
else: else:
fig, axs = plt.subplots(2,3,figsize=(15,10)) fig, axs = plt.subplots(2,3,figsize=(15,10))
plt.subplots_adjust(wspace=0.02, hspace=0.1) plt.subplots_adjust(wspace=0.02, hspace=0.02)
# row 0: predicted # row 0: predicted
self.__plot_panels(axs[0], img, predicted_mask, 'red', self.__plot_panels(axs[0], img, predicted_mask, 'red',
('Original Image','Predicted Mask','Predicted Contours')) ('Original Image','Predicted Mask','Predicted Contours'))
@ -1372,6 +1584,69 @@ class CellSegmentator:
ax2.axis('off') ax2.axis('off')
def __save_mask_comparison_visuals(
self,
gt: np.ndarray,
pred: np.ndarray,
tp_mask: np.ndarray,
fp_mask: np.ndarray,
fn_mask: np.ndarray,
file_path: str
) -> None:
"""
Creates and saves a 1x3 subplot figure showing:
1) True mask with boundaries
2) Predicted mask without boundaries
3) Overlay mask combining FP (R), TP (G), FN (B)
Args:
gt (np.ndarray): Ground truth mask (H, W).
pred (np.ndarray): Predicted mask (H, W).
tp_mask (np.ndarray): True positive mask (H, W).
fp_mask (np.ndarray): False positive mask (H, W).
fn_mask (np.ndarray): False negative mask (H, W).
file_path (str): Path where the visualization image will be saved.
"""
# Prepare overlay mask
overlap_mask = np.zeros((*gt.shape[:2], 3), dtype=np.uint8)
overlap_mask[..., 0] = np.where(fp_mask, 255, 0)
overlap_mask[..., 1] = np.where(tp_mask, 255, 0)
overlap_mask[..., 2] = np.where(fn_mask, 255, 0)
# Set up figure
fig, axes = plt.subplots(1, 3, figsize=(15, 5),
gridspec_kw={'width_ratios': [1, 1, 1]})
plt.subplots_adjust(wspace=0.02, hspace=0.0,
left=0.05, right=0.95, top=0.95, bottom=0.05)
# Colormap for instances
num_instances = max(np.max(gt), np.max(pred))
cmap = plt.get_cmap("gist_ncar")
colors = [cmap(i / num_instances) for i in range(num_instances)]
cmap = mcolors.ListedColormap(colors)
# Plot true mask
axes[0].imshow(gt, cmap=cmap)
axes[0].contour(find_boundaries(gt, mode="thick"), colors="black", linewidths=0.5)
axes[0].set_title("True Mask")
axes[0].axis("off")
# Plot predicted mask
axes[1].imshow(pred, cmap=cmap)
axes[1].contour(find_boundaries(pred, mode="thick"), colors="black", linewidths=0.5)
axes[1].set_title("Predicted Mask")
axes[1].axis("off")
# Plot overlay
axes[2].imshow(overlap_mask)
axes[2].set_title("Overlay Mask (R-FP; G-TP; B-FN)")
axes[2].axis("off")
# Save
plt.savefig(file_path, bbox_inches="tight", dpi=300)
plt.close()
def __compute_flows_from_masks( def __compute_flows_from_masks(
self, self,
true_masks: Tensor true_masks: Tensor
@ -1522,7 +1797,7 @@ class CellSegmentator:
flow_output = np.zeros((2, height, width), dtype=np.float32) flow_output = np.zeros((2, height, width), dtype=np.float32)
ys_np = y.cpu().numpy() - 1 ys_np = y.cpu().numpy() - 1
xs_np = x.cpu().numpy() - 1 xs_np = x.cpu().numpy() - 1
flow_output[:, ys_np, xs_np] = mu flow_output[:, ys_np, xs_np] = mu.reshape(2, -1)
flows[2*channel: 2*channel + 2] = flow_output flows[2*channel: 2*channel + 2] = flow_output
return flows return flows

@ -21,7 +21,7 @@ __all__ = [
"compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score" "compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score"
] ]
logger = get_logger() logger = get_logger(__name__)
def compute_f1_score( def compute_f1_score(

@ -23,6 +23,20 @@ def main():
default='train', default='train',
help='Run mode: train, test or predict' help='Run mode: train, test or predict'
) )
parser.add_argument(
'-s', '--save-masks',
action='store_true',
default=True,
help='If set to False, do not save predicted masks; by default, saving is enabled'
)
parser.add_argument(
'--only-masks',
action='store_true',
default=False,
help=('If set and save-masks set, save only the raw predicted'
' masks without additional visualizations or metrics')
)
args = parser.parse_args() args = parser.parse_args()
mode = args.mode mode = args.mode
@ -56,9 +70,11 @@ def main():
if config.wandb_config.use_wandb: if config.wandb_config.use_wandb:
wandb.watch(segmentator._model, log="all", log_graph=True) wandb.watch(segmentator._model, log="all", log_graph=True)
# Run training (or prediction, if implemented) try:
segmentator.run() segmentator.run(save_results=args.save_masks, only_masks=args.only_masks)
except Exception:
raise
finally:
if config.dataset_config.is_training: if config.dataset_config.is_training:
# Prepare saving path # Prepare saving path
weights_dir = ( weights_dir = (
@ -74,5 +90,6 @@ def main():
wandb.save(saving_path) wandb.save(saving_path)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

Loading…
Cancel
Save