Fixed bugs that prevented the project from running. Mediar Former values for the F1 metric were achieved.

master
laynholt 2 months ago
parent 4a501ea31a
commit c1ef9d20d5

6
.gitignore vendored

@ -3,4 +3,8 @@ __pycache__/
**/__pycache__/ **/__pycache__/
.vscode/ .vscode/
*.json *.json
outputs/
weights/
wandb/

@ -1,3 +1,5 @@
from .config import Config from .config import Config, ComponentConfig
from .wandb_config import WandbConfig
from .dataset_config import DatasetConfig
__all__ = ["Config"] __all__ = ["Config", "WandbConfig", "DatasetConfig", "ComponentConfig"]

@ -2,6 +2,7 @@ import json
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel from pydantic import BaseModel
from .wandb_config import WandbConfig
from .dataset_config import DatasetConfig from .dataset_config import DatasetConfig
@ -33,6 +34,7 @@ class ComponentConfig(BaseModel):
class Config(BaseModel): class Config(BaseModel):
model: ComponentConfig model: ComponentConfig
dataset_config: DatasetConfig dataset_config: DatasetConfig
wandb_config: WandbConfig
criterion: Optional[ComponentConfig] = None criterion: Optional[ComponentConfig] = None
optimizer: Optional[ComponentConfig] = None optimizer: Optional[ComponentConfig] = None
scheduler: Optional[ComponentConfig] = None scheduler: Optional[ComponentConfig] = None
@ -57,6 +59,7 @@ class Config(BaseModel):
data["optimizer"] = self.optimizer.dump() data["optimizer"] = self.optimizer.dump()
if self.scheduler is not None: if self.scheduler is not None:
data["scheduler"] = self.scheduler.dump() data["scheduler"] = self.scheduler.dump()
data["wandb"] = self.wandb_config.model_dump()
return data return data
@ -88,8 +91,9 @@ class Config(BaseModel):
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# Parse dataset_config using its Pydantic model. # Parse dataset_config and wandb_config using its Pydantic model.
dataset_config = DatasetConfig(**data.get("dataset_config", {})) dataset_config = DatasetConfig(**data.get("dataset_config", {}))
wandb_config = WandbConfig(**data.get("wandb", {}))
# Helper function to parse registry fields. # Helper function to parse registry fields.
def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]: def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]:
@ -119,5 +123,6 @@ class Config(BaseModel):
dataset_config=dataset_config, dataset_config=dataset_config,
criterion=parsed_criterion, criterion=parsed_criterion,
optimizer=parsed_optimizer, optimizer=parsed_optimizer,
scheduler=parsed_scheduler scheduler=parsed_scheduler,
wandb_config=wandb_config
) )

@ -7,10 +7,11 @@ class DatasetCommonConfig(BaseModel):
""" """
Common configuration fields shared by both training and testing. Common configuration fields shared by both training and testing.
""" """
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda') device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA) use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars'
predictions_dir: str = "." # Directory to save predictions predictions_dir: str = "." # Directory to save predictions
@model_validator(mode="after") @model_validator(mode="after")
@ -62,8 +63,8 @@ class DatasetTrainingConfig(BaseModel):
split: TrainingSplitInfo = TrainingSplitInfo() split: TrainingSplitInfo = TrainingSplitInfo()
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic) train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic) valid_size: Union[int, float] = 0.1 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic) test_size: Union[int, float] = 0.2 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data test_offset: int = 0 # Offset for testing data
@ -99,7 +100,7 @@ class DatasetTrainingConfig(BaseModel):
- If is_split is False, validates split (all_data_dir must be non-empty and exist). - If is_split is False, validates split (all_data_dir must be non-empty and exist).
""" """
if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)): if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)):
if (self.train_size + self.valid_size + self.test_size) > 1: if (self.train_size + self.valid_size + self.test_size) > 1 and not self.is_split:
raise ValueError("The total sample size with dynamically defined sizes must be <= 1") raise ValueError("The total sample size with dynamically defined sizes must be <= 1")
if not self.is_split: if not self.is_split:
@ -214,34 +215,6 @@ class DatasetTestingConfig(BaseModel):
return self return self
class WandbConfig(BaseModel):
"""
Configuration for Weights & Biases logging.
"""
use_wandb: bool = False # Whether to enable WandB logging
project: Optional[str] = None # WandB project name
entity: Optional[str] = None # WandB entity (user or team)
name: Optional[str] = None # Name of the run
tags: Optional[list[str]] = None # List of tags for the run
notes: Optional[str] = None # Notes or description for the run
save_code: bool = True # Whether to save the code to WandB
@model_validator(mode="after")
def validate_wandb(cls) -> "WandbConfig":
if cls.use_wandb:
if not cls.project:
raise ValueError("When use_wandb=True, 'project' must be provided")
if not cls.entity:
raise ValueError("When use_wandb=True, 'entity' must be provided")
return cls
def asdict(self) -> Dict[str, Any]:
"""
Return a dict of all W&B parameters, excluding 'use_wandb' and any None values.
"""
return self.model_dump(exclude_none=True, exclude={"use_wandb"})
class DatasetConfig(BaseModel): class DatasetConfig(BaseModel):
""" """
Main dataset configuration that groups fields into nested models for a structured and readable JSON. Main dataset configuration that groups fields into nested models for a structured and readable JSON.
@ -250,7 +223,6 @@ class DatasetConfig(BaseModel):
common: DatasetCommonConfig = DatasetCommonConfig() common: DatasetCommonConfig = DatasetCommonConfig()
training: DatasetTrainingConfig = DatasetTrainingConfig() training: DatasetTrainingConfig = DatasetTrainingConfig()
testing: DatasetTestingConfig = DatasetTestingConfig() testing: DatasetTestingConfig = DatasetTestingConfig()
wandb: WandbConfig = WandbConfig()
@model_validator(mode="after") @model_validator(mode="after")
def validate_config(self) -> "DatasetConfig": def validate_config(self) -> "DatasetConfig":
@ -265,15 +237,11 @@ class DatasetConfig(BaseModel):
if (self.training.is_split and self.training.pre_split.test_dir) or (not self.training.is_split): if (self.training.is_split and self.training.pre_split.test_dir) or (not self.training.is_split):
if self.training.test_size > 0 and not self.common.predictions_dir: if self.training.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero") raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
else: else:
if self.testing is None: if self.testing is None:
raise ValueError("Testing configuration must be provided when is_training is False") raise ValueError("Testing configuration must be provided when is_training is False")
if self.testing.test_size > 0 and not self.common.predictions_dir: if self.testing.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero") raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
return self return self
def model_dump(self, **kwargs) -> Dict[str, Any]: def model_dump(self, **kwargs) -> Dict[str, Any]:
@ -286,12 +254,10 @@ class DatasetConfig(BaseModel):
"is_training": self.is_training, "is_training": self.is_training,
"common": self.common.model_dump(), "common": self.common.model_dump(),
"training": self.training.model_dump() if self.training else {}, "training": self.training.model_dump() if self.training else {},
"wandb": self.wandb.model_dump()
} }
else: else:
return { return {
"is_training": self.is_training, "is_training": self.is_training,
"common": self.common.model_dump(), "common": self.common.model_dump(),
"testing": self.testing.model_dump() if self.testing else {}, "testing": self.testing.model_dump() if self.testing else {},
"wandb": self.wandb.model_dump()
} }

