Fixed bugs that prevented the project from running. Mediar Former values for the F1 metric were achieved.

master
laynholt 2 months ago
parent 4a501ea31a
commit c1ef9d20d5

6
.gitignore vendored

@ -3,4 +3,8 @@ __pycache__/
**/__pycache__/
.vscode/
*.json
*.json
outputs/
weights/
wandb/

@ -1,3 +1,5 @@
from .config import Config
from .config import Config, ComponentConfig
from .wandb_config import WandbConfig
from .dataset_config import DatasetConfig
__all__ = ["Config"]
__all__ = ["Config", "WandbConfig", "DatasetConfig", "ComponentConfig"]

@ -2,6 +2,7 @@ import json
from typing import Any, Dict, Optional
from pydantic import BaseModel
from .wandb_config import WandbConfig
from .dataset_config import DatasetConfig
@ -33,6 +34,7 @@ class ComponentConfig(BaseModel):
class Config(BaseModel):
model: ComponentConfig
dataset_config: DatasetConfig
wandb_config: WandbConfig
criterion: Optional[ComponentConfig] = None
optimizer: Optional[ComponentConfig] = None
scheduler: Optional[ComponentConfig] = None
@ -57,6 +59,7 @@ class Config(BaseModel):
data["optimizer"] = self.optimizer.dump()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.dump()
data["wandb"] = self.wandb_config.model_dump()
return data
@ -88,8 +91,9 @@ class Config(BaseModel):
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Parse dataset_config using its Pydantic model.
# Parse dataset_config and wandb_config using its Pydantic model.
dataset_config = DatasetConfig(**data.get("dataset_config", {}))
wandb_config = WandbConfig(**data.get("wandb", {}))
# Helper function to parse registry fields.
def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]:
@ -119,5 +123,6 @@ class Config(BaseModel):
dataset_config=dataset_config,
criterion=parsed_criterion,
optimizer=parsed_optimizer,
scheduler=parsed_scheduler
scheduler=parsed_scheduler,
wandb_config=wandb_config
)

@ -7,10 +7,11 @@ class DatasetCommonConfig(BaseModel):
"""
Common configuration fields shared by both training and testing.
"""
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars'
predictions_dir: str = "." # Directory to save predictions
@model_validator(mode="after")
@ -62,8 +63,8 @@ class DatasetTrainingConfig(BaseModel):
split: TrainingSplitInfo = TrainingSplitInfo()
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.1 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.2 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data
@ -99,7 +100,7 @@ class DatasetTrainingConfig(BaseModel):
- If is_split is False, validates split (all_data_dir must be non-empty and exist).
"""
if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)):
if (self.train_size + self.valid_size + self.test_size) > 1:
if (self.train_size + self.valid_size + self.test_size) > 1 and not self.is_split:
raise ValueError("The total sample size with dynamically defined sizes must be <= 1")
if not self.is_split:
@ -214,34 +215,6 @@ class DatasetTestingConfig(BaseModel):
return self
class WandbConfig(BaseModel):
"""
Configuration for Weights & Biases logging.
"""
use_wandb: bool = False # Whether to enable WandB logging
project: Optional[str] = None # WandB project name
entity: Optional[str] = None # WandB entity (user or team)
name: Optional[str] = None # Name of the run
tags: Optional[list[str]] = None # List of tags for the run
notes: Optional[str] = None # Notes or description for the run
save_code: bool = True # Whether to save the code to WandB
@model_validator(mode="after")
def validate_wandb(cls) -> "WandbConfig":
if cls.use_wandb:
if not cls.project:
raise ValueError("When use_wandb=True, 'project' must be provided")
if not cls.entity:
raise ValueError("When use_wandb=True, 'entity' must be provided")
return cls
def asdict(self) -> Dict[str, Any]:
"""
Return a dict of all W&B parameters, excluding 'use_wandb' and any None values.
"""
return self.model_dump(exclude_none=True, exclude={"use_wandb"})
class DatasetConfig(BaseModel):
"""
Main dataset configuration that groups fields into nested models for a structured and readable JSON.
@ -250,7 +223,6 @@ class DatasetConfig(BaseModel):
common: DatasetCommonConfig = DatasetCommonConfig()
training: DatasetTrainingConfig = DatasetTrainingConfig()
testing: DatasetTestingConfig = DatasetTestingConfig()
wandb: WandbConfig = WandbConfig()
@model_validator(mode="after")
def validate_config(self) -> "DatasetConfig":
@ -265,15 +237,11 @@ class DatasetConfig(BaseModel):
if (self.training.is_split and self.training.pre_split.test_dir) or (not self.training.is_split):
if self.training.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
else:
if self.testing is None:
raise ValueError("Testing configuration must be provided when is_training is False")
if self.testing.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
return self
def model_dump(self, **kwargs) -> Dict[str, Any]:
@ -286,12 +254,10 @@ class DatasetConfig(BaseModel):
"is_training": self.is_training,
"common": self.common.model_dump(),
"training": self.training.model_dump() if self.training else {},
"wandb": self.wandb.model_dump()
}
else:
return {
"is_training": self.is_training,
"common": self.common.model_dump(),
"testing": self.testing.model_dump() if self.testing else {},
"wandb": self.wandb.model_dump()
}

