@ -19,15 +19,14 @@ from torch.utils.data import DataLoader
import fastremap
import fill_voids
from skimage import morphology
# from skimage import morphology
from skimage . segmentation import find_boundaries
from scipy . special import expit
from scipy . ndimage import mean , find_objects
from monai . data . dataset import Dataset
from monai . transforms import * # type: ignor e
from monai . transforms . compose import Compos e
from monai . inferers . utils import sliding_window_inference
from monai . metrics . cumulative_average import CumulativeAverage
import matplotlib . pyplot as plt
import matplotlib . colors as mcolors
@ -42,16 +41,16 @@ from itertools import chain
from pprint import pformat
from tabulate import tabulate
from typing import Any , Dict, Literal, Optional , Tuple , List , Union
from typing import Any , Literal
from tqdm import tqdm
import wandb
from config import Config
from core . models import *
from core . losses import *
from core . optimizers import *
from core . schedulers import *
from core . models import ModelRegistry
from core . losses import CriterionRegistry
from core . optimizers import OptimizerRegistry
from core . schedulers import SchedulerRegistry
from core . utils import (
compute_batch_segmentation_tp_fp_fn ,
compute_f1_score ,
@ -78,30 +77,30 @@ class CellSegmentator:
else None
)
self . _train_dataloader : Optional[ DataLoader ] = None
self . _valid_dataloader : Optional[ DataLoader ] = None
self . _test_dataloader : Optional[ DataLoader ] = None
self . _predict_dataloader : Optional[ DataLoader ] = None
self . _train_dataloader : DataLoader | None = None
self . _valid_dataloader : DataLoader | None = None
self . _test_dataloader : DataLoader | None = None
self . _predict_dataloader : DataLoader | None = None
self . _best_weights = None
def create_dataloaders (
self ,
train_transforms : Optional[ Compose ] = None ,
valid_transforms : Optional[ Compose ] = None ,
test_transforms : Optional[ Compose ] = None ,
predict_transforms : Optional[ Compose ] = None
train_transforms : Compose | None = None ,
valid_transforms : Compose | None = None ,
test_transforms : Compose | None = None ,
predict_transforms : Compose | None = None
) - > None :
"""
Creates train , validation , test , and prediction dataloaders based on dataset configuration
and provided transforms .
Args :
train_transforms ( Optional[ Compose ] ) : Transformations for training data .
valid_transforms ( Optional[ Compose ] ) : Transformations for validation data .
test_transforms ( Optional[ Compose ] ) : Transformations for testing data .
predict_transforms ( Optional[ Compose ] ) : Transformations for prediction data .
train_transforms ( Compose | None ) : Transformations for training data .
valid_transforms ( Compose | None ) : Transformations for validation data .
test_transforms ( Compose | None ) : Transformations for testing data .
predict_transforms ( Compose | None ) : Transformations for prediction data .
Raises :
ValueError : If required transforms are missing .
@ -257,7 +256,7 @@ class CellSegmentator:
def print_data_info (
self ,
loader_type : Literal [ " train " , " valid " , " test " , " predict " ] ,
index : Optional [ int ] = None
index : int | None = None
) - > None :
"""
Prints statistics for a single sample from the specified dataloader .
@ -267,7 +266,7 @@ class CellSegmentator:
index : The sample index ; if None , a random index is selected .
"""
# Retrieve the dataloader attribute, e.g., self._train_dataloader
loader : Optional[ torch . utils . data . DataLoader ] = getattr ( self , f " _ { loader_type } _dataloader " , None )
loader : DataLoader | None = getattr ( self , f " _ { loader_type } _dataloader " , None )
if loader is None :
logger . error ( f " Dataloader ' { loader_type } ' is not initialized. " )
return
@ -326,8 +325,8 @@ class CellSegmentator:
lines . append ( " = " * 40 )
# Output via logger
for l in lines :
logger . info ( l )
for l ine in lines :
logger . info ( l ine )
def train ( self , save_results : bool = True , only_masks : bool = False ) - > None :
@ -661,16 +660,16 @@ class CellSegmentator:
logger . info ( f " ├─ Validation frequency: { training . val_freq } " )
if training . is_split :
logger . info ( f " ├─ Using pre-split directories: " )
logger . info ( " ├─ Using pre-split directories: " )
logger . info ( f " │ ├─ Train dir: { training . pre_split . train_dir } " )
logger . info ( f " │ ├─ Valid dir: { training . pre_split . valid_dir } " )
logger . info ( f " │ └─ Test dir: { training . pre_split . test_dir } " )
else :
logger . info ( f " ├─ Using unified dataset with splits: " )
logger . info ( f " │ ├─ All data dir: { training . split . all_data_dir }" )
logger . info ( " ├─ Using unified dataset with splits: " )
logger . info ( " │ ├─ All data dir: { training.split.all_data_dir}" )
logger . info ( f " │ └─ Shuffle: { ' yes ' if training . split . shuffle else ' no ' } " )
logger . info ( f " └─ Dataset split: " )
logger . info ( " └─ Dataset split: " )
logger . info ( f " ├─ Train size: { training . train_size } , offset: { training . train_offset } " )
logger . info ( f " ├─ Valid size: { training . valid_size } , offset: { training . valid_offset } " )
logger . info ( f " └─ Test size: { training . test_size } , offset: { training . test_offset } " )
@ -703,12 +702,12 @@ class CellSegmentator:
logger . info ( " =================================== " )
def __set_seed ( self , seed : Optional [ int ] ) - > None :
def __set_seed ( self , seed : int | None ) - > None :
"""
Sets the random seed for reproducibility across Python , NumPy , and PyTorch .
Args :
seed ( Optional [ int ] ) : Seed value . If None , no seeding is performed .
seed ( int | None ) : Seed value . If None , no seeding is performed .
"""
if seed is not None :
random . seed ( seed )
@ -724,9 +723,9 @@ class CellSegmentator:
def __get_dataset (
self ,
images_dir : str ,
masks_dir : Optional [ str ] ,
masks_dir : str | None ,
transforms : Compose ,
size : Union [ int , float ] ,
size : int | float ,
offset : int ,
shuffle : bool
) - > Dataset :
@ -735,9 +734,9 @@ class CellSegmentator:
Args :
images_dir ( str ) : Path to directory or glob pattern for input images .
masks_dir ( Optional [ str ] ) : Path to directory or glob pattern for masks .
masks_dir ( str | None ) : Path to directory or glob pattern for masks .
transforms ( Compose ) : Transformations to apply to each image or pair .
size ( Union [ int , float ] ) : Either an integer or a fraction of the dataset .
size ( int | float ) : Either an integer or a fraction of the dataset .
offset ( int ) : Number of images to skip from the start .
shuffle ( bool ) : Whether to shuffle the dataset before slicing .
@ -806,12 +805,12 @@ class CellSegmentator:
return Dataset ( data , transforms )
def __print_with_logging ( self , metrics : Dict [ str , Union [ float , np . ndarray ] ] , step : int ) - > None :
def __print_with_logging ( self , metrics : dict [ str , float | np . ndarray ] , step : int ) - > None :
"""
Print metrics in a tabular format and log to W & B .
Args :
metrics ( Dict [ str , Union [ float , np . ndarray ] ] ) : Mapping from metric names
metrics ( dict ( str , float | np . ndarray ) ) : Mapping from metric names
to either a float or a ND numpy array .
step ( int ) : epoch index .
"""
@ -846,14 +845,14 @@ class CellSegmentator:
def __save_metrics_to_csv (
self ,
metrics : Dict [ str , Union [ float , np . ndarray ] ] ,
metrics : dict [ str , float | np . ndarray ] ,
output_path : str
) - > None :
"""
Saves a dictionary of metrics to a CSV file with columns ' Metric ' and ' Value ' .
Args :
metrics ( Dict [ str , Union [ float , np . ndarray ] ] ) :
metrics ( dict ( str , float | np . ndarray ) ) :
Mapping from metric names to scalar values or numpy arrays .
output_path ( str ) :
Path to the output CSV file .
@ -874,22 +873,22 @@ class CellSegmentator:
def __run_epoch ( self ,
mode : Literal [ " train " , " valid " , " test " ] ,
epoch : Optional [ int ] = None ,
epoch : int | None = None ,
save_results : bool = True ,
only_masks : bool = False
) - > Dict [ str , Union [ float , np . ndarray ] ] :
) - > dict [ str , float | np . ndarray ] :
"""
Execute one epoch of training , validation , or testing .
Args :
mode ( str ) : One of ' train ' , ' valid ' , or ' test ' .
epoch ( int , optional ) : Current epoch number for logging .
epoch ( int | None ) : Current epoch number for logging .
save_results ( bool ) : If True , the predicted masks and test metrics will be saved .
only_masks ( bool ) : If True and save_results is True , only raw predicted masks are saved ,
without visualization overlays .
Returns :
Dict [ str , Union [ float , np . ndarray ] ] : Metrics for valid / test .
dict ( str , float | np . ndarray ) : Metrics for valid / test .
"""
# Ensure required components are available
if mode in ( " train " , " valid " ) and ( self . _optimizer is None or self . _criterion is None ) :
@ -988,7 +987,7 @@ class CellSegmentator:
if self . _criterion is not None :
# Collect loss metrics
epoch_metrics : Dict [ str , Union [ float , np . ndarray ] ] = {
epoch_metrics : dict [ str , float | np . ndarray ] = {
f " { mode } _ { name } " : value for name , value in self . _criterion . get_loss_metrics ( ) . items ( )
}
# Reset internal loss metrics accumulator
@ -1051,17 +1050,17 @@ class CellSegmentator:
def __post_process_predictions (
self ,
raw_outputs : torch . Tensor ,
ground_truth : Optional[ torch . Tensor ] = None
) - > Tuple [ np . ndarray , Optional [ np . ndarray ] ] :
ground_truth : torch. Tensor | None = None
) - > tuple [ np . ndarray , np . ndarray | None ] :
"""
Post - process raw network outputs to extract instance segmentation masks .
Args :
raw_outputs ( torch . Tensor ) : Raw model outputs of shape ( B , С , H , W ) .
ground_truth ( torch . Tensor ) : Ground truth masks of shape ( B , С , H , W ) .
ground_truth ( torch . Tensor | None ) : Ground truth masks of shape ( B , С , H , W ) .
Returns :
Tuple [ np . ndarray , Optional [ np . ndarray ] ] :
tuple ( np . ndarray , np . ndarray | None ) :
- instance_masks : Instance - wise masks array of shape ( B , С , H , W ) .
- labels_np : Converted ground truth of shape ( B , С , H , W ) or None if
ground_truth was not provided .
@ -1097,8 +1096,8 @@ class CellSegmentator:
ground_truth_masks : np . ndarray ,
iou_threshold : float = 0.5 ,
return_error_masks : bool = False
) - > T uple[ np . ndarray , np . ndarray , np . ndarray ,
Optional[ np . ndarray ] , Optional [ np . ndarray ] , Optional [ np . ndarray ] ] :
) - > t uple[ np . ndarray , np . ndarray , np . ndarray ,
np. ndarray | None , np . ndarray | None , np . ndarray | None ] :
"""
Compute batch - wise true positives , false positives , and false negatives
for instance segmentation , using a configurable IoU threshold .
@ -1111,7 +1110,7 @@ class CellSegmentator:
return_error_masks ( bool ) : Whether to also return binary error masks .
Returns :
T uple( np . ndarray , np . ndarray , np . ndarray ,
t uple( np . ndarray , np . ndarray , np . ndarray ,
np . ndarray | None , np . ndarray | None , np . ndarray | None ) :
- tp : True positives per batch and class , shape ( B , C )
- fp : False positives per batch and class , shape ( B , C )
@ -1143,7 +1142,7 @@ class CellSegmentator:
false_positives : np . ndarray ,
false_negatives : np . ndarray ,
reduction : Literal [ " micro " , " macro " , " weighted " , " imagewise " , " per_class " , " none " ] = " micro "
) - > Union [ float , np . ndarray ] :
) - > float | np . ndarray :
"""
Compute F1 - score from batch - wise TP / FP / FN using various aggregation schemes .
@ -1266,7 +1265,7 @@ class CellSegmentator:
false_positives : np . ndarray ,
false_negatives : np . ndarray ,
reduction : Literal [ " micro " , " macro " , " weighted " , " imagewise " , ' per_class ' , " none " ] = " micro "
) - > Union [ float , np . ndarray ] :
) - > float | np . ndarray :
"""
Compute Average Precision ( AP ) from batch - wise TP / FP / FN using various aggregation schemes .
@ -1399,23 +1398,23 @@ class CellSegmentator:
def __save_prediction_masks (
self ,
sample : D ict[ str , Any ] ,
predicted_mask : Union[ np . ndarray , torch . Tensor ] ,
sample : d ict[ str , Any ] ,
predicted_mask : np. ndarray | torch . Tensor ,
start_index : int = 0 ,
only_masks : bool = False ,
masks : Optional [ T uple[ np . ndarray , np . ndarray , np . ndarray ] ] = None
masks : t uple[ np . ndarray , np . ndarray , np . ndarray ] | None = None
) - > None :
"""
Save multi - channel predicted masks as TIFFs and
corresponding visualizations as PNGs in separate folders .
Args :
sample ( Dict [ str , Any ] ) : Batch sample from MONAI
sample ( dict ( str , Any ) ) : Batch sample from MONAI
LoadImaged ( contains ' image ' , optional ' mask ' , and ' image_meta_dict ' ) .
predicted_mask ( np . ndarray or torch . Tensor ) : Array of shape ( C , H , W ) or ( B , C , H , W ) .
predicted_mask ( np . ndarray | torch . Tensor ) : Array of shape ( C , H , W ) or ( B , C , H , W ) .
start_index ( int ) : Starting index for naming when metadata is missing .
only_masks ( bool ) : If True , save only the raw predicted mask TIFFs and skip PNG visualizations .
masks ( T uple[ np . ndarray , np . ndarray , np . ndarray ] | None ) :
masks ( t uple[ np . ndarray , np . ndarray , np . ndarray ] | None ) :
A tuple ( tp_masks , fp_masks , fn_masks ) , each of shape ( B , C , H , W ) . Defaults to None .
"""
# Base directories (created once per call)
@ -1428,14 +1427,14 @@ class CellSegmentator:
os . makedirs ( evaluate_dir , exist_ok = True )
# Convert tensors to numpy if necessary
def to_numpy ( x : Union[ np . ndarray , torch . Tensor ] ) - > np . ndarray :
def to_numpy ( x : np. ndarray | torch . Tensor ) - > np . ndarray :
return x . cpu ( ) . numpy ( ) if isinstance ( x , torch . Tensor ) else x
pred_array = to_numpy ( predicted_mask ) . astype ( np . uint16 )
# Handle batch dimension
for idx in range ( pred_array . shape [ 0 ] ) :
batch_sample : D ict[ str , Any ] = { }
batch_sample : d ict[ str , Any ] = { }
# copy per-sample image and meta
img = to_numpy ( sample [ " image " ] )
if img . ndim == 4 :
@ -1467,21 +1466,21 @@ class CellSegmentator:
def __save_single_prediction_mask (
self ,
sample : D ict[ str , Any ] ,
sample : d ict[ str , Any ] ,
pred_array : np . ndarray ,
start_index : int ,
masks_dir : str ,
plots_dir : str ,
evaluate_dir : str ,
only_masks : bool = False ,
masks : Optional [ T uple[ np . ndarray , np . ndarray , np . ndarray ] ] = None
masks : t uple[ np . ndarray , np . ndarray , np . ndarray ] | None = None
) - > None :
"""
Save a single sample ' s predicted mask and optional TP/FP/FN masks and visualizations.
Assumes output directories already exist .
Args :
sample ( Dict [ str , Any ] ) : Dictionary containing ' image ' , ' mask ' ,
sample ( dict ( str , Any ) ) : Dictionary containing ' image ' , ' mask ' ,
and optional ' image_meta_dict ' for metadata .
pred_array ( np . ndarray ) : Predicted mask array of shape ( C , H , W ) .
start_index ( int ) : Base index for generating filenames when metadata is missing .
@ -1489,7 +1488,7 @@ class CellSegmentator:
plots_dir ( str ) : Directory for saving PNG visualizations .
evaluate_dir ( str ) : Directory for saving PNG visualizations of evaluation results .
only_masks ( bool ) : If True , saves only TIFF mask files ; skips PNG plots .
masks ( T uple[ np . ndarray , np . ndarray , np . ndarray ] , optional ) : A tuple of
masks ( t uple[ np . ndarray , np . ndarray , np . ndarray ] | None ) : A tuple of
true - positive , false - positive , and false - negative mask arrays ,
each of shape ( C , H , W ) . Defaults to None .
"""
@ -1510,7 +1509,7 @@ class CellSegmentator:
" Expected 2D (H,W) or 3D (C,H,W). "
)
true_mask_array : Optional[ np . ndarray ] = sample . get ( " mask " )
true_mask_array : np. ndarray | None = sample . get ( " mask " )
if isinstance ( true_mask_array , np . ndarray ) :
if true_mask_array . ndim == 2 :
true_mask_array = np . expand_dims ( true_mask_array , axis = 0 )
@ -1562,7 +1561,7 @@ class CellSegmentator:
file_path : str ,
image_data : np . ndarray ,
predicted_mask : np . ndarray ,
true_mask : Optional[ np . ndarray ] = None ,
true_mask : np. ndarray | None = None ,
) - > None :
"""
Create and save grid visualization : 1 x3 if no true mask , or 2 x3 if true mask provided .
@ -1572,7 +1571,7 @@ class CellSegmentator:
image_data ( np . ndarray ) : The original input image array , expected shape ( C , H , W ) .
predicted_mask ( np . ndarray ) : The predicted mask array , shape ( H , W ) ,
depending on the task .
true_mask ( Optional[ np . ndarray ] , optional ) : The ground - truth mask array .
true_mask ( np. ndarray | None ) : The ground - truth mask array .
If provided , an additional row with true mask and overlap visualization
will be added to the plot . Default is None .
@ -1603,7 +1602,7 @@ class CellSegmentator:
img : np . ndarray ,
mask : np . ndarray ,
contour_color : str ,
titles : T uple[ str , . . . ]
titles : t uple[ str , . . . ]
) :
"""
Plot a row of three panels : original image , mask , and mask boundaries on image .
@ -1618,7 +1617,8 @@ class CellSegmentator:
# Panel 1: Original image
ax0 , ax1 , ax2 = axes
ax0 . imshow ( img , cmap = ' gray ' if img . ndim == 2 else None )
ax0 . set_title ( titles [ 0 ] ) ; ax0 . axis ( ' off ' )
ax0 . set_title ( titles [ 0 ] )
ax0 . axis ( ' off ' )
# Compute boundaries once
boundaries = find_boundaries ( mask , mode = ' thick ' )
@ -1793,7 +1793,8 @@ class CellSegmentator:
# Get coordinates of all non-zero pixels in the padded mask
y , x = torch . nonzero ( masks_padded , as_tuple = True )
y = y . int ( ) ; x = x . int ( ) # ensure integer type
y = y . int ( )
x = x . int ( ) # ensure integer type
# Generate 8-connected neighbors (including center) via broadcasted offsets
offsets = torch . tensor ( [
@ -1830,9 +1831,12 @@ class CellSegmentator:
] , dtype = np . int16 )
# Compute centers (pixel indices) and extents via the provided helper
centers , ext = self . __get_mask_centers_and_extents ( mask_channel , slices_arr )
centers , ext = self . __get_mask_centers_and_extents (
mask_channel , slices_arr
)
# Move centers to GPU and shift by +1 for padding
meds_p = torch . from_numpy ( centers ) . to ( self . _device ) . long ( ) + 1 # (M, 2); +1 for padding
# (M, 2); +1 for padding
meds_p = torch . from_numpy ( centers ) . to ( self . _device ) . long ( ) + 1
# Determine number of diffusion iterations
n_iter = 2 * ext . max ( )
@ -1865,7 +1869,7 @@ class CellSegmentator:
def __get_mask_centers_and_extents (
label_map : np . ndarray ,
slices_arr : np . ndarray
) - > T uple[ np . ndarray , np . ndarray ] :
) - > t uple[ np . ndarray , np . ndarray ] :
"""
Compute the centroids and extents of labeled regions in a 2 D mask array .
@ -1923,7 +1927,7 @@ class CellSegmentator:
neighbor_indices : torch . Tensor ,
center_indices : torch . Tensor ,
valid_neighbor_mask : torch . Tensor ,
output_shape : T uple[ int , int ] ,
output_shape : t uple[ int , int ] ,
num_iterations : int = 200
) - > np . ndarray :
"""
@ -1933,7 +1937,7 @@ class CellSegmentator:
neighbor_indices ( torch . Tensor ) : Tensor of shape ( 2 , 9 , N ) containing row and column indices for 9 neighbors per pixel .
center_indices ( torch . Tensor ) : Tensor of shape ( 2 , N ) with row and column indices of mask centers .
valid_neighbor_mask ( torch . Tensor ) : Boolean tensor of shape ( 9 , N ) indicating if each neighbor is valid .
output_shape ( Tuple [ int , int ] ) : Desired 2 D shape of the diffusion tensor , e . g . , ( H , W ) .
output_shape ( tuple ( int , int ) ) : Desired 2 D shape of the diffusion tensor , e . g . , ( H , W ) .
num_iterations ( int , optional ) : Number of diffusion iterations . Defaults to 200.
Returns :
@ -2242,7 +2246,7 @@ class CellSegmentator:
flow_field : np . ndarray ,
initial_coords : np . ndarray ,
num_iters : int = 200
) - > Union[ np . ndarray , torch . Tensor ] :
) - > np. ndarray | torch . Tensor :
"""
Trace pixel positions through a flow field via iterative interpolation .
@ -2252,7 +2256,7 @@ class CellSegmentator:
num_iters ( int ) : Number of integration steps .
Returns :
np . ndarray or torch . Tensor : Final ( y , x ) positions of each point .
( np . ndarray | torch . Tensor ) : Final ( y , x ) positions of each point .
"""
dims = 2
# Extract spatial dimensions
@ -2383,7 +2387,7 @@ class CellSegmentator:
self ,
pixel_positions : torch . Tensor ,
valid_indices : np . ndarray ,
original_shape : T uple[ int , . . . ] ,
original_shape : t uple[ int , . . . ] ,
pad_radius : int = 20 ,
max_size_fraction : float = 0.4
) - > np . ndarray :
@ -2534,7 +2538,7 @@ class CellSegmentator:
input_tensor : Tensor ,
kernel_size : int = 5 ,
axis : int = 1 ,
output_tensor : Optional[ Tensor ] = None
output_tensor : Tensor | None = None
) - > Tensor :
"""
Memory - efficient 1 D max pooling along a specified axis using in - place updates .
@ -2547,7 +2551,7 @@ class CellSegmentator:
input_tensor ( Tensor ) : Source tensor for pooling .
kernel_size ( int ) : Size of the pooling window ( must be odd and > = 3 ) .
axis ( int ) : Axis along which to compute 1 D max pooling .
output_tensor ( Optional[ Tensor ] ) : Tensor to store the result .
output_tensor ( Tensor | None ) : Tensor to store the result .
If None , a clone of input_tensor is used .
Returns :
@ -2691,7 +2695,7 @@ class CellSegmentator:
self ,
mask : np . ndarray ,
flow_network : np . ndarray
) - > T uple[ np . ndarray , np . ndarray ] :
) - > t uple[ np . ndarray , np . ndarray ] :
"""
Compute mean squared error between network - predicted flows and flows derived from masks .
@ -2700,7 +2704,7 @@ class CellSegmentator:
flow_network ( np . ndarray ) : Network predicted flows of shape [ axis , . . . ] .
Returns :
Tuple [ np . ndarray , np . ndarray ] :
tuple ( np . ndarray , np . ndarray ) :
- flow_errors : 1 D array ( length = max label ) of mean squared error per label .
- computed_flows : Array of flows derived from the mask , same shape as flow_network .