@ -0,0 +1,29 @@
from pydantic import BaseModel, model_validator
from typing import Any, Dict, Optional
class WandbConfig(BaseModel):
"""
Configuration for Weights & Biases logging.
"""
use_wandb: bool = False # Whether to enable WandB logging
project: Optional[str] = None # WandB project name
group: Optional[str] = None # WandB group name
entity: Optional[str] = None # WandB entity (user or team)
name: Optional[str] = None # Name of the run
tags: Optional[list[str]] = None # List of tags for the run
notes: Optional[str] = None # Notes or description for the run
save_code: bool = True # Whether to save the code to WandB
@model_validator(mode="after")
def validate_wandb(self) -> "WandbConfig":
if self.use_wandb:
if not self.project:
raise ValueError("When use_wandb=True, 'project' must be provided")
return self
def asdict(self) -> Dict[str, Any]:
"""
Return a dict of all W&B parameters, excluding 'use_wandb' and any None values.
"""
return self.model_dump(exclude_none=True, exclude={"use_wandb"})

@ -169,16 +169,21 @@ def get_predict_transforms():
""" """
pred_transforms = Compose( pred_transforms = Compose(
[ [
# Load the image data in (H, W, C) format. # Load image data in (H, W, C) format (allow missing keys).
CustomLoadImage(image_only=False), CustomLoadImaged(keys=["image"], allow_missing_keys=True, image_only=False),
# Normalize the (H, W, C) image using the specified percentiles. # Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]), CustomNormalizeImaged(
# Ensure the image is in channel-first format. keys=["image"],
EnsureChannelFirst(channel_dim=-1), # image shape: (C, H, W) allow_missing_keys=True,
channel_wise=False,
percentiles=[0.0, 99.5],
),
# Ensure image is in channel-first format.
EnsureChannelFirstd(keys=["image"], allow_missing_keys=True, channel_dim=-1),
# Scale image intensities. # Scale image intensities.
ScaleIntensity(), ScaleIntensityd(keys=["image"], allow_missing_keys=True),
# Convert the image to the required tensor type. # Ensure that the data types are correct.
EnsureType(data_type="tensor"), EnsureTyped(keys=["image"], allow_missing_keys=True),
] ]
) )
return pred_transforms return pred_transforms

@ -1,13 +1,18 @@
import copy import copy
import torch
import numpy as np import numpy as np
from typing import Dict, Sequence, Tuple, Union from typing import Dict, Sequence, Tuple, Union
from skimage.segmentation import find_boundaries from skimage.segmentation import find_boundaries
from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore
import logging
__all__ = ["BoundaryExclusion", "IntensityDiversification"] __all__ = ["BoundaryExclusion", "IntensityDiversification"]
logger = logging.getLogger("cell_aware")
class BoundaryExclusion(MapTransform): class BoundaryExclusion(MapTransform):
""" """
Map the cell boundary pixel labels to the background class (0). Map the cell boundary pixel labels to the background class (0).
@ -164,7 +169,8 @@ class IntensityDiversification(MapTransform):
# If there are no unique cell objects in this channel, raise an exception. # If there are no unique cell objects in this channel, raise an exception.
if cell_ids.size == 0: if cell_ids.size == 0:
raise ValueError(f"No unique objects found in the label mask for channel {c}") logger.warning(f"No unique objects found in the label mask for channel {c}")
continue
# Determine the number of cells to modify using the change_cell_ratio. # Determine the number of cells to modify using the change_cell_ratio.
change_count = int(len(cell_ids) * self.change_cell_ratio) change_count = int(len(cell_ids) * self.change_cell_ratio)
@ -175,7 +181,10 @@ class IntensityDiversification(MapTransform):
# Create a binary mask for the current channel: # Create a binary mask for the current channel:
# - Pixels corresponding to the selected cell IDs are set to 1. # - Pixels corresponding to the selected cell IDs are set to 1.
# - All other pixels are set to 0. # - All other pixels are set to 0.
mask = np.isin(channel_label, selected).astype(np.float32) mask_np = np.isin(channel_label, selected).astype(np.float32)
# Convert mask to same dtype and device
mask = torch.from_numpy(mask_np).to(dtype=torch.float32, device=channel_label.device)
# Separate the image channel into two components: # Separate the image channel into two components:
# 1. img_orig: The portion of the image that remains unchanged. # 1. img_orig: The portion of the image that remains unchanged.
@ -183,8 +192,11 @@ class IntensityDiversification(MapTransform):
img_orig = (1 - mask) * img_channel img_orig = (1 - mask) * img_channel
img_changed = mask * img_channel img_changed = mask * img_channel
# Add a channel dimension for RandScaleIntensity: (1, H, W)
img_changed = img_changed.unsqueeze(0)
# Apply a random intensity scaling transformation to the selected regions. # Apply a random intensity scaling transformation to the selected regions.
img_changed = self.randscale_intensity(img_changed) img_changed = self.randscale_intensity(img_changed)
img_changed = img_changed.squeeze(0) # type: ignore # back to shape (H, W)
# Combine the unchanged and modified parts to update the image channel. # Combine the unchanged and modified parts to update the image channel.
data["image"][c] = img_orig + img_changed data["image"][c] = img_orig + img_changed

@ -6,7 +6,7 @@ from typing import Dict, Any, Optional
from monai.metrics.cumulative_average import CumulativeAverage from monai.metrics.cumulative_average import CumulativeAverage
class BaseLoss(abc.ABC): class BaseLoss(nn.Module, abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, params: Optional[BaseModel] = None): def __init__(self, params: Optional[BaseModel] = None):