@ -0,0 +1,29 @@
from pydantic import BaseModel, model_validator
from typing import Any, Dict, Optional
class WandbConfig(BaseModel):
"""
Configuration for Weights & Biases logging.
"""
use_wandb: bool = False # Whether to enable WandB logging
project: Optional[str] = None # WandB project name
group: Optional[str] = None # WandB group name
entity: Optional[str] = None # WandB entity (user or team)
name: Optional[str] = None # Name of the run
tags: Optional[list[str]] = None # List of tags for the run
notes: Optional[str] = None # Notes or description for the run
save_code: bool = True # Whether to save the code to WandB
@model_validator(mode="after")
def validate_wandb(self) -> "WandbConfig":
if self.use_wandb:
if not self.project:
raise ValueError("When use_wandb=True, 'project' must be provided")
return self
def asdict(self) -> Dict[str, Any]:
"""
Return a dict of all W&B parameters, excluding 'use_wandb' and any None values.
"""
return self.model_dump(exclude_none=True, exclude={"use_wandb"})

@ -169,16 +169,21 @@ def get_predict_transforms():
"""
pred_transforms = Compose(
[
# Load the image data in (H, W, C) format.
CustomLoadImage(image_only=False),
# Load image data in (H, W, C) format (allow missing keys).
CustomLoadImaged(keys=["image"], allow_missing_keys=True, image_only=False),
# Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]),
# Ensure the image is in channel-first format.
EnsureChannelFirst(channel_dim=-1), # image shape: (C, H, W)
CustomNormalizeImaged(
keys=["image"],
allow_missing_keys=True,
channel_wise=False,
percentiles=[0.0, 99.5],
),
# Ensure image is in channel-first format.
EnsureChannelFirstd(keys=["image"], allow_missing_keys=True, channel_dim=-1),
# Scale image intensities.
ScaleIntensity(),
# Convert the image to the required tensor type.
EnsureType(data_type="tensor"),
ScaleIntensityd(keys=["image"], allow_missing_keys=True),
# Ensure that the data types are correct.
EnsureTyped(keys=["image"], allow_missing_keys=True),
]
)
return pred_transforms

@ -1,13 +1,18 @@
import copy
import torch
import numpy as np
from typing import Dict, Sequence, Tuple, Union
from skimage.segmentation import find_boundaries
from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore
import logging
__all__ = ["BoundaryExclusion", "IntensityDiversification"]
logger = logging.getLogger("cell_aware")
class BoundaryExclusion(MapTransform):
"""
Map the cell boundary pixel labels to the background class (0).
@ -164,7 +169,8 @@ class IntensityDiversification(MapTransform):
# If there are no unique cell objects in this channel, raise an exception.
if cell_ids.size == 0:
raise ValueError(f"No unique objects found in the label mask for channel {c}")
logger.warning(f"No unique objects found in the label mask for channel {c}")
continue
# Determine the number of cells to modify using the change_cell_ratio.
change_count = int(len(cell_ids) * self.change_cell_ratio)
@ -175,7 +181,10 @@ class IntensityDiversification(MapTransform):
# Create a binary mask for the current channel:
# - Pixels corresponding to the selected cell IDs are set to 1.
# - All other pixels are set to 0.
mask = np.isin(channel_label, selected).astype(np.float32)
mask_np = np.isin(channel_label, selected).astype(np.float32)
# Convert mask to same dtype and device
mask = torch.from_numpy(mask_np).to(dtype=torch.float32, device=channel_label.device)
# Separate the image channel into two components:
# 1. img_orig: The portion of the image that remains unchanged.
@ -183,8 +192,11 @@ class IntensityDiversification(MapTransform):
img_orig = (1 - mask) * img_channel
img_changed = mask * img_channel
# Add a channel dimension for RandScaleIntensity: (1, H, W)
img_changed = img_changed.unsqueeze(0)
# Apply a random intensity scaling transformation to the selected regions.
img_changed = self.randscale_intensity(img_changed)
img_changed = img_changed.squeeze(0) # type: ignore # back to shape (H, W)
# Combine the unchanged and modified parts to update the image channel.
data["image"][c] = img_orig + img_changed

@ -6,7 +6,7 @@ from typing import Dict, Any, Optional
from monai.metrics.cumulative_average import CumulativeAverage
class BaseLoss(abc.ABC):
class BaseLoss(nn.Module, abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, params: Optional[BaseModel] = None):

@ -28,7 +28,7 @@ class BCELossParams(BaseModel):
loss_kwargs = self.model_dump()
if not self.with_logits:
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
loss_kwargs.pop("with_logits", None)
loss_kwargs.pop("with_logits", None)
weight = loss_kwargs.get("weight")
pos_weight = loss_kwargs.get("pos_weight")

