@ -1,3 +1,4 @@
import time
import random
import numpy as np
from numba import njit , prange
@ -69,6 +70,8 @@ class CellSegmentator:
self . _valid_dataloader : Optional [ DataLoader ] = None
self . _test_dataloader : Optional [ DataLoader ] = None
self . _predict_dataloader : Optional [ DataLoader ] = None
self . _best_weights = None
def create_dataloaders (
@ -315,9 +318,14 @@ class CellSegmentator:
logger . info ( l )
def train ( self ) - > None :
def train ( self , save_results : bool = True , only_masks : bool = False ) - > None :
"""
Train the model over multiple epochs , including validation and test .
Args :
save_results ( bool ) : If True , the predicted masks and test metrics will be saved .
only_masks ( bool ) : If True and save_results is True , only raw predicted masks are saved ,
without visualization overlays .
"""
# Ensure training is enabled in dataset setup
if not self . _dataset_setup . is_training :
@ -335,7 +343,6 @@ class CellSegmentator:
logger . info ( f " \n { ' = ' * 50 } " )
best_f1_score = 0.0
best_weights = None
for epoch in range ( 1 , self . _dataset_setup . training . num_epochs + 1 ) :
train_metrics = self . __run_epoch ( " train " , epoch )
@ -356,29 +363,38 @@ class CellSegmentator:
if f1 > best_f1_score :
best_f1_score = f1
# Deep copy weights to avoid reference issues
best_weights = copy . deepcopy ( self . _model . state_dict ( ) )
self . _ best_weights = copy . deepcopy ( self . _model . state_dict ( ) )
logger . info ( f " Updated best model weights with F1 score: { f1 : .4f } " )
# Restore best model weights if available
if best_weights is not None :
self . _model . load_state_dict ( best_weights)
if self . _ best_weights is not None :
self . _model . load_state_dict ( self . _ best_weights)
if self . _test_dataloader is not None :
test_metrics = self . __run_epoch ( " test " )
test_metrics = self . __run_epoch ( " test " , save_results = save_results , only_masks = only_masks )
self . __print_with_logging ( test_metrics , 0 )
def evaluate ( self ) - > None :
def evaluate ( self , save_results : bool = True , only_masks : bool = False ) - > None :
"""
Run a full test epoch and display / log the resulting metrics .
Args :
save_results ( bool ) : If True , the predicted masks and test metrics will be saved .
only_masks ( bool ) : If True and save_results is True , only raw predicted masks are saved ,
without visualization overlays .
"""
test_metrics = self . __run_epoch ( " test " )
test_metrics = self . __run_epoch ( " test " , save_results = save_results , only_masks = only_masks )
self . __print_with_logging ( test_metrics , 0 )
def predict ( self ) - > None :
def predict ( self , only_masks : bool = False ) - > None :
"""
Run inference on the predict set and save the resulting instance masks .
Args :
only_masks ( bool ) : If True , only raw predicted masks are saved ,
without visualization overlays .
"""
# Ensure the predict DataLoader has been set
if self . _predict_dataloader is None :
@ -404,43 +420,62 @@ class CellSegmentator:
preds , _ = self . __post_process_predictions ( raw_output )
# Save out the predicted masks, using batch_counter to index files
self . __save_prediction_masks ( batch , preds , batch_counter )
self . __save_prediction_masks (
sample = batch ,
predicted_mask = preds ,
start_index = batch_counter ,
only_masks = only_masks
)
# Increment counter by batch size for unique file naming
batch_counter + = inputs . shape [ 0 ]
def run ( self ) - > None :
def run ( self , save_results : bool = True , only_masks : bool = False ) - > None :
"""
Orchestrate the full workflow :
Orchestrate the full workflow and report execution time :
- If training is enabled in the dataset setup , start training .
- Otherwise , if a test DataLoader is provided , run evaluation .
- Else if a prediction DataLoader is provided , run inference / prediction .
- If neither loader is available in non ‐ training mode , raise an error .
Args :
save_results ( bool ) : If True , the predicted masks and test metrics will be saved .
only_masks ( bool ) : If True and save_results is True , only raw predicted masks are saved ,
without visualization overlays .
"""
start_time = time . time ( )
# 1) TRAINING PATH
if self . _dataset_setup . is_training :
# Launch the full training loop (with validation, scheduler steps, etc.)
self . train ( )
return
# 2) NON-TRAINING PATH (TEST or PREDICT)
# Prefer test if available
if self . _test_dataloader is not None :
# Run a single evaluation epoch on the test set and log metrics
self . evaluate ( )
return
# If no test loader, fall back to prediction if available
if self . _predict_dataloader is not None :
# Run inference on the predict set and save outputs
self . predict ( )
return
self . train ( save_results = save_results , only_masks = only_masks )
else :
# 2) NON-TRAINING PATH (TEST or PREDICT)
if self . _test_dataloader is not None :
# Run a single evaluation epoch on the test set and log metrics
self . evaluate ( save_results = save_results , only_masks = only_masks )
elif self . _predict_dataloader is not None :
# Run inference on the predict set and save outputs
self . predict ( only_masks = only_masks )
else :
# 3) ERROR: no appropriate loader found
raise RuntimeError (
" Neither test nor predict DataLoader is set for non‐ training mode. "
)
# 3) ERROR: no appropriate loader found
raise RuntimeError (
" Neither test nor predict DataLoader is set for non‐ training mode. "
)
elapsed = time . time ( ) - start_time
if elapsed < 60 :
logger . info ( f " Total execution time: { elapsed : .2f } seconds " )
elif elapsed < 3600 :
minutes = int ( elapsed / / 60 )
seconds = elapsed % 60
logger . info ( f " Total execution time: { minutes } min { seconds : .2f } sec " )
else :
hours = int ( elapsed / / 3600 )
minutes = int ( ( elapsed % 3600 ) / / 60 )
seconds = elapsed % 60
logger . info ( f " Total execution time: { hours } h { minutes } min { seconds : .2f } sec " )
def load_from_checkpoint ( self , checkpoint_path : str ) - > None :
@ -490,7 +525,12 @@ class CellSegmentator:
"""
# Write the checkpoint to disk
os . makedirs ( os . path . dirname ( checkpoint_path ) , exist_ok = True )
torch . save ( self . _model . state_dict ( ) , checkpoint_path )
torch . save ( (
self . _model . state_dict ( )
if self . _best_weights is None
else self . _best_weights ) ,
checkpoint_path
)
def __parse_config ( self , config : Config ) - > None :
@ -523,11 +563,7 @@ class CellSegmentator:
self . _model = ModelRegistry . get_model_class ( model . name ) ( model . params )
# Loads model weights from a specified checkpoint
pretrained_weights = (
config . dataset_config . training . pretrained_weights
if config . dataset_config . is_training
else config . dataset_config . testing . pretrained_weights
)
pretrained_weights = config . dataset_config . common . pretrained_weights
if pretrained_weights :
self . load_from_checkpoint ( pretrained_weights )
logger . info ( f " Loaded pre-trained weights from: { pretrained_weights } " )
@ -589,6 +625,7 @@ class CellSegmentator:
logger . info ( f " ├─ Use AMP: { ' yes ' if common . use_amp else ' no ' } " )
logger . info ( f " ├─ Masks subdirectory: { common . masks_subdir } " )
logger . info ( f " └─ Predictions output dir: { common . predictions_dir } " )
logger . info ( f " ├─ Pretrained weights: { common . pretrained_weights or ' None ' } " )
if config . dataset_config . is_training :
training = config . dataset_config . training
@ -596,7 +633,6 @@ class CellSegmentator:
logger . info ( f " ├─ Batch size: { training . batch_size } " )
logger . info ( f " ├─ Epochs: { training . num_epochs } " )
logger . info ( f " ├─ Validation frequency: { training . val_freq } " )
logger . info ( f " ├─ Pretrained weights: { training . pretrained_weights or ' None ' } " )
if training . is_split :
logger . info ( f " ├─ Using pre-split directories: " )
@ -619,11 +655,6 @@ class CellSegmentator:
logger . info ( f " ├─ Test dir: { testing . test_dir } " )
logger . info ( f " ├─ Test size: { testing . test_size } (offset: { testing . test_offset } ) " )
logger . info ( f " ├─ Shuffle: { ' yes ' if testing . shuffle else ' no ' } " )
logger . info ( f " ├─ Use ensemble: { ' yes ' if testing . use_ensemble else ' no ' } " )
logger . info ( f " └─ Pretrained weights: " )
logger . info ( f " ├─ Single model: { testing . pretrained_weights } " )
logger . info ( f " ├─ Ensemble model 1: { testing . ensemble_pretrained_weights1 } " )
logger . info ( f " └─ Ensemble model 2: { testing . ensemble_pretrained_weights2 } " )
self . _wandb_config = config . wandb_config
if self . _wandb_config . use_wandb :
@ -747,38 +778,62 @@ class CellSegmentator:
return Dataset ( data , transforms )
def __print_with_logging ( self , results : Dict [ str , float ] , step : int ) - > None :
def __print_with_logging ( self , results : Dict [ str , Union [ float , np . ndarray ] ] , step : int ) - > None :
"""
Print metrics in a tabular format and log to W & B .
Args :
results ( Dict [ str , float ] ) : results dictionary .
results ( Dict [ str , Union [ float , np . ndarray ] ] ) : Mapping from metric names
to either a float or a ND numpy array .
step ( int ) : epoch index .
"""
rows : list [ tuple [ str , str ] ] = [ ]
for key , val in results . items ( ) :
if isinstance ( val , np . ndarray ) :
# Convert array to string, e.g. '[0.2, 0.8, 0.5]'
val_str = np . array2string ( val , separator = ' , ' )
else :
# Format scalar with 4 decimal places
val_str = f " { val : .4f } "
rows . append ( ( key , val_str ) )
table = tabulate (
tabular_data = results . items ( ) ,
tabular_data = r ows ,
headers = [ " Metric " , " Value " ] ,
floatfmt = " .4f " ,
tablefmt = " fancy_grid "
)
print ( table , " \n " )
if self . _wandb_config . use_wandb :
wandb . log ( results , step = step )
# Keep only scalar values
scalar_results : dict [ str , float ] = { }
for key , val in results . items ( ) :
if isinstance ( val , np . ndarray ) :
continue
# Ensure float type
scalar_results [ key ] = float ( val )
wandb . log ( scalar_results , step = step )
def __run_epoch ( self ,
mode : Literal [ " train " , " valid " , " test " ] ,
epoch : Optional [ int ] = None
) - > Dict [ str , float ] :
epoch : Optional [ int ] = None ,
save_results : bool = True ,
only_masks : bool = False
) - > Dict [ str , Union [ float , np . ndarray ] ] :
"""
Execute one epoch of training , validation , or testing .
Args :
mode ( str ) : One of ' train ' , ' valid ' , or ' test ' .
epoch ( int , optional ) : Current epoch number for logging .
save_results ( bool ) : If True , the predicted masks and test metrics will be saved .
only_masks ( bool ) : If True and save_results is True , only raw predicted masks are saved ,
without visualization overlays .
Returns :
Dict [ str , float ] : Loss metrics and F1 score for valid / test .
Dict [ str , Union [ float , np . ndarray ] ] : Metrics for valid / test .
"""
# Ensure required components are available
if mode in ( " train " , " valid " ) and ( self . _optimizer is None or self . _criterion is None ) :
@ -841,18 +896,23 @@ class CellSegmentator:
)
# Collecting statistics on the batch
tp , fp , fn = self . __compute_stats (
tp , fp , fn , tp_masks , fp_masks , fn_masks = self . __compute_stats (
predicted_masks = preds ,
ground_truth_masks = labels_post , # type: ignore
iou_threshold = 0.5
iou_threshold = 0.5 ,
return_error_masks = ( mode == " test " )
)
all_tp . append ( tp )
all_fp . append ( fp )
all_fn . append ( fn )
if mode == " test " :
if mode == " test " and save_results is True :
self . __save_prediction_masks (
batch , preds , batch_counter
sample = batch ,
predicted_mask = preds ,
start_index = batch_counter ,
only_masks = only_masks ,
masks = ( tp_masks , fp_masks , fn_masks ) # type: ignore
)
# Backpropagation and optimizer step in training
@ -871,7 +931,9 @@ class CellSegmentator:
if self . _criterion is not None :
# Collect loss metrics
epoch_metrics = { f " { mode } _ { name } " : value for name , value in self . _criterion . get_loss_metrics ( ) . items ( ) }
epoch_metrics : Dict [ str , Union [ float , np . ndarray ] ] = {
f " { mode } _ { name } " : value for name , value in self . _criterion . get_loss_metrics ( ) . items ( )
}
# Reset internal loss metrics accumulator
self . _criterion . reset_metrics ( )
else :
@ -884,13 +946,13 @@ class CellSegmentator:
fp_array = np . vstack ( all_fp )
fn_array = np . vstack ( all_fn )
epoch_metrics [ f " { mode } _f1_score " ] = self . __compute_f1_metric ( # type: ignore
epoch_metrics [ f " { mode } _f1_score " ] = self . __compute_f1_metric (
tp_array , fp_array , fn_array , reduction = " micro "
)
epoch_metrics [ f " { mode } _f1_score_iw " ] = self . __compute_f1_metric ( # type: ignore
epoch_metrics [ f " { mode } _f1_score_iw " ] = self . __compute_f1_metric (
tp_array , fp_array , fn_array , reduction = " imagewise "
)
epoch_metrics [ f " { mode } _mAP " ] = self . __compute_average_precision_metric ( # type: ignore
epoch_metrics [ f " { mode } _mAP " ] = self . __compute_average_precision_metric (
tp_array , fp_array , fn_array , reduction = " macro "
)
@ -976,8 +1038,10 @@ class CellSegmentator:
self ,
predicted_masks : np . ndarray ,
ground_truth_masks : np . ndarray ,
iou_threshold : float = 0.5
) - > Tuple [ np . ndarray , np . ndarray , np . ndarray ] :
iou_threshold : float = 0.5 ,
return_error_masks : bool = False
) - > Tuple [ np . ndarray , np . ndarray , np . ndarray ,
Optional [ np . ndarray ] , Optional [ np . ndarray ] , Optional [ np . ndarray ] ] :
"""
Compute batch - wise true positives , false positives , and false negatives
for instance segmentation , using a configurable IoU threshold .
@ -987,23 +1051,33 @@ class CellSegmentator:
ground_truth_masks ( np . ndarray ) : Ground truth instance masks of shape ( B , C , H , W ) .
iou_threshold ( float ) : Intersection - over - Union threshold for matching predictions
to ground truths ( default : 0.5 ) .
return_error_masks ( bool ) : Whether to also return binary error masks .
Returns :
Tuple [ np . ndarray , np . ndarray , np . ndarray ] :
Tuple ( np . ndarray , np . ndarray , np . ndarray ,
np . ndarray | None , np . ndarray | None , np . ndarray | None ) :
- tp : True positives per batch and class , shape ( B , C )
- fp : False positives per batch and class , shape ( B , C )
- fn : False negatives per batch and class , shape ( B , C )
- tp_maks : True positives mask per batch and class , shape ( B , C , H , W )
- fp_maks : False positives mask per batch and class , shape ( B , C , H , W )
- fn_maks : False negatives mask per batch and class , shape ( B , C , H , W )
"""
stats = compute_batch_segmentation_tp_fp_fn (
batch_ground_truth = ground_truth_masks ,
batch_prediction = predicted_masks ,
iou_threshold = iou_threshold ,
return_error_masks = return_error_masks ,
remove_boundary_objects = True
)
tp = stats [ " tp " ]
fp = stats [ " fp " ]
fn = stats [ " fn " ]
return tp , fp , fn
tp_mask = stats [ " tp_mask " ] if return_error_masks else None
fp_mask = stats [ " fp_mask " ] if return_error_masks else None
fn_mask = stats [ " fn_mask " ] if return_error_masks else None
return tp , fp , fn , tp_mask , fp_mask , fn_mask
def __compute_f1_metric (
@ -1011,7 +1085,7 @@ class CellSegmentator:
true_positives : np . ndarray ,
false_positives : np . ndarray ,
false_negatives : np . ndarray ,
reduction : Literal [ " micro " , " macro " , " weighted " , " imagewise " , " none" ] = " micro "
reduction : Literal [ " micro " , " macro " , " weighted " , " imagewise " , " per_class" , " none" ] = " micro "
) - > Union [ float , np . ndarray ] :
"""
Compute F1 - score from batch - wise TP / FP / FN using various aggregation schemes .
@ -1023,8 +1097,9 @@ class CellSegmentator:
reduction :
- ' none ' : return F1 for each sample , class → shape ( batch_size , num_classes )
- ' micro ' : global F1 over all samples & classes
- ' imagewise ' : F1 per sample ( summing over classes ) , then average over samples
- ' macro ' : average class - wise F1 ( classes summed over batch )
- ' imagewise ' : F1 per sample ( summing over classes ) , then average over samples
- ' per_class ' : F1 per class ( summing over batch ) , return vector of shape ( num_classes , )
- ' weighted ' : class - wise F1 weighted by support ( TP + FN )
Returns :
float for reductions ' micro ' , ' imagewise ' , ' macro ' , ' weighted ' ;
@ -1040,7 +1115,11 @@ class CellSegmentator:
tp_val = int ( true_positives [ i , c ] )
fp_val = int ( false_positives [ i , c ] )
fn_val = int ( false_negatives [ i , c ] )
_ , _ , f1_val = compute_f1_score ( tp_val , fp_val , fn_val )
_ , _ , f1_val = compute_f1_score (
tp_val ,
fp_val ,
fn_val
)
f1_matrix [ i , c ] = f1_val
return f1_matrix
@ -1049,7 +1128,11 @@ class CellSegmentator:
tp_total = int ( true_positives . sum ( ) )
fp_total = int ( false_positives . sum ( ) )
fn_total = int ( false_negatives . sum ( ) )
_ , _ , f1_global = compute_f1_score ( tp_total , fp_total , fn_total )
_ , _ , f1_global = compute_f1_score (
tp_total ,
fp_total ,
fn_total
)
return f1_global
# 3) Imagewise: compute per-sample F1 (sum over classes), then average
@ -1059,16 +1142,31 @@ class CellSegmentator:
tp_i = int ( true_positives [ i ] . sum ( ) )
fp_i = int ( false_positives [ i ] . sum ( ) )
fn_i = int ( false_negatives [ i ] . sum ( ) )
_ , _ , f1_i = compute_f1_score ( tp_i , fp_i , fn_i )
_ , _ , f1_i = compute_f1_score (
tp_i ,
fp_i ,
fn_i
)
f1_per_image [ i ] = f1_i
return float ( f1_per_image . mean ( ) )
# For macro/weighted, first aggregate per class across the batch
# Aggregate per class across the batch for per_class, macro, weighted
tp_per_class = true_positives . sum ( axis = 0 ) . astype ( int ) # shape (num_classes,)
fp_per_class = false_positives . sum ( axis = 0 ) . astype ( int )
fn_per_class = false_negatives . sum ( axis = 0 ) . astype ( int )
# 4) Macro: average F1 across classes equally
# 4) Per-class: compute F1 for each class and return vector
if reduction == " per_class " :
f1_per_class = np . zeros ( num_classes , dtype = float )
for c in range ( num_classes ) :
_ , _ , f1_per_class [ c ] = compute_f1_score (
tp_per_class [ c ] ,
fp_per_class [ c ] ,
fn_per_class [ c ]
)
return f1_per_class
# 5) Macro: average F1 across classes equally
if reduction == " macro " :
f1_per_class = np . zeros ( num_classes , dtype = float )
for c in range ( num_classes ) :
@ -1080,7 +1178,7 @@ class CellSegmentator:
f1_per_class [ c ] = f1_c
return float ( f1_per_class . mean ( ) )
# 5 ) Weighted: class-wise F1 weighted by support = TP + FN
# 6 ) Weighted: class-wise F1 weighted by support = TP + FN
if reduction == " weighted " :
f1_per_class = np . zeros ( num_classes , dtype = float )
support = np . zeros ( num_classes , dtype = float )
@ -1088,7 +1186,11 @@ class CellSegmentator:
tp_c = tp_per_class [ c ]
fp_c = fp_per_class [ c ]
fn_c = fn_per_class [ c ]
_ , _ , f1_c = compute_f1_score ( tp_c , fp_c , fn_c )
_ , _ , f1_c = compute_f1_score (
tp_c ,
fp_c ,
fn_c
)
f1_per_class [ c ] = f1_c
support [ c ] = tp_c + fn_c
total_support = support . sum ( )
@ -1106,7 +1208,7 @@ class CellSegmentator:
true_positives : np . ndarray ,
false_positives : np . ndarray ,
false_negatives : np . ndarray ,
reduction : Literal [ " micro " , " macro " , " weighted " , " imagewise " , " none " ] = " micro "
reduction : Literal [ " micro " , " macro " , " weighted " , " imagewise " , ' per_class ' , " none " ] = " micro "
) - > Union [ float , np . ndarray ] :
"""
Compute Average Precision ( AP ) from batch - wise TP / FP / FN using various aggregation schemes .
@ -1121,8 +1223,9 @@ class CellSegmentator:
reduction :
- ' none ' : return AP for each sample and class → shape ( batch_size , num_classes )
- ' micro ' : global AP over all samples & classes
- ' imagewise ' : AP per sample ( summing stats over classes ) , then average over batch
- ' macro ' : average class - wise AP ( each class summed over batch )
- ' imagewise ' : AP per sample ( summing stats over classes ) , then average over batch
- ' per_class ' : AP per class ( summing over batch ) , return vector of shape ( num_classes , )
- ' weighted ' : class - wise AP weighted by support ( TP + FN )
Returns :
@ -1139,7 +1242,11 @@ class CellSegmentator:
tp_val = int ( true_positives [ i , c ] )
fp_val = int ( false_positives [ i , c ] )
fn_val = int ( false_negatives [ i , c ] )
ap_val = compute_average_precision_score ( tp_val , fp_val , fn_val )
ap_val = compute_average_precision_score (
tp_val ,
fp_val ,
fn_val
)
ap_matrix [ i , c ] = ap_val
return ap_matrix
@ -1148,7 +1255,11 @@ class CellSegmentator:
tp_total = int ( true_positives . sum ( ) )
fp_total = int ( false_positives . sum ( ) )
fn_total = int ( false_negatives . sum ( ) )
return compute_average_precision_score ( tp_total , fp_total , fn_total )
return compute_average_precision_score (
tp_total ,
fp_total ,
fn_total
)
# 3) Imagewise: compute per-sample AP (sum over classes), then mean
if reduction == " imagewise " :
@ -1157,7 +1268,11 @@ class CellSegmentator:
tp_i = int ( true_positives [ i ] . sum ( ) )
fp_i = int ( false_positives [ i ] . sum ( ) )
fn_i = int ( false_negatives [ i ] . sum ( ) )
ap_per_image [ i ] = compute_average_precision_score ( tp_i , fp_i , fn_i )
ap_per_image [ i ] = compute_average_precision_score (
tp_i ,
fp_i ,
fn_i
)
return float ( ap_per_image . mean ( ) )
# For macro and weighted: first aggregate per class across batch
@ -1165,7 +1280,18 @@ class CellSegmentator:
fp_per_class = false_positives . sum ( axis = 0 ) . astype ( int )
fn_per_class = false_negatives . sum ( axis = 0 ) . astype ( int )
# 4) Macro: average AP across classes equally
# 4) Per-class: compute F1 for each class and return vector
if reduction == " per_class " :
ap_per_class = np . zeros ( num_classes , dtype = float )
for c in range ( num_classes ) :
ap_per_class [ c ] = compute_average_precision_score (
tp_per_class [ c ] ,
fp_per_class [ c ] ,
fn_per_class [ c ]
)
return ap_per_class
# 5) Macro: average AP across classes equally
if reduction == " macro " :
ap_per_class = np . zeros ( num_classes , dtype = float )
for c in range ( num_classes ) :
@ -1176,7 +1302,7 @@ class CellSegmentator:
)
return float ( ap_per_class . mean ( ) )
# 5 ) Weighted: class-wise AP weighted by support = TP + FN
# 6 ) Weighted: class-wise AP weighted by support = TP + FN
if reduction == " weighted " :
ap_per_class = np . zeros ( num_classes , dtype = float )
support = np . zeros ( num_classes , dtype = float )
@ -1184,7 +1310,11 @@ class CellSegmentator:
tp_c = tp_per_class [ c ]
fp_c = fp_per_class [ c ]
fn_c = fn_per_class [ c ]
ap_per_class [ c ] = compute_average_precision_score ( tp_c , fp_c , fn_c )
ap_per_class [ c ] = compute_average_precision_score (
tp_c ,
fp_c ,
fn_c
)
support [ c ] = tp_c + fn_c
total_support = support . sum ( )
if total_support == 0 :
@ -1215,87 +1345,159 @@ class CellSegmentator:
sample : Dict [ str , Any ] ,
predicted_mask : Union [ np . ndarray , torch . Tensor ] ,
start_index : int = 0 ,
only_masks : bool = False ,
masks : Optional [ Tuple [ np . ndarray , np . ndarray , np . ndarray ] ] = None
) - > None :
"""
Save multi - channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders .
Save multi - channel predicted masks as TIFFs and
corresponding visualizations as PNGs in separate folders .
Args :
sample ( Dict [ str , Any ] ) : Batch sample from MONAI LoadImaged ( contains ' image ' , optional ' mask ' , and ' image_meta_dict ' ) .
sample ( Dict [ str , Any ] ) : Batch sample from MONAI
LoadImaged ( contains ' image ' , optional ' mask ' , and ' image_meta_dict ' ) .
predicted_mask ( np . ndarray or torch . Tensor ) : Array of shape ( C , H , W ) or ( B , C , H , W ) .
start_index ( int ) : Starting index for naming when metadata is missing .
only_masks ( bool ) : If True , save only the raw predicted mask TIFFs and skip PNG visualizations .
masks ( Tuple [ np . ndarray , np . ndarray , np . ndarray ] | None ) :
A tuple ( tp_masks , fp_masks , fn_masks ) , each of shape ( B , C , H , W ) . Defaults to None .
"""
# Determine base paths
# Base directories (created once per call)
base_output_dir = self . _dataset_setup . common . predictions_dir
masks_dir = base_output_dir
plots_dir = os . path . join ( base_output_dir , " plots " )
evaluate_dir = os . path . join ( plots_dir , " evaluate " )
os . makedirs ( masks_dir , exist_ok = True )
os . makedirs ( plots_dir , exist_ok = True )
os . makedirs ( evaluate_dir , exist_ok = True )
# Extract image (C, H, W) or batch of images (B, C, H, W), and metadata
image_obj = sample . get ( " image " ) # Expected shape: (C, H, W) or (B, C, H, W)
mask_obj = sample . get ( " mask " ) # Expected shape: (C, H, W) or (B, C, H, W)
image_meta = sample . get ( " image_meta_dict " )
# Convert tensors to numpy
# Convert tensors to numpy if necessary
def to_numpy ( x : Union [ np . ndarray , torch . Tensor ] ) - > np . ndarray :
if isinstance ( x , torch . Tensor ) :
return x . cpu ( ) . numpy ( )
return x
image_array = to_numpy ( image_obj ) if image_obj is not None else None
mask_array = to_numpy ( mask_obj ) if mask_obj is not None else None
pred_array = to_numpy ( predicted_mask )
# Handle batch dimension: (B, C, H, W)
if pred_array . ndim == 4 :
for idx in range ( pred_array . shape [ 0 ] ) :
batch_sample : Dict [ str , Any ] = { }
if image_array is not None and image_array . ndim == 4 :
batch_sample [ " image " ] = image_array [ idx ]
if isinstance ( image_meta , dict ) and " filename_or_obj " in image_meta :
batch_sample [ " image_meta_dict " ] = image_meta [ " filename_or_obj " ] [ idx ]
if mask_array is not None and mask_array . ndim == 4 :
batch_sample [ " mask " ] = mask_array [ idx ]
self . __save_prediction_masks (
batch_sample ,
pred_array [ idx ] ,
start_index = start_index + idx
)
return
return x . cpu ( ) . numpy ( ) if isinstance ( x , torch . Tensor ) else x
pred_array = to_numpy ( predicted_mask ) . astype ( np . uint16 )
# Handle batch dimension
for idx in range ( pred_array . shape [ 0 ] ) :
batch_sample : Dict [ str , Any ] = { }
# copy per-sample image and meta
img = to_numpy ( sample [ " image " ] )
if img . ndim == 4 :
batch_sample [ " image " ] = img [ idx ]
if " mask " in sample :
msk = to_numpy ( sample [ " mask " ] ) . astype ( np . uint16 )
if msk . ndim == 4 :
batch_sample [ " mask " ] = msk [ idx ]
image_meta = sample . get ( " image_meta_dict " )
if isinstance ( image_meta , dict ) and " filename_or_obj " in image_meta :
fname = image_meta [ " filename_or_obj " ] [ idx ]
batch_sample [ " image_name " ] = fname
single_masks = (
( masks [ 0 ] [ idx ] , masks [ 1 ] [ idx ] , masks [ 2 ] [ idx ] ) if masks is not None else None
)
self . __save_single_prediction_mask (
sample = batch_sample ,
pred_array = pred_array [ idx ] ,
start_index = start_index + idx ,
masks_dir = masks_dir ,
plots_dir = plots_dir ,
evaluate_dir = evaluate_dir ,
only_masks = only_masks ,
masks = single_masks ,
)
def __save_single_prediction_mask (
self ,
sample : Dict [ str , Any ] ,
pred_array : np . ndarray ,
start_index : int ,
masks_dir : str ,
plots_dir : str ,
evaluate_dir : str ,
only_masks : bool = False ,
masks : Optional [ Tuple [ np . ndarray , np . ndarray , np . ndarray ] ] = None
) - > None :
"""
Save a single sample ' s predicted mask and optional TP/FP/FN masks and visualizations.
Assumes output directories already exist .
# Determine base filename
Args :
sample ( Dict [ str , Any ] ) : Dictionary containing ' image ' , ' mask ' ,
and optional ' image_meta_dict ' for metadata .
pred_array ( np . ndarray ) : Predicted mask array of shape ( C , H , W ) .
start_index ( int ) : Base index for generating filenames when metadata is missing .
masks_dir ( str ) : Directory for saving TIFF masks .
plots_dir ( str ) : Directory for saving PNG visualizations .
evaluate_dir ( str ) : Directory for saving PNG visualizations of evaluation results .
only_masks ( bool ) : If True , saves only TIFF mask files ; skips PNG plots .
masks ( Tuple [ np . ndarray , np . ndarray , np . ndarray ] , optional ) : A tuple of
true - positive , false - positive , and false - negative mask arrays ,
each of shape ( C , H , W ) . Defaults to None .
"""
if pred_array . ndim == 2 :
pred_array = np . expand_dims ( pred_array , axis = 0 )
elif pred_array . ndim != 3 :
raise ValueError (
f " Unsupported predicted_mask dimensions: { pred_array . ndim } . "
" Expected 2D (H,W) or 3D (C,H,W). "
)
# Handle image array if present
image_array : np . ndarray = sample [ " image " ]
if image_array . ndim == 2 :
image_array = np . expand_dims ( image_array , axis = 0 )
elif image_array . ndim != 3 :
raise ValueError (
f " Unsupported image dimensions: { image_array . ndim } . "
" Expected 2D (H,W) or 3D (C,H,W). "
)
true_mask_array : Optional [ np . ndarray ] = sample . get ( " mask " )
if isinstance ( true_mask_array , np . ndarray ) :
if true_mask_array . ndim == 2 :
true_mask_array = np . expand_dims ( true_mask_array , axis = 0 )
elif true_mask_array . ndim != 3 :
raise ValueError (
f " Unsupported true_mask_array dimensions: { true_mask_array . ndim } . "
" Expected 2D (H,W) or 3D (C,H,W). "
)
# Determine filename base
image_meta = sample . get ( " image_name " )
if isinstance ( image_meta , ( str , os . PathLike ) ) :
base_name = os . path . splitext ( os . path . basename ( image_meta ) ) [ 0 ]
else :
# Use provided start_index when metadata missing
base_name = f " prediction_ { start_index : 04d } "
# Save mask TIFF (16-bit)
mask_filename = f " { base_name } _mask.tif "
mask_path = os . path . join ( masks_dir , mask_filename )
# Save main mask TIFF
mask_path = os . path . join ( masks_dir , f " { base_name } _mask.tif " )
tiff . imwrite ( mask_path , pred_array . astype ( np . uint16 ) , compression = " zlib " )
# Now pred_array shape is (C, H, W)
num_channels = pred_array . shape [ 0 ]
for channel_idx in range ( num_channels ) :
channel_mask = pred_array [ channel_idx ]
# File names
plot_filename = f " { base_name } _ch { channel_idx : 01d } .png "
plot_path = os . path . join ( plots_dir , plot_filename )
# Extract corresponding true mask channel if exists
true_mask = None
if mask_array is not None and mask_array . ndim == 3 :
true_mask = mask_array [ channel_idx ]
if only_masks :
return
# Generate and save visualization
# Save channel-wise plots
num_channels = pred_array . shape [ 0 ]
for ch in range ( num_channels ) :
true_ch = true_mask_array [ ch ] if true_mask_array is not None else None
self . __plot_mask (
file_path = plot_path ,
image_data = image_array , # type: ignore
predicted_mask = channel_mask ,
true_mask = true_mask ,
file_path = os . path . join ( plots_dir , f " { base_name } _ch { ch } .png " ) ,
image_data = image_array ,
predicted_mask = pred_array [ ch ] ,
true_mask = true_ch ,
)
if masks is not None and true_ch is not None :
self . __save_mask_comparison_visuals (
gt = true_ch ,
pred = pred_array [ ch ] ,
tp_mask = masks [ 0 ] [ ch ] ,
fp_mask = masks [ 1 ] [ ch ] ,
fn_mask = masks [ 2 ] [ ch ] ,
file_path = os . path . join ( evaluate_dir , f " { base_name } _ch { ch } .png " )
)
def __plot_mask (
@ -1307,6 +1509,16 @@ class CellSegmentator:
) - > None :
"""
Create and save grid visualization : 1 x3 if no true mask , or 2 x3 if true mask provided .
Args :
file_path ( str ) : Path where the visualization image will be saved .
image_data ( np . ndarray ) : The original input image array , expected shape ( C , H , W ) .
predicted_mask ( np . ndarray ) : The predicted mask array , shape ( H , W ) ,
depending on the task .
true_mask ( Optional [ np . ndarray ] , optional ) : The ground - truth mask array .
If provided , an additional row with true mask and overlap visualization
will be added to the plot . Default is None .
"""
img = np . moveaxis ( image_data , 0 , - 1 ) if image_data . ndim == 3 else image_data
@ -1317,7 +1529,7 @@ class CellSegmentator:
( ' Original Image ' , ' Predicted Mask ' , ' Predicted Contours ' ) )
else :
fig , axs = plt . subplots ( 2 , 3 , figsize = ( 15 , 10 ) )
plt . subplots_adjust ( wspace = 0.02 , hspace = 0. 1 )
plt . subplots_adjust ( wspace = 0.02 , hspace = 0. 02 )
# row 0: predicted
self . __plot_panels ( axs [ 0 ] , img , predicted_mask , ' red ' ,
( ' Original Image ' , ' Predicted Mask ' , ' Predicted Contours ' ) )
@ -1370,6 +1582,69 @@ class CellSegmentator:
ax2 . contour ( boundaries , colors = contour_color , linewidths = 0.5 )
ax2 . set_title ( titles [ 2 ] )
ax2 . axis ( ' off ' )
def __save_mask_comparison_visuals (
self ,
gt : np . ndarray ,
pred : np . ndarray ,
tp_mask : np . ndarray ,
fp_mask : np . ndarray ,
fn_mask : np . ndarray ,
file_path : str
) - > None :
"""
Creates and saves a 1 x3 subplot figure showing :
1 ) True mask with boundaries
2 ) Predicted mask without boundaries
3 ) Overlay mask combining FP ( R ) , TP ( G ) , FN ( B )
Args :
gt ( np . ndarray ) : Ground truth mask ( H , W ) .
pred ( np . ndarray ) : Predicted mask ( H , W ) .
tp_mask ( np . ndarray ) : True positive mask ( H , W ) .
fp_mask ( np . ndarray ) : False positive mask ( H , W ) .
fn_mask ( np . ndarray ) : False negative mask ( H , W ) .
file_path ( str ) : Path where the visualization image will be saved .
"""
# Prepare overlay mask
overlap_mask = np . zeros ( ( * gt . shape [ : 2 ] , 3 ) , dtype = np . uint8 )
overlap_mask [ . . . , 0 ] = np . where ( fp_mask , 255 , 0 )
overlap_mask [ . . . , 1 ] = np . where ( tp_mask , 255 , 0 )
overlap_mask [ . . . , 2 ] = np . where ( fn_mask , 255 , 0 )
# Set up figure
fig , axes = plt . subplots ( 1 , 3 , figsize = ( 15 , 5 ) ,
gridspec_kw = { ' width_ratios ' : [ 1 , 1 , 1 ] } )
plt . subplots_adjust ( wspace = 0.02 , hspace = 0.0 ,
left = 0.05 , right = 0.95 , top = 0.95 , bottom = 0.05 )
# Colormap for instances
num_instances = max ( np . max ( gt ) , np . max ( pred ) )
cmap = plt . get_cmap ( " gist_ncar " )
colors = [ cmap ( i / num_instances ) for i in range ( num_instances ) ]
cmap = mcolors . ListedColormap ( colors )
# Plot true mask
axes [ 0 ] . imshow ( gt , cmap = cmap )
axes [ 0 ] . contour ( find_boundaries ( gt , mode = " thick " ) , colors = " black " , linewidths = 0.5 )
axes [ 0 ] . set_title ( " True Mask " )
axes [ 0 ] . axis ( " off " )
# Plot predicted mask
axes [ 1 ] . imshow ( pred , cmap = cmap )
axes [ 1 ] . contour ( find_boundaries ( pred , mode = " thick " ) , colors = " black " , linewidths = 0.5 )
axes [ 1 ] . set_title ( " Predicted Mask " )
axes [ 1 ] . axis ( " off " )
# Plot overlay
axes [ 2 ] . imshow ( overlap_mask )
axes [ 2 ] . set_title ( " Overlay Mask (R-FP; G-TP; B-FN) " )
axes [ 2 ] . axis ( " off " )
# Save
plt . savefig ( file_path , bbox_inches = " tight " , dpi = 300 )
plt . close ( )
def __compute_flows_from_masks (
@ -1522,7 +1797,7 @@ class CellSegmentator:
flow_output = np . zeros ( ( 2 , height , width ) , dtype = np . float32 )
ys_np = y . cpu ( ) . numpy ( ) - 1
xs_np = x . cpu ( ) . numpy ( ) - 1
flow_output [ : , ys_np , xs_np ] = mu
flow_output [ : , ys_np , xs_np ] = mu . reshape ( 2 , - 1 )
flows [ 2 * channel : 2 * channel + 2 ] = flow_output
return flows