@ -28,7 +28,7 @@ class BCELossParams(BaseModel):
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
if not self.with_logits: if not self.with_logits:
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
loss_kwargs.pop("with_logits", None) loss_kwargs.pop("with_logits", None)
weight = loss_kwargs.get("weight") weight = loss_kwargs.get("weight")
pos_weight = loss_kwargs.get("pos_weight") pos_weight = loss_kwargs.get("pos_weight")

@ -12,6 +12,7 @@ import fastremap
import fill_voids import fill_voids
from skimage import morphology from skimage import morphology
from skimage.segmentation import find_boundaries from skimage.segmentation import find_boundaries
from scipy.special import expit
from scipy.ndimage import mean, find_objects from scipy.ndimage import mean, find_objects
from monai.data.dataset import Dataset from monai.data.dataset import Dataset
@ -53,10 +54,11 @@ logger = get_logger()
class CellSegmentator: class CellSegmentator:
def __init__(self, config: Config) -> None: def __init__(self, config: Config) -> None:
self._device: torch.device = torch.device(config.dataset_config.common.device or "cpu")
self.__set_seed(config.dataset_config.common.seed) self.__set_seed(config.dataset_config.common.seed)
self.__parse_config(config) self.__parse_config(config)
self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu")
self._scaler = ( self._scaler = (
torch.amp.GradScaler(self._device.type) # type: ignore torch.amp.GradScaler(self._device.type) # type: ignore
if self._dataset_setup.is_training and self._dataset_setup.common.use_amp if self._dataset_setup.is_training and self._dataset_setup.common.use_amp
@ -153,7 +155,7 @@ class CellSegmentator:
# Train dataloader # Train dataloader
train_dataset = self.__get_dataset( train_dataset = self.__get_dataset(
images_dir=os.path.join(train_dir, 'images'), images_dir=os.path.join(train_dir, 'images'),
masks_dir=os.path.join(train_dir, 'masks'), masks_dir=os.path.join(train_dir, 'masks', self._dataset_setup.common.masks_subdir),
transforms=train_transforms, # type: ignore transforms=train_transforms, # type: ignore
size=self._dataset_setup.training.train_size, size=self._dataset_setup.training.train_size,
offset=train_offset, offset=train_offset,
@ -168,7 +170,7 @@ class CellSegmentator:
raise RuntimeError("Validation directory or size is not properly configured.") raise RuntimeError("Validation directory or size is not properly configured.")
valid_dataset = self.__get_dataset( valid_dataset = self.__get_dataset(
images_dir=os.path.join(valid_dir, 'images'), images_dir=os.path.join(valid_dir, 'images'),
masks_dir=os.path.join(valid_dir, 'masks'), masks_dir=os.path.join(valid_dir, 'masks', self._dataset_setup.common.masks_subdir),
transforms=valid_transforms, transforms=valid_transforms,
size=self._dataset_setup.training.valid_size, size=self._dataset_setup.training.valid_size,
offset=valid_offset, offset=valid_offset,
@ -183,7 +185,7 @@ class CellSegmentator:
raise RuntimeError("Test directory or size is not properly configured.") raise RuntimeError("Test directory or size is not properly configured.")
test_dataset = self.__get_dataset( test_dataset = self.__get_dataset(
images_dir=os.path.join(test_dir, 'images'), images_dir=os.path.join(test_dir, 'images'),
masks_dir=os.path.join(test_dir, 'masks'), masks_dir=os.path.join(test_dir, 'masks', self._dataset_setup.common.masks_subdir),
transforms=test_transforms, transforms=test_transforms,
size=self._dataset_setup.training.test_size, size=self._dataset_setup.training.test_size,
offset=test_offset, offset=test_offset,
@ -210,7 +212,7 @@ class CellSegmentator:
else: else:
# Inference mode (no training) # Inference mode (no training)
test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images') test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images')
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks') test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks', self._dataset_setup.common.masks_subdir)
if test_transforms is not None: if test_transforms is not None:
test_dataset = self.__get_dataset( test_dataset = self.__get_dataset(
@ -385,7 +387,7 @@ class CellSegmentator:
batch_counter = 0 batch_counter = 0
for batch in tqdm(self._predict_dataloader, desc="Predicting"): for batch in tqdm(self._predict_dataloader, desc="Predicting"):
# Move input images to the configured device (CPU/GPU) # Move input images to the configured device (CPU/GPU)
inputs = batch["img"].to(self._device) inputs = batch["image"].to(self._device)
# Use automatic mixed precision if enabled in dataset setup # Use automatic mixed precision if enabled in dataset setup
with torch.amp.autocast( # type: ignore with torch.amp.autocast( # type: ignore
@ -443,15 +445,40 @@ class CellSegmentator:
def load_from_checkpoint(self, checkpoint_path: str) -> None: def load_from_checkpoint(self, checkpoint_path: str) -> None:
""" """
Loads model weights from a specified checkpoint into the current model. Loads model weights from a specified checkpoint into the current model,
but only for parameters whose shapes match. Parameters with mismatched
shapes (e.g., classification heads with different output sizes) remain
at their initialized values.
Args: Args:
checkpoint_path (str): Path to the checkpoint file containing the model weights. checkpoint_path (str): Path to the checkpoint file containing the model weights.
""" """
# Load the checkpoint onto the correct device (CPU or GPU) # Load the checkpoint (state_dict) from file onto CPU
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True) checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# Load the state dict into the model, allowing for missing keys # Extract nested state_dict if present
self._model.load_state_dict(checkpoint['state_dict'], strict=False) state_dict = checkpoint.get("state_dict", checkpoint)
# Get the current model's parameter dictionary
model_dict = self._model.state_dict()
# Filter pretrained parameters to those matching in name and shape
pretrained_dict = {
k: v for k, v in state_dict.items()
if k in model_dict and v.size() == model_dict[k].size()
}
# Log how many parameters are loaded, skipped, or missing
skipped = [k for k in state_dict if k not in pretrained_dict]
missing = [k for k in model_dict if k not in pretrained_dict]
logger.info(
f"Loaded {len(pretrained_dict)} parameters;"
f" skipped {len(skipped)} params from checkpoint;"
f" {len(missing)} params remain uninitialized in model."
)
# Update the model's state_dict and load it
model_dict.update(pretrained_dict)
self._model.load_state_dict(model_dict)
def save_checkpoint(self, checkpoint_path: str) -> None: def save_checkpoint(self, checkpoint_path: str) -> None:
@ -461,12 +488,9 @@ class CellSegmentator:
Args: Args:
checkpoint_path (str): Path where the checkpoint file will be saved. checkpoint_path (str): Path where the checkpoint file will be saved.
""" """
# Create a checkpoint dictionary containing the models state_dict
checkpoint = {
'state_dict': self._model.state_dict()
}
# Write the checkpoint to disk # Write the checkpoint to disk
torch.save(checkpoint, checkpoint_path) os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(self._model.state_dict(), checkpoint_path)
def __parse_config(self, config: Config) -> None: def __parse_config(self, config: Config) -> None:
@ -492,15 +516,23 @@ class CellSegmentator:
if scheduler: if scheduler:
logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2)) logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2))
logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2)) logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2))
logger.info("Wandb Config:\n%s", pformat(config.wandb_config.model_dump(), indent=2))
logger.info("==========================================") logger.info("==========================================")
# Initialize model using the model registry # Initialize model using the model registry
self._model = ModelRegistry.get_model_class(model.name)(model.params) self._model = ModelRegistry.get_model_class(model.name)(model.params)
# Loads model weights from a specified checkpoint # Loads model weights from a specified checkpoint
if config.dataset_config.is_training: pretrained_weights = (
if config.dataset_config.training.pretrained_weights: config.dataset_config.training.pretrained_weights
self.load_from_checkpoint(config.dataset_config.training.pretrained_weights) if config.dataset_config.is_training
else config.dataset_config.testing.pretrained_weights
)
if pretrained_weights:
self.load_from_checkpoint(pretrained_weights)
logger.info(f"Loaded pre-trained weights from: {pretrained_weights}")
self._model = self._model.to(self._device)
# Initialize loss criterion if specified # Initialize loss criterion if specified
self._criterion = ( self._criterion = (
@ -555,6 +587,7 @@ class CellSegmentator:
logger.info(f"├─ Seed: {common.seed}") logger.info(f"├─ Seed: {common.seed}")
logger.info(f"├─ Device: {common.device}") logger.info(f"├─ Device: {common.device}")
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}") logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
logger.info(f"├─ Masks subdirectory: {common.masks_subdir}")
logger.info(f"└─ Predictions output dir: {common.predictions_dir}") logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
if config.dataset_config.is_training: if config.dataset_config.is_training:
@ -592,18 +625,21 @@ class CellSegmentator:
logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}") logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}")
logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}") logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}")
wandb_cfg = config.dataset_config.wandb self._wandb_config = config.wandb_config
if wandb_cfg.use_wandb: if self._wandb_config.use_wandb:
logger.info("[W&B]") logger.info("[W&B]")
logger.info(f"├─ Project: {wandb_cfg.project}") logger.info(f"├─ Project: {self._wandb_config.project}")
logger.info(f"├─ Entity: {wandb_cfg.entity}") if self._wandb_config.group:
if wandb_cfg.name: logger.info(f"├─ Group: {self._wandb_config.group}")
logger.info(f"├─ Run name: {wandb_cfg.name}") if self._wandb_config.entity:
if wandb_cfg.tags: logger.info(f"├─ Entity: {self._wandb_config.entity}")
logger.info(f"├─ Tags: {', '.join(wandb_cfg.tags)}") if self._wandb_config.name:
if wandb_cfg.notes: logger.info(f"├─ Run name: {self._wandb_config.name}")
logger.info(f"├─ Notes: {wandb_cfg.notes}") if self._wandb_config.tags:
logger.info(f"└─ Save code: {'yes' if wandb_cfg.save_code else 'no'}") logger.info(f"├─ Tags: {', '.join(self._wandb_config.tags)}")
if self._wandb_config.notes:
logger.info(f"├─ Notes: {self._wandb_config.notes}")
logger.info(f"└─ Save code: {'yes' if self._wandb_config.save_code else 'no'}")
else: else:
logger.info("[W&B] Logging disabled") logger.info("[W&B] Logging disabled")
@ -657,13 +693,19 @@ class CellSegmentator:
ValueError: If dataset is too small for requested size or offset. ValueError: If dataset is too small for requested size or offset.
""" """
# Collect sorted list of image paths # Collect sorted list of image paths
images = sorted(glob.glob(images_dir)) images = sorted(
glob.glob(os.path.join(images_dir, '*.tif')) +
glob.glob(os.path.join(images_dir, '*.tiff'))
)
if not images: if not images:
raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'") raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'")
if masks_dir is not None: if masks_dir is not None:
# Collect and validate sorted list of mask paths # Collect and validate sorted list of mask paths
masks = sorted(glob.glob(masks_dir)) masks = sorted(
glob.glob(os.path.join(masks_dir, '*.tif')) +
glob.glob(os.path.join(masks_dir, '*.tiff'))
)
if len(images) != len(masks): if len(images) != len(masks):
raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})") raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})")
@ -720,7 +762,7 @@ class CellSegmentator:
tablefmt="fancy_grid" tablefmt="fancy_grid"
) )
print(table, "\n") print(table, "\n")
if self._dataset_setup.wandb.use_wandb: if self._wandb_config.use_wandb:
wandb.log(results, step=step) wandb.log(results, step=step)
@ -765,8 +807,8 @@ class CellSegmentator:
# Iterate over batches # Iterate over batches
batch_counter = 0 batch_counter = 0
for batch in tqdm(loader, desc=desc): for batch in tqdm(loader, desc=desc):
inputs = batch["img"].to(self._device) inputs = batch["image"].to(self._device)
targets = batch["label"].to(self._device) targets = batch["mask"].to(self._device)
# Zero gradients for training # Zero gradients for training
if self._optimizer is not None: if self._optimizer is not None:
@ -787,7 +829,10 @@ class CellSegmentator:
flow_targets = self.__compute_flows_from_masks(targets) flow_targets = self.__compute_flows_from_masks(targets)
# Compute loss for this batch # Compute loss for this batch
batch_loss = self._criterion(raw_output, flow_targets) # type: ignore batch_loss = self._criterion(
raw_output,
torch.from_numpy(flow_targets).to(device=raw_output.device)
)
# Post-process and compute F1 during validation and testing # Post-process and compute F1 during validation and testing
if mode in ("valid", "test"): if mode in ("valid", "test"):
@ -842,6 +887,9 @@ class CellSegmentator:
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="micro" tp_array, fp_array, fn_array, reduction="micro"
) )
epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="imagewise"
)
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore
tp_array, fp_array, fn_array, reduction="macro" tp_array, fp_array, fn_array, reduction="macro"
) )
@ -914,7 +962,7 @@ class CellSegmentator:
instance_masks[idx] = self.__segment_instances( instance_masks[idx] = self.__segment_instances(
probability_map=probabilities[idx], probability_map=probabilities[idx],
flow=gradflow[idx], flow=gradflow[idx],
prob_threshold=0.0, prob_threshold=0.5,
flow_threshold=0.4, flow_threshold=0.4,
min_object_size=15 min_object_size=15
) )
@ -1159,7 +1207,7 @@ class CellSegmentator:
Returns: Returns:
np.ndarray: Sigmoid of the input. np.ndarray: Sigmoid of the input.
""" """
return 1 / (1 + np.exp(-z)) return expit(z)
def __save_prediction_masks( def __save_prediction_masks(
@ -1191,7 +1239,7 @@ class CellSegmentator:
# Convert tensors to numpy # Convert tensors to numpy
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy() return x.cpu().numpy()
return x return x
image_array = to_numpy(image_obj) if image_obj is not None else None image_array = to_numpy(image_obj) if image_obj is not None else None
@ -1201,11 +1249,11 @@ class CellSegmentator:
# Handle batch dimension: (B, C, H, W) # Handle batch dimension: (B, C, H, W)
if pred_array.ndim == 4: if pred_array.ndim == 4:
for idx in range(pred_array.shape[0]): for idx in range(pred_array.shape[0]):
batch_sample = dict(sample) batch_sample: Dict[str, Any] = {}
if image_array is not None and image_array.ndim == 4: if image_array is not None and image_array.ndim == 4:
batch_sample["image"] = image_array[idx] batch_sample["image"] = image_array[idx]
if isinstance(image_meta, list): if isinstance(image_meta, dict) and "filename_or_obj" in image_meta:
batch_sample["image_meta_dict"] = image_meta[idx] batch_sample["image_meta_dict"] = image_meta["filename_or_obj"][idx]
if mask_array is not None and mask_array.ndim == 4: if mask_array is not None and mask_array.ndim == 4:
batch_sample["mask"] = mask_array[idx] batch_sample["mask"] = mask_array[idx]
self.__save_prediction_masks( self.__save_prediction_masks(
@ -1216,8 +1264,8 @@ class CellSegmentator:
return return
# Determine base filename # Determine base filename
if image_meta and "filename_or_obj" in image_meta: if isinstance(image_meta, (str, os.PathLike)):
base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0] base_name = os.path.splitext(os.path.basename(image_meta))[0]
else: else:
# Use provided start_index when metadata missing # Use provided start_index when metadata missing
base_name = f"prediction_{start_index:04d}" base_name = f"prediction_{start_index:04d}"
@ -1228,8 +1276,8 @@ class CellSegmentator:
channel_mask = pred_array[channel_idx] channel_mask = pred_array[channel_idx]
# File names # File names
mask_filename = f"{base_name}_ch{channel_idx:02d}.tif" mask_filename = f"{base_name}_ch{channel_idx:01d}.tif"
plot_filename = f"{base_name}_ch{channel_idx:02d}.png" plot_filename = f"{base_name}_ch{channel_idx:01d}.png"
mask_path = os.path.join(masks_dir, mask_filename) mask_path = os.path.join(masks_dir, mask_filename)
plot_path = os.path.join(plots_dir, plot_filename) plot_path = os.path.join(plots_dir, plot_filename)
@ -1402,78 +1450,81 @@ class CellSegmentator:
flows = np.zeros((2*channels, height, width), np.float32) flows = np.zeros((2*channels, height, width), np.float32)
for channel in range(channels): for channel in range(channels):
padded_height, padded_width = height + 2, width + 2 mask_channel = mask[channel]
# Pad the mask with a 1-pixel border
masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device)
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
# Get coordinates of all non-zero pixels in the padded mask
y, x = torch.nonzero(masks_padded, as_tuple=True)
y = y.int(); x = x.int() # ensure integer type
# Generate 8-connected neighbors (including center) via broadcasted offsets
offsets = torch.tensor([
[ 0, 0], # center
[-1, 0], # up
[ 1, 0], # down
[ 0, -1], # left
[ 0, 1], # right
[-1, -1], # up-left
[-1, 1], # up-right
[ 1, -1], # down-left
[ 1, 1], # down-right
], dtype=torch.int32, device=self._device) # (9, 2)
# coords: (N, 2)
coords = torch.stack((y, x), dim=1)
# neighbors: (9, N, 2)
neighbors = offsets[:, None, :] + coords[None, :, :]
# transpose into (2, 9, N) for the GPU kernel
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
# Build connectivity mask: True where neighbor label == center label
center_labels = masks_padded[y, x][None, :] # (1, N)
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
is_neighbor = neighbor_labels == center_labels # (9, N)
# Compute object slices and pack into array for get_centers
slices = find_objects(mask)
slices_arr = np.array([
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
for i, sl in enumerate(slices) if sl is not None
], dtype=int)
# Compute centers (pixel indices) and extents via the provided helper if mask_channel.max() > 0:
centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr) padded_height, padded_width = height + 2, width + 2
# Move centers to GPU and shift by +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding # Pad the mask with a 1-pixel border
masks_padded = torch.from_numpy(mask_channel.astype(np.int64)).to(self._device)
# Determine number of diffusion iterations masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
n_iter = 2 * ext.max()
# Get coordinates of all non-zero pixels in the padded mask
# Run the GPU diffusion kernel y, x = torch.nonzero(masks_padded, as_tuple=True)
mu = self.__propagate_centers_gpu( y = y.int(); x = x.int() # ensure integer type
neighbor_indices=neighbors,
center_indices=meds_p.T, # Generate 8-connected neighbors (including center) via broadcasted offsets
valid_neighbor_mask=is_neighbor, offsets = torch.tensor([
output_shape=(padded_height, padded_width), [ 0, 0], # center
num_iterations=n_iter [-1, 0], # up
) [ 1, 0], # down
[ 0, -1], # left
[ 0, 1], # right
[-1, -1], # up-left
[-1, 1], # up-right
[ 1, -1], # down-left
[ 1, 1], # down-right
], dtype=torch.int32, device=self._device) # (9, 2)
# coords: (N, 2)
coords = torch.stack((y, x), dim=1)
# neighbors: (9, N, 2)
neighbors = offsets[:, None, :] + coords[None, :, :]
# transpose into (2, 9, N) for the GPU kernel
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
# Build connectivity mask: True where neighbor label == center label
center_labels = masks_padded[y, x][None, :] # (1, N)
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
is_neighbor = neighbor_labels == center_labels # (9, N)
# Compute object slices and pack into array for get_centers
slices = find_objects(mask_channel)
slices_arr = np.array([
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
for i, sl in enumerate(slices, start=1) if sl is not None
], dtype=np.int16)
# Compute centers (pixel indices) and extents via the provided helper
centers, ext = self.__get_mask_centers_and_extents(mask_channel, slices_arr)
# Move centers to GPU and shift by +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding
# Determine number of diffusion iterations
n_iter = 2 * ext.max()
# Run the GPU diffusion kernel
mu = self.__propagate_centers_gpu(
neighbor_indices=neighbors,
center_indices=meds_p.T,
valid_neighbor_mask=is_neighbor,
output_shape=(padded_height, padded_width),
num_iterations=n_iter
)
# Cast to float64 and normalize flow vectors # Cast to float64 and normalize flow vectors
mu = mu.astype(np.float64) mu = mu.astype(np.float64)
mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60 mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60
# Remove the padding and write into final output
flow_output = np.zeros((2, height, width), dtype=np.float32)
ys_np = y.cpu().numpy() - 1
xs_np = x.cpu().numpy() - 1
flow_output[:, ys_np, xs_np] = mu
flows[2*channel: 2*channel + 2] = flow_output
# Remove the padding and write into final output
flow_output = np.zeros((2, height, width), dtype=np.float32)
ys_np = y.cpu().numpy() - 1
xs_np = x.cpu().numpy() - 1
flow_output[:, ys_np, xs_np] = mu
flows[2*channel: 2*channel + 2] = flow_output
return flows return flows
@ -1624,8 +1675,10 @@ class CellSegmentator:
# Initialize flow_field with two channels: dy and dx # Initialize flow_field with two channels: dy and dx
flow_field = np.zeros((2, height, width), dtype=np.float64) flow_field = np.zeros((2, height, width), dtype=np.float64)
mask_channel = mask[channel]
# Find bounding box for each labeled mask # Find bounding box for each labeled mask
mask_slices = find_objects(mask) mask_slices = find_objects(mask_channel)
# centers: List[Tuple[int, int]] = [] # centers: List[Tuple[int, int]] = []
# Iterate over mask labels in parallel # Iterate over mask labels in parallel
@ -1642,7 +1695,7 @@ class CellSegmentator:
# Get local coordinates of mask pixels within the patch # Get local coordinates of mask pixels within the patch
local_rows, local_cols = np.nonzero( local_rows, local_cols = np.nonzero(
mask[row_slice, col_slice] == (label_idx + 1) mask_channel[row_slice, col_slice] == (label_idx + 1)
) )
# Shift coords by +1 for the border padding # Shift coords by +1 for the border padding
local_rows = local_rows.astype(np.int32) + 1 local_rows = local_rows.astype(np.int32) + 1
@ -1774,8 +1827,8 @@ class CellSegmentator:
Generate instance segmentation masks from probability and flow fields. Generate instance segmentation masks from probability and flow fields.
Args: Args:
probability_map: 3D array (channels, height, width) of cell probabilities. probability_map: 3D array `(C, H, W)` of cell probabilities.
flow: 3D array (2*channels, height, width) of forward flow vectors. flow: 3D array `(2*C, H, W)` of forward flow vectors.
prob_threshold: threshold to binarize probability_map. (Default 0.0) prob_threshold: threshold to binarize probability_map. (Default 0.0)
flow_threshold: threshold for filtering bad flow masks. (Default 0.4) flow_threshold: threshold for filtering bad flow masks. (Default 0.4)
num_iters: number of iterations for flow-following. (Default 200) num_iters: number of iterations for flow-following. (Default 200)
@ -1802,6 +1855,9 @@ class CellSegmentator:
# Extract binary mask for this channel # Extract binary mask for this channel
channel_mask = probability_mask[channel_index] channel_mask = probability_mask[channel_index]
if not channel_mask.sum():
continue
nonzero_coords = np.stack(np.nonzero(channel_mask)) nonzero_coords = np.stack(np.nonzero(channel_mask))
# Follow the flow vectors to generate coordinate mappings # Follow the flow vectors to generate coordinate mappings
@ -1810,12 +1866,6 @@ class CellSegmentator:
initial_coords=nonzero_coords, initial_coords=nonzero_coords,
num_iters=num_iters num_iters=num_iters
) )
# If flow following fails, leave this channel empty
if flow_coordinates is None:
labeled_instances[channel_index] = np.zeros(
probability_map.shape[1:], dtype=np.uint16
)
continue
if not torch.is_tensor(flow_coordinates): if not torch.is_tensor(flow_coordinates):
flow_coordinates = torch.from_numpy( flow_coordinates = torch.from_numpy(
@ -1851,11 +1901,6 @@ class CellSegmentator:
) )
labeled_instances[channel_index] = channel_instances_mask labeled_instances[channel_index] = channel_instances_mask
else:
# No valid instances found, leave the channel empty
labeled_instances[channel_index] = np.zeros(
probability_map.shape[1:], dtype=np.uint16
)
return labeled_instances return labeled_instances
@ -1923,7 +1968,7 @@ class CellSegmentator:
) )
# Update each coordinate and clamp to valid range # Update each coordinate and clamp to valid range
for i in range(dims): for i in range(dims):
pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0) pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0)
# Denormalize back to original pixel coordinates # Denormalize back to original pixel coordinates
pts = (pts + 1) * 0.5 * max_idx_pt pts = (pts + 1) * 0.5 * max_idx_pt
@ -2072,16 +2117,21 @@ class CellSegmentator:
raise raise
# Step 3: Find peaks via 5x5 max-pooling # Step 3: Find peaks via 5x5 max-pooling
k = 5 # k = 5
pooled = F.max_pool2d( # pooled = F.max_pool2d(
# histogram.float().unsqueeze(0).unsqueeze(1),
# kernel_size=k,
# stride=1,
# padding=k // 2
# ).squeeze()
pooled = self.__max_pool_nd(
histogram.unsqueeze(0), histogram.unsqueeze(0),
kernel_size=k, kernel_size=5
stride=1,
padding=k // 2
).squeeze() ).squeeze()
# Seeds are positions where histogram equals local max and count > threshold # Seeds are positions where histogram equals local max and count > threshold
seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10)) seed_positions = torch.nonzero((histogram - pooled == 0) * (histogram > 10))
if seed_positions.numel() == 0: if seed_positions.shape[0] == 0:
logger.warning("No seeds found: returning empty mask") logger.warning("No seeds found: returning empty mask")
return np.zeros(original_shape, dtype=np.uint16) return np.zeros(original_shape, dtype=np.uint16)
@ -2106,13 +2156,14 @@ class CellSegmentator:
seed_masks[:, 5, 5] = 1 seed_masks[:, 5, 5] = 1
# Iterative dilation and thresholding # Iterative dilation and thresholding
for _ in range(5): for _ in range(5):
seed_masks = F.max_pool2d( # seed_masks = F.max_pool2d(
seed_masks, # seed_masks.float().unsqueeze(0),
kernel_size=3, # kernel_size=3,
stride=1, # stride=1,
padding=1 # padding=1
) # ).squeeze(0).int()
seed_masks = seed_masks & (patches > 2) seed_masks = self.__max_pool_nd(seed_masks, kernel_size=3)
seed_masks *= (patches > 2)
# Compute final mask coordinates # Compute final mask coordinates
final_coords = [] final_coords = []
for idx in range(num_seeds): for idx in range(num_seeds):
@ -2133,7 +2184,7 @@ class CellSegmentator:
# Step 6: Map to original image and remove oversized masks # Step 6: Map to original image and remove oversized masks
mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32) mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32)
mask_final[valid_indices] = mask_values mask_final[tuple(valid_indices)] = mask_values
# Prune masks that are too large # Prune masks that are too large
labels, counts = fastremap.unique(mask_final, return_counts=True) labels, counts = fastremap.unique(mask_final, return_counts=True)
@ -2146,6 +2197,96 @@ class CellSegmentator:
return mask_final return mask_final
def __max_pool1d(
self,
input_tensor: Tensor,
kernel_size: int = 5,
axis: int = 1,
output_tensor: Optional[Tensor] = None
) -> Tensor:
"""
Memory-efficient 1D max pooling along a specified axis using in-place updates.
Requires:
- stride = 1
- padding = kernel_size // 2
- odd kernel_size >= 3
Args:
input_tensor (Tensor): Source tensor for pooling.
kernel_size (int): Size of the pooling window (must be odd and >= 3).
axis (int): Axis along which to compute 1D max pooling.
output_tensor (Optional[Tensor]): Tensor to store the result.
If None, a clone of input_tensor is used.
Returns:
Tensor: The pooled tensor, same shape as input_tensor.
"""
# Initialize or copy data into the output tensor
if output_tensor is None:
output = input_tensor.clone()
else:
output = output_tensor
output.copy_(input_tensor)
# Number of elements along the chosen axis and half-window size
dimension_size = input_tensor.shape[axis]
half_window = kernel_size // 2
# Slide window offsets from -half_window to +half_window
for offset in range(-half_window, half_window + 1):
# Compute slice indices depending on axis
if axis == 1:
target_slice = output[:, max(-offset, 0): min(dimension_size - offset, dimension_size)]
source_slice = input_tensor[:, max(offset, 0): min(dimension_size + offset, dimension_size)]
elif axis == 2:
target_slice = output[:, :, max(-offset, 0): min(dimension_size - offset, dimension_size)]
source_slice = input_tensor[:, :, max(offset, 0): min(dimension_size + offset, dimension_size)]
elif axis == 3:
target_slice = output[:, :, :, max(-offset, 0): min(dimension_size - offset, dimension_size)]
source_slice = input_tensor[:, :, :, max(offset, 0): min(dimension_size + offset, dimension_size)]
else:
raise ValueError(f"Unsupported axis {axis} for 1D pooling")
# In-place element-wise maximum
torch.maximum(target_slice, source_slice, out=target_slice)
return output
def __max_pool_nd(
self,
input_tensor: Tensor,
kernel_size: int = 5
) -> Tensor:
"""
Memory-efficient N-dimensional max pooling for 2D or 3D spatial data.
Applies 1D max pooling sequentially over each spatial axis.
Args:
input_tensor (Tensor): Input tensor with shape
(batch_size, dim1, dim2, ..., dimN).
kernel_size (int): Size of the pooling window (must be odd and >= 3).
Returns:
Tensor: The pooled tensor, same shape as input_tensor.
"""
# Determine number of spatial dimensions (excluding batch axis)
num_spatial_dims = input_tensor.ndim - 1
# First pass: pool along axis=1
pooled = self.__max_pool1d(input_tensor, kernel_size=kernel_size, axis=1)
# Second pass: pool along axis=2
pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=2)
# If 3D data, apply a third pass along axis=3
if num_spatial_dims == 3:
pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=3)
elif num_spatial_dims != 2:
raise ValueError("max_pool_nd only supports 2D or 3D spatial data")
return pooled
def __remove_inconsistent_flow_masks( def __remove_inconsistent_flow_masks(
self, self,
mask: np.ndarray, mask: np.ndarray,
@ -2175,11 +2316,7 @@ class CellSegmentator:
""" """
# If mask is very large and running on CUDA, check memory # If mask is very large and running on CUDA, check memory
num_pixels = mask.size num_pixels = mask.size
if ( if num_pixels > 10000 * 10000 and self._device.type == 'cuda':
num_pixels > 10000 * 10000
and self._device.type == 'cuda'
):
# Clear unused GPU cache # Clear unused GPU cache
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Determine PyTorch version # Determine PyTorch version
@ -2296,15 +2433,8 @@ class CellSegmentator:
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
# Optionally remove masks smaller than minimum_size # Initial pruning of too-small masks
if minimum_size >= 0: masks = self._prune_small_masks(masks, minimum_size)
# Compute label counts (skipping background at index 0)
labels, counts = fastremap.unique(masks, return_counts=True)
# Identify labels to remove: those with count < minimum_size
small_labels = labels[counts < minimum_size]
if small_labels.size > 0:
masks = fastremap.mask(masks, small_labels)
fastremap.renumber(masks, in_place=True)
# Find bounding boxes for each mask label # Find bounding boxes for each mask label
object_slices = find_objects(masks) object_slices = find_objects(masks)
@ -2325,12 +2455,36 @@ class CellSegmentator:
output_masks[slc][filled_region] = new_label output_masks[slc][filled_region] = new_label
new_label += 1 new_label += 1
# Final pruning of small masks after filling (optional) # Final pruning after hole filling
if minimum_size >= 0: output_masks = self._prune_small_masks(output_masks, minimum_size)
labels, counts = fastremap.unique(output_masks, return_counts=True) return output_masks
small_labels = labels[counts < minimum_size]
if small_labels.size > 0:
output_masks = fastremap.mask(output_masks, small_labels) def _prune_small_masks(
fastremap.renumber(output_masks, in_place=True) self,
masks: np.ndarray,
minimum_size: int
) -> np.ndarray:
"""
Remove labeled regions in `masks` whose pixel count is below `minimum_size`.
return output_masks Args:
masks (np.ndarray): Integer mask array (any shape), 0=background.
minimum_size (int): Minimum pixel count; labels smaller are removed. If <0, skip pruning.
Returns:
np.ndarray: Mask array with small labels suppressed and labels renumbered.
"""
if minimum_size < 0:
return masks
labels, counts = fastremap.unique(masks, return_counts=True)
# Skip background label at index 0
non_bg_labels = labels[1:]
non_bg_counts = counts[1:]
# Identify labels to remove
small_labels = non_bg_labels[non_bg_counts < minimum_size]
if small_labels.size > 0:
masks = fastremap.mask(masks, small_labels)
fastremap.renumber(masks, in_place=True)
return masks

@ -11,6 +11,8 @@ from skimage import segmentation
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from typing import Dict, List, Tuple, Any, Union from typing import Dict, List, Tuple, Any, Union
from core.logger import get_logger
__all__ = [ __all__ = [
"compute_batch_segmentation_f1_metrics", "compute_batch_segmentation_average_precision_metrics", "compute_batch_segmentation_f1_metrics", "compute_batch_segmentation_average_precision_metrics",
"compute_batch_segmentation_tp_fp_fn", "compute_batch_segmentation_tp_fp_fn",
@ -18,7 +20,9 @@ __all__ = [
"compute_segmentation_tp_fp_fn", "compute_segmentation_tp_fp_fn",
"compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score" "compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score"
] ]
logger = get_logger()
def compute_f1_score( def compute_f1_score(
true_positives: int, true_positives: int,
@ -92,7 +96,7 @@ def compute_confusion_matrix(
# If no predictions were made, return zeros (with a printout for debugging). # If no predictions were made, return zeros (with a printout for debugging).
if num_predictions == 0: if num_predictions == 0:
print("No segmentation results!") logger.warning("No segmentation results!")
return 0, 0, 0 return 0, 0, 0
# Compute the IoU matrix and ignore the background (first row and column). # Compute the IoU matrix and ignore the background (first row and column).
@ -586,7 +590,7 @@ def _process_instance_matching(
# If no predictions are found, return with all ground truth as false negatives. # If no predictions are found, return with all ground truth as false negatives.
if num_prediction == 0: if num_prediction == 0:
print("No segmentation results!") logger.warning("No segmentation results!")
result = {'tp': 0, 'fp': 0, 'fn': num_ground_truth} result = {'tp': 0, 'fp': 0, 'fn': num_ground_truth}
if return_masks: if return_masks:
tp_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8) tp_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8)

@ -1,8 +1,7 @@
import os import os
from typing import Tuple from typing import Tuple
from config.config import * from config import Config, WandbConfig, DatasetConfig, ComponentConfig
from config.dataset_config import DatasetConfig
from core import ( from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
@ -47,7 +46,8 @@ def main():
if is_training is False: if is_training is False:
config = Config( config = Config(
model=ComponentConfig(name=chosen_model, params=model_instance), model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config dataset_config=dataset_config,
wandb_config=WandbConfig()
) )
# Construct a base filename from the selected registry names. # Construct a base filename from the selected registry names.
@ -76,6 +76,7 @@ def main():
config = Config( config = Config(
model=ComponentConfig(name=chosen_model, params=model_instance), model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config, dataset_config=dataset_config,
wandb_config=WandbConfig(),
criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance), criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance),
optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance), optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance),
scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance) scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance)

@ -1,41 +1,78 @@
import os import os
import argparse
import wandb import wandb
from config.config import Config from config import Config
from core.data import * from core.data import *
from core.segmentator import CellSegmentator from core.segmentator import CellSegmentator
if __name__ == "__main__": def main():
config_path = 'config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json' parser = argparse.ArgumentParser(
# config_path = 'config/templates/predict/ModelV.json' description="Train or predict cell segmentator with specified config file."
)
parser.add_argument(
'-c', '--config',
type=str,
default='config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json',
help='Path to the JSON config file'
)
parser.add_argument(
'-m', '--mode',
choices=['train', 'test', 'predict'],
default='train',
help='Run mode: train, test or predict'
)
args = parser.parse_args()
mode = args.mode
config_path = args.config
config = Config.load_json(config_path) config = Config.load_json(config_path)
# config = Config.load_json(config_path)
if config.dataset_config.wandb.use_wandb:
# Initialize W&B
wandb.init(config=config.asdict(), **config.dataset_config.wandb.asdict())
if mode == 'train' and not config.dataset_config.is_training:
raise ValueError(
f"Config is not set for training (is_training=False), but mode 'train' was requested."
)
if mode in ('test', 'predict') and config.dataset_config.is_training:
raise ValueError(
f"Config is set for training (is_training=True), but mode '{mode}' was requested."
)
if config.wandb_config.use_wandb:
# Initialize W&B
wandb.init(config=config.asdict(), **config.wandb_config.asdict())
# How many batches to wait before logging training status # How many batches to wait before logging training status
wandb.config.log_interval = 10 wandb.config.log_interval = 10
segmentator = CellSegmentator(config) segmentator = CellSegmentator(config)
segmentator.create_dataloaders() segmentator.create_dataloaders(
train_transforms=get_train_transforms() if mode == "train" else None,
valid_transforms=get_valid_transforms() if mode == "train" else None,
test_transforms=get_test_transforms() if mode in ("train", "test") else None,
predict_transforms=get_predict_transforms() if mode == "predict" else None
)
# Watch parameters & gradients of model # Watch parameters & gradients of model
if config.dataset_config.wandb.use_wandb: if config.wandb_config.use_wandb:
wandb.watch(segmentator._model, log="all", log_graph=True) wandb.watch(segmentator._model, log="all", log_graph=True)
# Run training (or prediction, if implemented)
segmentator.run() segmentator.run()
weights_dir = "weights" if not config.dataset_config.wandb.use_wandb else wandb.run.dir # type: ignore if config.dataset_config.is_training:
saving_path = os.path.join( # Prepare saving path
weights_dir, os.path.basename(config.dataset_config.common.predictions_dir) + '.pth' weights_dir = (
) wandb.run.dir if config.wandb_config.use_wandb else "weights" # type: ignore
segmentator.save_checkpoint(saving_path) )
saving_path = os.path.join(
if config.dataset_config.wandb.use_wandb: weights_dir,
wandb.save(saving_path) os.path.basename(config.dataset_config.common.predictions_dir) + '.pth'
)
segmentator.save_checkpoint(saving_path)
if config.wandb_config.use_wandb:
wandb.save(saving_path)
if __name__ == "__main__":
main()

Loading…
Cancel
Save