@ -12,6 +12,7 @@ import fastremap
import fill_voids
from skimage import morphology
from skimage.segmentation import find_boundaries
from scipy.special import expit
from scipy.ndimage import mean, find_objects
from monai.data.dataset import Dataset
@ -53,10 +54,11 @@ logger = get_logger()
class CellSegmentator:
def __init__(self, config: Config) -> None:
self._device: torch.device = torch.device(config.dataset_config.common.device or "cpu")
self.__set_seed(config.dataset_config.common.seed)
self.__parse_config(config)
self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu")
self._scaler = (
torch.amp.GradScaler(self._device.type) # type: ignore
if self._dataset_setup.is_training and self._dataset_setup.common.use_amp
@ -153,7 +155,7 @@ class CellSegmentator:
# Train dataloader
train_dataset = self.__get_dataset(
images_dir=os.path.join(train_dir, 'images'),
masks_dir=os.path.join(train_dir, 'masks'),
masks_dir=os.path.join(train_dir, 'masks', self._dataset_setup.common.masks_subdir),
transforms=train_transforms, # type: ignore
size=self._dataset_setup.training.train_size,
offset=train_offset,
@ -168,7 +170,7 @@ class CellSegmentator:
raise RuntimeError("Validation directory or size is not properly configured.")
valid_dataset = self.__get_dataset(
images_dir=os.path.join(valid_dir, 'images'),
masks_dir=os.path.join(valid_dir, 'masks'),
masks_dir=os.path.join(valid_dir, 'masks', self._dataset_setup.common.masks_subdir),
transforms=valid_transforms,
size=self._dataset_setup.training.valid_size,
offset=valid_offset,
@ -183,7 +185,7 @@ class CellSegmentator:
raise RuntimeError("Test directory or size is not properly configured.")
test_dataset = self.__get_dataset(
images_dir=os.path.join(test_dir, 'images'),
masks_dir=os.path.join(test_dir, 'masks'),
masks_dir=os.path.join(test_dir, 'masks', self._dataset_setup.common.masks_subdir),
transforms=test_transforms,
size=self._dataset_setup.training.test_size,
offset=test_offset,
@ -210,7 +212,7 @@ class CellSegmentator:
else:
# Inference mode (no training)
test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images')
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks')
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks', self._dataset_setup.common.masks_subdir)
if test_transforms is not None:
test_dataset = self.__get_dataset(
@ -385,7 +387,7 @@ class CellSegmentator:
batch_counter = 0
for batch in tqdm(self._predict_dataloader, desc="Predicting"):
# Move input images to the configured device (CPU/GPU)
inputs = batch["img"].to(self._device)
inputs = batch["image"].to(self._device)
# Use automatic mixed precision if enabled in dataset setup
with torch.amp.autocast( # type: ignore
@ -443,15 +445,40 @@ class CellSegmentator:
def load_from_checkpoint(self, checkpoint_path: str) -> None:
"""
Loads model weights from a specified checkpoint into the current model.
Loads model weights from a specified checkpoint into the current model,
but only for parameters whose shapes match. Parameters with mismatched
shapes (e.g., classification heads with different output sizes) remain
at their initialized values.
Args:
checkpoint_path (str): Path to the checkpoint file containing the model weights.
"""
# Load the checkpoint onto the correct device (CPU or GPU)
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True)
# Load the state dict into the model, allowing for missing keys
self._model.load_state_dict(checkpoint['state_dict'], strict=False)
# Load the checkpoint (state_dict) from file onto CPU
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# Extract nested state_dict if present
state_dict = checkpoint.get("state_dict", checkpoint)
# Get the current model's parameter dictionary
model_dict = self._model.state_dict()
# Filter pretrained parameters to those matching in name and shape
pretrained_dict = {
k: v for k, v in state_dict.items()
if k in model_dict and v.size() == model_dict[k].size()
}
# Log how many parameters are loaded, skipped, or missing
skipped = [k for k in state_dict if k not in pretrained_dict]
missing = [k for k in model_dict if k not in pretrained_dict]
logger.info(
f"Loaded {len(pretrained_dict)} parameters;"
f" skipped {len(skipped)} params from checkpoint;"
f" {len(missing)} params remain uninitialized in model."
)
# Update the model's state_dict and load it
model_dict.update(pretrained_dict)
self._model.load_state_dict(model_dict)
def save_checkpoint(self, checkpoint_path: str) -> None:
@ -461,12 +488,9 @@ class CellSegmentator:
Args:
checkpoint_path (str): Path where the checkpoint file will be saved.
"""
# Create a checkpoint dictionary containing the models state_dict
checkpoint = {
'state_dict': self._model.state_dict()
}
# Write the checkpoint to disk
torch.save(checkpoint, checkpoint_path)
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(self._model.state_dict(), checkpoint_path)
def __parse_config(self, config: Config) -> None:
@ -492,15 +516,23 @@ class CellSegmentator:
if scheduler:
logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2))
logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2))
logger.info("Wandb Config:\n%s", pformat(config.wandb_config.model_dump(), indent=2))
logger.info("==========================================")
# Initialize model using the model registry
self._model = ModelRegistry.get_model_class(model.name)(model.params)
# Loads model weights from a specified checkpoint
if config.dataset_config.is_training:
if config.dataset_config.training.pretrained_weights:
self.load_from_checkpoint(config.dataset_config.training.pretrained_weights)
pretrained_weights = (
config.dataset_config.training.pretrained_weights
if config.dataset_config.is_training
else config.dataset_config.testing.pretrained_weights
)
if pretrained_weights:
self.load_from_checkpoint(pretrained_weights)
logger.info(f"Loaded pre-trained weights from: {pretrained_weights}")
self._model = self._model.to(self._device)
# Initialize loss criterion if specified
self._criterion = (
@ -555,6 +587,7 @@ class CellSegmentator:
logger.info(f"├─ Seed: {common.seed}")
logger.info(f"├─ Device: {common.device}")
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
logger.info(f"├─ Masks subdirectory: {common.masks_subdir}")
logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
if config.dataset_config.is_training:
@ -592,18 +625,21 @@ class CellSegmentator:
logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}")
logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}")
wandb_cfg = config.dataset_config.wandb
if wandb_cfg.use_wandb:
self._wandb_config = config.wandb_config
if self._wandb_config.use_wandb:
logger.info("[W&B]")
logger.info(f"├─ Project: {wandb_cfg.project}")
logger.info(f"├─ Entity: {wandb_cfg.entity}")
if wandb_cfg.name:
logger.info(f"├─ Run name: {wandb_cfg.name}")
if wandb_cfg.tags:
logger.info(f"├─ Tags: {', '.join(wandb_cfg.tags)}")
if wandb_cfg.notes:
logger.info(f"├─ Notes: {wandb_cfg.notes}")
logger.info(f"└─ Save code: {'yes' if wandb_cfg.save_code else 'no'}")
logger.info(f"├─ Project: {self._wandb_config.project}")
if self._wandb_config.group:
logger.info(f"├─ Group: {self._wandb_config.group}")
if self._wandb_config.entity:
logger.info(f"├─ Entity: {self._wandb_config.entity}")
if self._wandb_config.name:
logger.info(f"├─ Run name: {self._wandb_config.name}")
if self._wandb_config.tags:
logger.info(f"├─ Tags: {', '.join(self._wandb_config.tags)}")
if self._wandb_config.notes:
logger.info(f"├─ Notes: {self._wandb_config.notes}")
logger.info(f"└─ Save code: {'yes' if self._wandb_config.save_code else 'no'}")
else:
logger.info("[W&B] Logging disabled")
@ -657,13 +693,19 @@ class CellSegmentator:
ValueError: If dataset is too small for requested size or offset.
"""
# Collect sorted list of image paths
images = sorted(glob.glob(images_dir))
images = sorted(
glob.glob(os.path.join(images_dir, '*.tif')) +
glob.glob(os.path.join(images_dir, '*.tiff'))
)
if not images:
raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'")
if masks_dir is not None:
# Collect and validate sorted list of mask paths
masks = sorted(glob.glob(masks_dir))
masks = sorted(
glob.glob(os.path.join(masks_dir, '*.tif')) +
glob.glob(os.path.join(masks_dir, '*.tiff'))
)
if len(images) != len(masks):
raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})")
@ -720,7 +762,7 @@ class CellSegmentator:
tablefmt="fancy_grid"
)
print(table, "\n")
if self._dataset_setup.wandb.use_wandb:
if self._wandb_config.use_wandb:
wandb.log(results, step=step)
@ -765,8 +807,8 @@ class CellSegmentator:
# Iterate over batches
batch_counter = 0
for batch in tqdm(loader, desc=desc):
inputs = batch["img"].to(self._device)
targets = batch["label"].to(self._device)
inputs = batch["image"].to(self._device)
targets = batch["mask"].to(self._device)
# Zero gradients for training
if self._optimizer is not None:
@ -787,7 +829,10 @@ class CellSegmentator:
flow_targets = self.__compute_flows_from_masks(targets)
# Compute loss for this batch
batch_loss = self._criterion(raw_output, flow_targets) # type: ignore
batch_loss = self._criterion(
raw_output,
torch.from_numpy(flow_targets).to(device=raw_output.device)
)
# Post-process and compute F1 during validation and testing
if mode in ("valid", "test"):
@ -842,6 +887,9 @@ class CellSegmentator:
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="micro"
)
epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="imagewise"
)
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="macro"
)
@ -914,7 +962,7 @@ class CellSegmentator:
instance_masks[idx] = self.__segment_instances(
probability_map=probabilities[idx],
flow=gradflow[idx],
prob_threshold=0.0,
prob_threshold=0.5,
flow_threshold=0.4,
min_object_size=15
)
@ -1159,7 +1207,7 @@ class CellSegmentator:
Returns:
np.ndarray: Sigmoid of the input.
"""
return 1 / (1 + np.exp(-z))
return expit(z)
def __save_prediction_masks(
@ -1191,7 +1239,7 @@ class CellSegmentator:
# Convert tensors to numpy
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
return x.cpu().numpy()
return x
image_array = to_numpy(image_obj) if image_obj is not None else None
@ -1201,11 +1249,11 @@ class CellSegmentator:
# Handle batch dimension: (B, C, H, W)
if pred_array.ndim == 4:
for idx in range(pred_array.shape[0]):
batch_sample = dict(sample)
batch_sample: Dict[str, Any] = {}
if image_array is not None and image_array.ndim == 4:
batch_sample["image"] = image_array[idx]
if isinstance(image_meta, list):
batch_sample["image_meta_dict"] = image_meta[idx]
if isinstance(image_meta, dict) and "filename_or_obj" in image_meta:
batch_sample["image_meta_dict"] = image_meta["filename_or_obj"][idx]
if mask_array is not None and mask_array.ndim == 4:
batch_sample["mask"] = mask_array[idx]
self.__save_prediction_masks(
@ -1216,8 +1264,8 @@ class CellSegmentator:
return
# Determine base filename
if image_meta and "filename_or_obj" in image_meta:
base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0]
if isinstance(image_meta, (str, os.PathLike)):
base_name = os.path.splitext(os.path.basename(image_meta))[0]
else:
# Use provided start_index when metadata missing
base_name = f"prediction_{start_index:04d}"
@ -1228,8 +1276,8 @@ class CellSegmentator:
channel_mask = pred_array[channel_idx]
# File names
mask_filename = f"{base_name}_ch{channel_idx:02d}.tif"
plot_filename = f"{base_name}_ch{channel_idx:02d}.png"
mask_filename = f"{base_name}_ch{channel_idx:01d}.tif"
plot_filename = f"{base_name}_ch{channel_idx:01d}.png"
mask_path = os.path.join(masks_dir, mask_filename)
plot_path = os.path.join(plots_dir, plot_filename)
@ -1402,78 +1450,81 @@ class CellSegmentator:
flows = np.zeros((2*channels, height, width), np.float32)
for channel in range(channels):
padded_height, padded_width = height + 2, width + 2
# Pad the mask with a 1-pixel border
masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device)
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
# Get coordinates of all non-zero pixels in the padded mask
y, x = torch.nonzero(masks_padded, as_tuple=True)
y = y.int(); x = x.int() # ensure integer type
# Generate 8-connected neighbors (including center) via broadcasted offsets
offsets = torch.tensor([
[ 0, 0], # center
[-1, 0], # up
[ 1, 0], # down
[ 0, -1], # left
[ 0, 1], # right
[-1, -1], # up-left
[-1, 1], # up-right
[ 1, -1], # down-left
[ 1, 1], # down-right
], dtype=torch.int32, device=self._device) # (9, 2)
# coords: (N, 2)
coords = torch.stack((y, x), dim=1)
# neighbors: (9, N, 2)
neighbors = offsets[:, None, :] + coords[None, :, :]
# transpose into (2, 9, N) for the GPU kernel
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
# Build connectivity mask: True where neighbor label == center label
center_labels = masks_padded[y, x][None, :] # (1, N)
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
is_neighbor = neighbor_labels == center_labels # (9, N)
# Compute object slices and pack into array for get_centers
slices = find_objects(mask)
slices_arr = np.array([
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
for i, sl in enumerate(slices) if sl is not None
], dtype=int)
mask_channel = mask[channel]
# Compute centers (pixel indices) and extents via the provided helper
centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr)
# Move centers to GPU and shift by +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding
# Determine number of diffusion iterations
n_iter = 2 * ext.max()
# Run the GPU diffusion kernel
mu = self.__propagate_centers_gpu(
neighbor_indices=neighbors,
center_indices=meds_p.T,
valid_neighbor_mask=is_neighbor,
output_shape=(padded_height, padded_width),
num_iterations=n_iter
)
if mask_channel.max() > 0:
padded_height, padded_width = height + 2, width + 2
# Pad the mask with a 1-pixel border
masks_padded = torch.from_numpy(mask_channel.astype(np.int64)).to(self._device)
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
# Get coordinates of all non-zero pixels in the padded mask
y, x = torch.nonzero(masks_padded, as_tuple=True)
y = y.int(); x = x.int() # ensure integer type
# Generate 8-connected neighbors (including center) via broadcasted offsets
offsets = torch.tensor([
[ 0, 0], # center
[-1, 0], # up
[ 1, 0], # down
[ 0, -1], # left
[ 0, 1], # right
[-1, -1], # up-left
[-1, 1], # up-right
[ 1, -1], # down-left
[ 1, 1], # down-right
], dtype=torch.int32, device=self._device) # (9, 2)
# coords: (N, 2)
coords = torch.stack((y, x), dim=1)
# neighbors: (9, N, 2)
neighbors = offsets[:, None, :] + coords[None, :, :]
# transpose into (2, 9, N) for the GPU kernel
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
# Build connectivity mask: True where neighbor label == center label
center_labels = masks_padded[y, x][None, :] # (1, N)
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
is_neighbor = neighbor_labels == center_labels # (9, N)
# Compute object slices and pack into array for get_centers
slices = find_objects(mask_channel)
slices_arr = np.array([
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
for i, sl in enumerate(slices, start=1) if sl is not None
], dtype=np.int16)
# Compute centers (pixel indices) and extents via the provided helper
centers, ext = self.__get_mask_centers_and_extents(mask_channel, slices_arr)
# Move centers to GPU and shift by +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding
# Determine number of diffusion iterations
n_iter = 2 * ext.max()
# Run the GPU diffusion kernel
mu = self.__propagate_centers_gpu(
neighbor_indices=neighbors,
center_indices=meds_p.T,
valid_neighbor_mask=is_neighbor,
output_shape=(padded_height, padded_width),
num_iterations=n_iter
)
# Cast to float64 and normalize flow vectors
mu = mu.astype(np.float64)
mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60
# Cast to float64 and normalize flow vectors
mu = mu.astype(np.float64)
mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60
# Remove the padding and write into final output
flow_output = np.zeros((2, height, width), dtype=np.float32)
ys_np = y.cpu().numpy() - 1
xs_np = x.cpu().numpy() - 1
flow_output[:, ys_np, xs_np] = mu
flows[2*channel: 2*channel + 2] = flow_output
# Remove the padding and write into final output
flow_output = np.zeros((2, height, width), dtype=np.float32)
ys_np = y.cpu().numpy() - 1
xs_np = x.cpu().numpy() - 1
flow_output[:, ys_np, xs_np] = mu
flows[2*channel: 2*channel + 2] = flow_output
return flows
@ -1624,8 +1675,10 @@ class CellSegmentator:
# Initialize flow_field with two channels: dy and dx
flow_field = np.zeros((2, height, width), dtype=np.float64)
mask_channel = mask[channel]
# Find bounding box for each labeled mask
mask_slices = find_objects(mask)
mask_slices = find_objects(mask_channel)
# centers: List[Tuple[int, int]] = []
# Iterate over mask labels in parallel
@ -1642,7 +1695,7 @@ class CellSegmentator:
# Get local coordinates of mask pixels within the patch
local_rows, local_cols = np.nonzero(
mask[row_slice, col_slice] == (label_idx + 1)
mask_channel[row_slice, col_slice] == (label_idx + 1)
)
# Shift coords by +1 for the border padding
local_rows = local_rows.astype(np.int32) + 1
@ -1774,8 +1827,8 @@ class CellSegmentator:
Generate instance segmentation masks from probability and flow fields.
Args:
probability_map: 3D array (channels, height, width) of cell probabilities.
flow: 3D array (2*channels, height, width) of forward flow vectors.
probability_map: 3D array `(C, H, W)` of cell probabilities.
flow: 3D array `(2*C, H, W)` of forward flow vectors.
prob_threshold: threshold to binarize probability_map. (Default 0.0)
flow_threshold: threshold for filtering bad flow masks. (Default 0.4)
num_iters: number of iterations for flow-following. (Default 200)
@ -1802,6 +1855,9 @@ class CellSegmentator:
# Extract binary mask for this channel
channel_mask = probability_mask[channel_index]
if not channel_mask.sum():
continue
nonzero_coords = np.stack(np.nonzero(channel_mask))
# Follow the flow vectors to generate coordinate mappings
@ -1810,12 +1866,6 @@ class CellSegmentator:
initial_coords=nonzero_coords,
num_iters=num_iters
)
# If flow following fails, leave this channel empty
if flow_coordinates is None:
labeled_instances[channel_index] = np.zeros(
probability_map.shape[1:], dtype=np.uint16
)
continue
if not torch.is_tensor(flow_coordinates):
flow_coordinates = torch.from_numpy(
@ -1851,11 +1901,6 @@ class CellSegmentator:
)
labeled_instances[channel_index] = channel_instances_mask
else:
# No valid instances found, leave the channel empty
labeled_instances[channel_index] = np.zeros(
probability_map.shape[1:], dtype=np.uint16
)
return labeled_instances
@ -1923,7 +1968,7 @@ class CellSegmentator:
)
# Update each coordinate and clamp to valid range
for i in range(dims):
pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0)
pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0)
# Denormalize back to original pixel coordinates
pts = (pts + 1) * 0.5 * max_idx_pt
@ -2072,16 +2117,21 @@ class CellSegmentator:
raise
# Step 3: Find peaks via 5x5 max-pooling
k = 5
pooled = F.max_pool2d(
# k = 5
# pooled = F.max_pool2d(
# histogram.float().unsqueeze(0).unsqueeze(1),
# kernel_size=k,
# stride=1,
# padding=k // 2
# ).squeeze()
pooled = self.__max_pool_nd(
histogram.unsqueeze(0),
kernel_size=k,
stride=1,
padding=k // 2
kernel_size=5
).squeeze()
# Seeds are positions where histogram equals local max and count > threshold
seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10))
if seed_positions.numel() == 0:
seed_positions = torch.nonzero((histogram - pooled == 0) * (histogram > 10))
if seed_positions.shape[0] == 0:
logger.warning("No seeds found: returning empty mask")
return np.zeros(original_shape, dtype=np.uint16)
@ -2106,13 +2156,14 @@ class CellSegmentator:
seed_masks[:, 5, 5] = 1
# Iterative dilation and thresholding
for _ in range(5):
seed_masks = F.max_pool2d(
seed_masks,
kernel_size=3,
stride=1,
padding=1
)
seed_masks = seed_masks & (patches > 2)
# seed_masks = F.max_pool2d(
# seed_masks.float().unsqueeze(0),
# kernel_size=3,
# stride=1,
# padding=1
# ).squeeze(0).int()
seed_masks = self.__max_pool_nd(seed_masks, kernel_size=3)
seed_masks *= (patches > 2)
# Compute final mask coordinates
final_coords = []
for idx in range(num_seeds):
@ -2133,7 +2184,7 @@ class CellSegmentator:
# Step 6: Map to original image and remove oversized masks
mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32)
mask_final[valid_indices] = mask_values
mask_final[tuple(valid_indices)] = mask_values
# Prune masks that are too large
labels, counts = fastremap.unique(mask_final, return_counts=True)
@ -2146,6 +2197,96 @@ class CellSegmentator:
return mask_final
def __max_pool1d(
self,
input_tensor: Tensor,
kernel_size: int = 5,
axis: int = 1,
output_tensor: Optional[Tensor] = None
) -> Tensor:
"""
Memory-efficient 1D max pooling along a specified axis using in-place updates.
Requires:
- stride = 1
- padding = kernel_size // 2
- odd kernel_size >= 3
Args:
input_tensor (Tensor): Source tensor for pooling.
kernel_size (int): Size of the pooling window (must be odd and >= 3).
axis (int): Axis along which to compute 1D max pooling.
output_tensor (Optional[Tensor]): Tensor to store the result.
If None, a clone of input_tensor is used.
Returns:
Tensor: The pooled tensor, same shape as input_tensor.
"""
# Initialize or copy data into the output tensor
if output_tensor is None:
output = input_tensor.clone()
else:
output = output_tensor
output.copy_(input_tensor)
# Number of elements along the chosen axis and half-window size
dimension_size = input_tensor.shape[axis]
half_window = kernel_size // 2
# Slide window offsets from -half_window to +half_window
for offset in range(-half_window, half_window + 1):
# Compute slice indices depending on axis
if axis == 1:
target_slice = output[:, max(-offset, 0): min(dimension_size - offset, dimension_size)]
source_slice = input_tensor[:, max(offset, 0): min(dimension_size + offset, dimension_size)]
elif axis == 2:
target_slice = output[:, :, max(-offset, 0): min(dimension_size - offset, dimension_size)]
source_slice = input_tensor[:, :, max(offset, 0): min(dimension_size + offset, dimension_size)]
elif axis == 3:
target_slice = output[:, :, :, max(-offset, 0): min(dimension_size - offset, dimension_size)]
source_slice = input_tensor[:, :, :, max(offset, 0): min(dimension_size + offset, dimension_size)]
else:
raise ValueError(f"Unsupported axis {axis} for 1D pooling")
# In-place element-wise maximum
torch.maximum(target_slice, source_slice, out=target_slice)
return output
def __max_pool_nd(
self,
input_tensor: Tensor,
kernel_size: int = 5
) -> Tensor:
"""
Memory-efficient N-dimensional max pooling for 2D or 3D spatial data.
Applies 1D max pooling sequentially over each spatial axis.
Args:
input_tensor (Tensor): Input tensor with shape
(batch_size, dim1, dim2, ..., dimN).
kernel_size (int): Size of the pooling window (must be odd and >= 3).
Returns:
Tensor: The pooled tensor, same shape as input_tensor.
"""
# Determine number of spatial dimensions (excluding batch axis)
num_spatial_dims = input_tensor.ndim - 1
# First pass: pool along axis=1
pooled = self.__max_pool1d(input_tensor, kernel_size=kernel_size, axis=1)
# Second pass: pool along axis=2
pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=2)
# If 3D data, apply a third pass along axis=3
if num_spatial_dims == 3:
pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=3)
elif num_spatial_dims != 2:
raise ValueError("max_pool_nd only supports 2D or 3D spatial data")
return pooled
def __remove_inconsistent_flow_masks(
self,
mask: np.ndarray,
@ -2175,11 +2316,7 @@ class CellSegmentator:
"""
# If mask is very large and running on CUDA, check memory
num_pixels = mask.size
if (
num_pixels > 10000 * 10000
and self._device.type == 'cuda'
):
if num_pixels > 10000 * 10000 and self._device.type == 'cuda':
# Clear unused GPU cache
torch.cuda.empty_cache()
# Determine PyTorch version
@ -2296,15 +2433,8 @@ class CellSegmentator:
logger.error(msg)
raise ValueError(msg)
# Optionally remove masks smaller than minimum_size
if minimum_size >= 0:
# Compute label counts (skipping background at index 0)
labels, counts = fastremap.unique(masks, return_counts=True)
# Identify labels to remove: those with count < minimum_size
small_labels = labels[counts < minimum_size]
if small_labels.size > 0:
masks = fastremap.mask(masks, small_labels)
fastremap.renumber(masks, in_place=True)
# Initial pruning of too-small masks
masks = self._prune_small_masks(masks, minimum_size)
# Find bounding boxes for each mask label
object_slices = find_objects(masks)
@ -2325,12 +2455,36 @@ class CellSegmentator:
output_masks[slc][filled_region] = new_label
new_label += 1
# Final pruning of small masks after filling (optional)
if minimum_size >= 0:
labels, counts = fastremap.unique(output_masks, return_counts=True)
small_labels = labels[counts < minimum_size]
if small_labels.size > 0:
output_masks = fastremap.mask(output_masks, small_labels)
fastremap.renumber(output_masks, in_place=True)
# Final pruning after hole filling
output_masks = self._prune_small_masks(output_masks, minimum_size)
return output_masks
def _prune_small_masks(
self,
masks: np.ndarray,
minimum_size: int
) -> np.ndarray:
"""
Remove labeled regions in `masks` whose pixel count is below `minimum_size`.
return output_masks
Args:
masks (np.ndarray): Integer mask array (any shape), 0=background.
minimum_size (int): Minimum pixel count; labels smaller are removed. If <0, skip pruning.
Returns:
np.ndarray: Mask array with small labels suppressed and labels renumbered.
"""
if minimum_size < 0:
return masks
labels, counts = fastremap.unique(masks, return_counts=True)
# Skip background label at index 0
non_bg_labels = labels[1:]
non_bg_counts = counts[1:]
# Identify labels to remove
small_labels = non_bg_labels[non_bg_counts < minimum_size]
if small_labels.size > 0:
masks = fastremap.mask(masks, small_labels)
fastremap.renumber(masks, in_place=True)
return masks

@ -11,6 +11,8 @@ from skimage import segmentation
from scipy.optimize import linear_sum_assignment
from typing import Dict, List, Tuple, Any, Union
from core.logger import get_logger
__all__ = [
"compute_batch_segmentation_f1_metrics", "compute_batch_segmentation_average_precision_metrics",
"compute_batch_segmentation_tp_fp_fn",
@ -18,7 +20,9 @@ __all__ = [
"compute_segmentation_tp_fp_fn",
"compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score"
]
logger = get_logger()
def compute_f1_score(
true_positives: int,
@ -92,7 +96,7 @@ def compute_confusion_matrix(
# If no predictions were made, return zeros (with a printout for debugging).
if num_predictions == 0:
print("No segmentation results!")
logger.warning("No segmentation results!")
return 0, 0, 0
# Compute the IoU matrix and ignore the background (first row and column).
@ -586,7 +590,7 @@ def _process_instance_matching(
# If no predictions are found, return with all ground truth as false negatives.
if num_prediction == 0:
print("No segmentation results!")
logger.warning("No segmentation results!")
result = {'tp': 0, 'fp': 0, 'fn': num_ground_truth}
if return_masks:
tp_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8)

@ -1,8 +1,7 @@
import os
from typing import Tuple
from config.config import *
from config.dataset_config import DatasetConfig
from config import Config, WandbConfig, DatasetConfig, ComponentConfig
from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
@ -47,7 +46,8 @@ def main():
if is_training is False:
config = Config(
model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config
dataset_config=dataset_config,
wandb_config=WandbConfig()
)
# Construct a base filename from the selected registry names.
@ -76,6 +76,7 @@ def main():
config = Config(
model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config,
wandb_config=WandbConfig(),
criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance),
optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance),
scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance)

@ -1,41 +1,78 @@
import os
import argparse
import wandb
from config.config import Config
from config import Config
from core.data import *
from core.segmentator import CellSegmentator
if __name__ == "__main__":
config_path = 'config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json'
# config_path = 'config/templates/predict/ModelV.json'
def main():
parser = argparse.ArgumentParser(
description="Train or predict cell segmentator with specified config file."
)
parser.add_argument(
'-c', '--config',
type=str,
default='config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json',
help='Path to the JSON config file'
)
parser.add_argument(
'-m', '--mode',
choices=['train', 'test', 'predict'],
default='train',
help='Run mode: train, test or predict'
)
args = parser.parse_args()
mode = args.mode
config_path = args.config
config = Config.load_json(config_path)
# config = Config.load_json(config_path)
if config.dataset_config.wandb.use_wandb:
# Initialize W&B
wandb.init(config=config.asdict(), **config.dataset_config.wandb.asdict())
if mode == 'train' and not config.dataset_config.is_training:
raise ValueError(
f"Config is not set for training (is_training=False), but mode 'train' was requested."
)
if mode in ('test', 'predict') and config.dataset_config.is_training:
raise ValueError(
f"Config is set for training (is_training=True), but mode '{mode}' was requested."
)
if config.wandb_config.use_wandb:
# Initialize W&B
wandb.init(config=config.asdict(), **config.wandb_config.asdict())
# How many batches to wait before logging training status
wandb.config.log_interval = 10
segmentator = CellSegmentator(config)
segmentator.create_dataloaders()
segmentator.create_dataloaders(
train_transforms=get_train_transforms() if mode == "train" else None,
valid_transforms=get_valid_transforms() if mode == "train" else None,
test_transforms=get_test_transforms() if mode in ("train", "test") else None,
predict_transforms=get_predict_transforms() if mode == "predict" else None
)
# Watch parameters & gradients of model
if config.dataset_config.wandb.use_wandb:
if config.wandb_config.use_wandb:
wandb.watch(segmentator._model, log="all", log_graph=True)
# Run training (or prediction, if implemented)
segmentator.run()
weights_dir = "weights" if not config.dataset_config.wandb.use_wandb else wandb.run.dir # type: ignore
saving_path = os.path.join(
weights_dir, os.path.basename(config.dataset_config.common.predictions_dir) + '.pth'
)
segmentator.save_checkpoint(saving_path)
if config.dataset_config.wandb.use_wandb:
wandb.save(saving_path)
if config.dataset_config.is_training:
# Prepare saving path
weights_dir = (
wandb.run.dir if config.wandb_config.use_wandb else "weights" # type: ignore
)
saving_path = os.path.join(
weights_dir,
os.path.basename(config.dataset_config.common.predictions_dir) + '.pth'
)
segmentator.save_checkpoint(saving_path)
if config.wandb_config.use_wandb:
wandb.save(saving_path)
if __name__ == "__main__":
main()

Loading…
Cancel
Save