refactor: migrate all type annotations from Python 3.9 to 3.10 syntax

Replaced typing module generics like List, Dict, Tuple with built-in alternatives (list, dict, tuple).
Updated code to use new union syntax (X | Y) instead of Union[X, Y].
master
laynholt 1 month ago
parent 40c21d5456
commit b0e7b21a21

@ -1,5 +1,5 @@
import json import json
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from .wandb_config import WandbConfig from .wandb_config import WandbConfig
@ -13,7 +13,7 @@ class ComponentConfig(BaseModel):
name: str name: str
params: BaseModel params: BaseModel
def dump(self) -> Dict[str, Any]: def dump(self) -> dict[str, Any]:
""" """
Recursively serializes the component into a dictionary. Recursively serializes the component into a dictionary.
@ -24,22 +24,18 @@ class ComponentConfig(BaseModel):
params_dump = self.params.model_dump() params_dump = self.params.model_dump()
else: else:
params_dump = self.params params_dump = self.params
return { return {"name": self.name, "params": params_dump}
"name": self.name,
"params": params_dump
}
class Config(BaseModel): class Config(BaseModel):
model: ComponentConfig model: ComponentConfig
dataset_config: DatasetConfig dataset_config: DatasetConfig
wandb_config: WandbConfig wandb_config: WandbConfig
criterion: Optional[ComponentConfig] = None criterion: ComponentConfig | None = None
optimizer: Optional[ComponentConfig] = None optimizer: ComponentConfig | None = None
scheduler: Optional[ComponentConfig] = None scheduler: ComponentConfig | None = None
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
""" """
Produce a JSONserializable dict of this config, including nested Produce a JSONserializable dict of this config, including nested
ComponentConfig and DatasetConfig entries. Useful for saving to file ComponentConfig and DatasetConfig entries. Useful for saving to file
@ -49,7 +45,7 @@ class Config(BaseModel):
A dict with keys 'model', 'dataset_config', and (if set) A dict with keys 'model', 'dataset_config', and (if set)
'criterion', 'optimizer', 'scheduler'. 'criterion', 'optimizer', 'scheduler'.
""" """
data: Dict[str, Any] = { data: dict[str, Any] = {
"model": self.model.dump(), "model": self.model.dump(),
"dataset_config": self.dataset_config.model_dump(), "dataset_config": self.dataset_config.model_dump(),
} }
@ -62,7 +58,6 @@ class Config(BaseModel):
data["wandb"] = self.wandb_config.model_dump() data["wandb"] = self.wandb_config.model_dump()
return data return data
def save_json(self, file_path: str, indent: int = 4) -> None: def save_json(self, file_path: str, indent: int = 4) -> None:
""" """
Save this config to a JSON file. Save this config to a JSON file.
@ -75,7 +70,6 @@ class Config(BaseModel):
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(config_dict, indent=indent)) f.write(json.dumps(config_dict, indent=indent))
@classmethod @classmethod
def load_json(cls, file_path: str) -> "Config": def load_json(cls, file_path: str) -> "Config":
""" """
@ -96,10 +90,12 @@ class Config(BaseModel):
wandb_config = WandbConfig(**data.get("wandb", {})) wandb_config = WandbConfig(**data.get("wandb", {}))
# Helper function to parse registry fields. # Helper function to parse registry fields.
def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]: def parse_field(
component_data: dict[str, Any], registry_getter
) -> ComponentConfig | None:
name = component_data.get("name") name = component_data.get("name")
params_data = component_data.get("params", {}) params_data = component_data.get("params", {})
if name is not None: if name is not None:
expected = registry_getter(name) expected = registry_getter(name)
params = expected(**params_data) params = expected(**params_data)
@ -107,16 +103,31 @@ class Config(BaseModel):
return None return None
from core import ( from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry ModelRegistry,
CriterionRegistry,
OptimizerRegistry,
SchedulerRegistry,
) )
parsed_model = parse_field(data.get("model", {}), lambda key: ModelRegistry.get_model_params(key)) parsed_model = parse_field(
parsed_criterion = parse_field(data.get("criterion", {}), lambda key: CriterionRegistry.get_criterion_params(key)) data.get("model", {}),
parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key)) lambda key: ModelRegistry.get_model_params(key),
parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key)) )
parsed_criterion = parse_field(
data.get("criterion", {}),
lambda key: CriterionRegistry.get_criterion_params(key),
)
parsed_optimizer = parse_field(
data.get("optimizer", {}),
lambda key: OptimizerRegistry.get_optimizer_params(key),
)
parsed_scheduler = parse_field(
data.get("scheduler", {}),
lambda key: SchedulerRegistry.get_scheduler_params(key),
)
if parsed_model is None: if parsed_model is None:
raise ValueError('Failed to load model information') raise ValueError("Failed to load model information")
return cls( return cls(
model=parsed_model, model=parsed_model,
@ -124,5 +135,5 @@ class Config(BaseModel):
criterion=parsed_criterion, criterion=parsed_criterion,
optimizer=parsed_optimizer, optimizer=parsed_optimizer,
scheduler=parsed_scheduler, scheduler=parsed_scheduler,
wandb_config=wandb_config wandb_config=wandb_config,
) )

@ -1,5 +1,5 @@
from pydantic import BaseModel, model_validator, field_validator from pydantic import BaseModel, model_validator, field_validator
from typing import Any, Dict, Optional, Union from typing import Any
import os import os
@ -7,7 +7,7 @@ class DatasetCommonConfig(BaseModel):
""" """
Common configuration fields shared by both training and testing. Common configuration fields shared by both training and testing.
""" """
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) seed: int | None = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda') device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
roi_size: int = 512 # The size of the square window for cropping roi_size: int = 512 # The size of the square window for cropping
@ -65,9 +65,9 @@ class DatasetTrainingConfig(BaseModel):
pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo() pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo()
split: TrainingSplitInfo = TrainingSplitInfo() split: TrainingSplitInfo = TrainingSplitInfo()
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic) train_size: int | float = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.1 # Validation data size (int for static, float in (0,1] for dynamic) valid_size: int | float = 0.1 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.2 # Testing data size (int for static, float in (0,1] for dynamic) test_size: int | float = 0.2 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data test_offset: int = 0 # Offset for testing data
@ -78,7 +78,7 @@ class DatasetTrainingConfig(BaseModel):
@field_validator("train_size", "valid_size", "test_size", mode="before") @field_validator("train_size", "valid_size", "test_size", mode="before")
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]: def validate_sizes(cls, v: int | float) -> int | float:
""" """
Validates size values: Validates size values:
- If provided as a float, must be in the range (0, 1]. - If provided as a float, must be in the range (0, 1].
@ -145,12 +145,12 @@ class DatasetTestingConfig(BaseModel):
Configuration fields used only in testing mode. Configuration fields used only in testing mode.
""" """
test_dir: str = "." # Test data directory; must be non-empty test_dir: str = "." # Test data directory; must be non-empty
test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic) test_size: int | float = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
test_offset: int = 0 # Offset for testing data test_offset: int = 0 # Offset for testing data
shuffle: bool = True # Shuffle data shuffle: bool = True # Shuffle data
@field_validator("test_size", mode="before") @field_validator("test_size", mode="before")
def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]: def validate_test_size(cls, v: int | float) -> int | float:
""" """
Validates the test_size value. Validates the test_size value.
""" """
@ -224,7 +224,7 @@ class DatasetConfig(BaseModel):
raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}") raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}")
return self return self
def model_dump(self, **kwargs) -> Dict[str, Any]: def model_dump(self, **kwargs) -> dict[str, Any]:
""" """
Dumps only the relevant configuration depending on the is_training flag. Dumps only the relevant configuration depending on the is_training flag.
Only the nested configuration (training or testing) along with common fields is returned. Only the nested configuration (training or testing) along with common fields is returned.

@ -1,19 +1,19 @@
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from typing import Any, Dict, Optional from typing import Any
class WandbConfig(BaseModel): class WandbConfig(BaseModel):
""" """
Configuration for Weights & Biases logging. Configuration for Weights & Biases logging.
""" """
use_wandb: bool = False # Whether to enable WandB logging use_wandb: bool = False # Whether to enable WandB logging
project: Optional[str] = None # WandB project name project: str | None = None # WandB project name
group: Optional[str] = None # WandB group name group: str | None = None # WandB group name
entity: Optional[str] = None # WandB entity (user or team) entity: str | None = None # WandB entity (user or team)
name: Optional[str] = None # Name of the run name: str | None = None # Name of the run
tags: Optional[list[str]] = None # List of tags for the run tags: list[str] | None = None # List of tags for the run
notes: Optional[str] = None # Notes or description for the run notes: str | None = None # Notes or description for the run
save_code: bool = True # Whether to save the code to WandB save_code: bool = True # Whether to save the code to WandB
@model_validator(mode="after") @model_validator(mode="after")
def validate_wandb(self) -> "WandbConfig": def validate_wandb(self) -> "WandbConfig":
@ -22,7 +22,7 @@ class WandbConfig(BaseModel):
raise ValueError("When use_wandb=True, 'project' must be provided") raise ValueError("When use_wandb=True, 'project' must be provided")
return self return self
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
""" """
Return a dict of all W&B parameters, excluding 'use_wandb' and any None values. Return a dict of all W&B parameters, excluding 'use_wandb' and any None values.
""" """

@ -1,6 +1,6 @@
from .cell_aware import IntensityDiversification from .cell_aware import IntensityDiversification
from .load_image import CustomLoadImage, CustomLoadImaged from .load_image import CustomLoadImaged
from .normalize_image import CustomNormalizeImage, CustomNormalizeImaged from .normalize_image import CustomNormalizeImaged
from monai.transforms import * # type: ignore from monai.transforms import * # type: ignore

@ -1,7 +1,7 @@
import copy import copy
import torch import torch
import numpy as np import numpy as np
from typing import Dict, Sequence, Tuple, Union from typing import Sequence
from skimage.segmentation import find_boundaries from skimage.segmentation import find_boundaries
from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore
@ -26,14 +26,14 @@ class BoundaryExclusion(MapTransform):
def __init__(self, keys: Sequence[str] = ("mask",), allow_missing_keys: bool = False) -> None: def __init__(self, keys: Sequence[str] = ("mask",), allow_missing_keys: bool = False) -> None:
""" """
Args: Args:
keys (Sequence[str]): Keys in the input dictionary corresponding to the label image. keys (Sequence(str)): Keys in the input dictionary corresponding to the label image.
Default is ("mask",). Default is ("mask",).
allow_missing_keys (bool): If True, missing keys in the input will be ignored. allow_missing_keys (bool): If True, missing keys in the input will be ignored.
Default is False. Default is False.
""" """
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: def __call__(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
""" """
Apply the boundary exclusion transform to the label image. Apply the boundary exclusion transform to the label image.
@ -46,10 +46,10 @@ class BoundaryExclusion(MapTransform):
6. Assigning the transformed label back into the input dictionary. 6. Assigning the transformed label back into the input dictionary.
Args: Args:
data (Dict[str, np.ndarray]): Dictionary containing at least the "mask" key with a label image. data (Dict(str, np.ndarray)): Dictionary containing at least the "mask" key with a label image.
Returns: Returns:
Dict[str, np.ndarray]: The input dictionary with the "mask" key updated after boundary exclusion. Dict(str, np.ndarray): The input dictionary with the "mask" key updated after boundary exclusion.
""" """
# Retrieve the original label image. # Retrieve the original label image.
label_original: np.ndarray = data["mask"] label_original: np.ndarray = data["mask"]
@ -100,17 +100,17 @@ class IntensityDiversification(MapTransform):
self, self,
keys: Sequence[str] = ("image",), keys: Sequence[str] = ("image",),
change_cell_ratio: float = 0.4, change_cell_ratio: float = 0.4,
scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7), scale_factors: tuple[float, float] | float = (0.0, 0.7),
allow_missing_keys: bool = False, allow_missing_keys: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
keys (Sequence[str]): Keys in the input dictionary corresponding to the image. keys (Sequence(str)): Keys in the input dictionary corresponding to the image.
Default is ("image",). Default is ("image",).
change_cell_ratio (float): Ratio of cells to apply the intensity scaling. change_cell_ratio (float): Ratio of cells to apply the intensity scaling.
For example, 0.4 means 40% of the cells will be transformed. For example, 0.4 means 40% of the cells will be transformed.
Default is 0.4. Default is 0.4.
scale_factors (Sequence[float]): Factors used for random intensity scaling. scale_factors (tuple(float, float) | float): Factors used for random intensity scaling.
Default is (0.0, 0.7). Default is (0.0, 0.7).
allow_missing_keys (bool): If True, missing keys in the input will be ignored. allow_missing_keys (bool): If True, missing keys in the input will be ignored.
Default is False. Default is False.
@ -120,7 +120,7 @@ class IntensityDiversification(MapTransform):
# Compose a random intensity scaling transform with 100% probability. # Compose a random intensity scaling transform with 100% probability.
self.randscale_intensity = Compose([RandScaleIntensity(prob=1.0, factors=scale_factors)]) self.randscale_intensity = Compose([RandScaleIntensity(prob=1.0, factors=scale_factors)])
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: def __call__(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
""" """
Apply a cell-wise intensity diversification transform to an input image. Apply a cell-wise intensity diversification transform to an input image.
@ -141,12 +141,12 @@ class IntensityDiversification(MapTransform):
9. Combine the unchanged and modified parts to update the image for that channel. 9. Combine the unchanged and modified parts to update the image for that channel.
Args: Args:
data (Dict[str, np.ndarray]): A dictionary containing: data (dict(str, np.ndarray)): A dictionary containing:
- "image": The original image array. - "image": The original image array.
- "mask": The corresponding cell label image array. - "mask": The corresponding cell label image array.
Returns: Returns:
Dict[str, np.ndarray]: The updated dictionary with the "image" key modified after applying dict(str, np.ndarray): The updated dictionary with the "image" key modified after applying
the intensity transformation. the intensity transformation.
Raises: Raises:

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import tifffile as tif import tifffile as tif
import skimage.io as io import skimage.io as io
from typing import Final, List, Optional, Sequence, Type, Union from typing import Final, Sequence, Type
from monai.utils.enums import PostFix from monai.utils.enums import PostFix
from monai.utils.module import optional_import from monai.utils.module import optional_import
@ -45,7 +45,7 @@ class CustomLoadImage(LoadImage):
""" """
def __init__( def __init__(
self, self,
reader: Optional[Union[ImageReader, Type[ImageReader], str]] = None, reader: ImageReader | Type[ImageReader] | str | None = None,
image_only: bool = False, image_only: bool = False,
dtype: DtypeLike = np.float32, dtype: DtypeLike = np.float32,
ensure_channel_first: bool = False, ensure_channel_first: bool = False,
@ -75,9 +75,9 @@ class CustomLoadImaged(LoadImaged):
def __init__( def __init__(
self, self,
keys: KeysCollection, keys: KeysCollection,
reader: Optional[Union[Type[ImageReader], str]] = None, reader: Type[ImageReader] | str | None = None,
dtype: DtypeLike = np.float32, dtype: DtypeLike = np.float32,
meta_keys: Optional[KeysCollection] = None, meta_keys: KeysCollection | None = None,
meta_key_postfix: str = DEFAULT_POST_FIX, meta_key_postfix: str = DEFAULT_POST_FIX,
overwriting: bool = False, overwriting: bool = False,
image_only: bool = False, image_only: bool = False,
@ -141,13 +141,13 @@ class UniversalImageReader(NumpyReader):
(e.g., repeating or cropping channels). (e.g., repeating or cropping channels).
""" """
def __init__( def __init__(
self, channel_dim: Optional[int] = None, **kwargs, self, channel_dim: int | None = None, **kwargs,
): ) -> None:
super().__init__(channel_dim=channel_dim, **kwargs) super().__init__(channel_dim=channel_dim, **kwargs)
self.kwargs = kwargs self.kwargs = kwargs
self.channel_dim = channel_dim self.channel_dim = channel_dim
def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
""" """
Check if the file format is supported for reading. Check if the file format is supported for reading.
@ -155,7 +155,7 @@ class UniversalImageReader(NumpyReader):
""" """
return has_itk or is_supported_format(filename, SUPPORTED_IMAGE_FORMATS) return has_itk or is_supported_format(filename, SUPPORTED_IMAGE_FORMATS)
def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
""" """
Read image(s) from the given path. Read image(s) from the given path.
@ -166,7 +166,7 @@ class UniversalImageReader(NumpyReader):
Returns: Returns:
A single image or a list of images depending on the number of paths provided. A single image or a list of images depending on the number of paths provided.
""" """
images: List[np.ndarray] = [] # List to store the loaded images images: list[np.ndarray] = [] # List to store the loaded images
# Convert data to a tuple to support multiple files # Convert data to a tuple to support multiple files
filenames: Sequence[PathLike] = ensure_tuple(data) filenames: Sequence[PathLike] = ensure_tuple(data)

@ -2,7 +2,7 @@ import numpy as np
from skimage import exposure from skimage import exposure
from monai.config.type_definitions import KeysCollection from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import Transform, MapTransform from monai.transforms.transform import Transform, MapTransform
from typing import Dict, Hashable, Mapping, Sequence from typing import Hashable, Mapping, Sequence
__all__ = [ __all__ = [
"CustomNormalizeImage", "CustomNormalizeImage",
@ -23,7 +23,7 @@ class CustomNormalizeImage(Transform):
def __init__(self, percentiles: Sequence[float] = (0, 99), channel_wise: bool = False) -> None: def __init__(self, percentiles: Sequence[float] = (0, 99), channel_wise: bool = False) -> None:
""" """
Args: Args:
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling. percentiles (Sequence(float)): Lower and upper percentiles used for intensity scaling.
Default is (0, 99). Default is (0, 99).
channel_wise (bool): Whether to apply normalization on each channel individually. channel_wise (bool): Whether to apply normalization on each channel individually.
Default is False. Default is False.
@ -106,7 +106,7 @@ class CustomNormalizeImaged(MapTransform):
""" """
Args: Args:
keys (KeysCollection): Keys identifying the image entries in the dictionary. keys (KeysCollection): Keys identifying the image entries in the dictionary.
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling. percentiles (Sequence(float)): Lower and upper percentiles used for intensity scaling.
Default is (1, 99). Default is (1, 99).
channel_wise (bool): Whether to apply normalization on each channel individually. channel_wise (bool): Whether to apply normalization on each channel individually.
Default is False. Default is False.
@ -117,7 +117,7 @@ class CustomNormalizeImaged(MapTransform):
# Create an instance of the normalization transform with specified parameters. # Create an instance of the normalization transform with specified parameters.
self.normalizer: CustomNormalizeImage = CustomNormalizeImage(percentiles, channel_wise) self.normalizer: CustomNormalizeImage = CustomNormalizeImage(percentiles, channel_wise)
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:
""" """
Apply the normalization transform to each image in the input dictionary. Apply the normalization transform to each image in the input dictionary.
@ -125,10 +125,10 @@ class CustomNormalizeImaged(MapTransform):
data (Mapping[Hashable, np.ndarray]): A dictionary mapping keys to numpy arrays representing images. data (Mapping[Hashable, np.ndarray]): A dictionary mapping keys to numpy arrays representing images.
Returns: Returns:
Dict[Hashable, np.ndarray]: A dictionary with the same keys where the images have been normalized. dict(Hashable, np.ndarray): A dictionary with the same keys where the images have been normalized.
""" """
# Copy the input dictionary to avoid modifying the original data. # Copy the input dictionary to avoid modifying the original data.
d: Dict[Hashable, np.ndarray] = dict(data) d: dict[Hashable, np.ndarray] = dict(data)
# Iterate over each key specified in the transform and normalize the corresponding image. # Iterate over each key specified in the transform and normalize the corresponding image.
for key in self.keys: for key in self.keys:
d[key] = self.normalizer(d[key]) d[key] = self.normalizer(d[key])

@ -1,6 +1,6 @@
import torch import torch
import numpy as np import numpy as np
from typing import Hashable, List, Sequence, Optional, Tuple from typing import Sequence
from monai.utils.misc import fall_back_tuple from monai.utils.misc import fall_back_tuple
from monai.data.meta_tensor import MetaTensor from monai.data.meta_tensor import MetaTensor
@ -14,7 +14,7 @@ logger = get_logger(__name__)
def _compute_multilabel_bbox( def _compute_multilabel_bbox(
mask: np.ndarray mask: np.ndarray
) -> Optional[Tuple[List[int], List[int], List[int], List[int]]]: ) -> tuple[list[int], list[int], list[int], list[int]] | None:
""" """
Compute per-channel bounding-box constraints and return lists of limits for each axis. Compute per-channel bounding-box constraints and return lists of limits for each axis.
@ -33,10 +33,10 @@ def _compute_multilabel_bbox(
if channels.size == 0: if channels.size == 0:
return None return None
top_mins: List[int] = [] top_mins: list[int] = []
top_maxs: List[int] = [] top_maxs: list[int] = []
left_mins: List[int] = [] left_mins: list[int] = []
left_maxs: List[int] = [] left_maxs: list[int] = []
C = mask.shape[0] C = mask.shape[0]
for ch in range(C): for ch in range(C):
rs, cs = np.nonzero(mask[ch]) rs, cs = np.nonzero(mask[ch])
@ -74,7 +74,7 @@ class SpatialCropAllClasses(Randomizable, Crop):
super().__init__(lazy=lazy) super().__init__(lazy=lazy)
self.roi_size = tuple(roi_size) self.roi_size = tuple(roi_size)
self.num_candidates = num_candidates self.num_candidates = num_candidates
self._slices: Optional[Tuple[slice, ...]] = None self._slices: tuple[slice, ...] | None = None
def randomize(self, img_size: Sequence[int]) -> None: # type: ignore def randomize(self, img_size: Sequence[int]) -> None: # type: ignore
""" """
@ -139,7 +139,7 @@ class SpatialCropAllClasses(Randomizable, Crop):
slice(left, left + crop_w), slice(left, left + crop_w),
) )
def __call__(self, img: torch.Tensor, lazy: Optional[bool] = None) -> torch.Tensor: # type: ignore def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore
""" """
On first call (mask), computes crop. On subsequent (image), just applies. On first call (mask), computes crop. On subsequent (image), just applies.
Raises if mask not provided first. Raises if mask not provided first.

@ -1,4 +1,4 @@
from typing import Dict, Final, Tuple, Type, List, Any, Union from typing import Final, Type, Any
from pydantic import BaseModel from pydantic import BaseModel
from .base import BaseLoss from .base import BaseLoss
@ -16,7 +16,7 @@ __all__ = [
class CriterionRegistry: class CriterionRegistry:
"""Registry of loss functions and their parameter classes with case-insensitive lookup.""" """Registry of loss functions and their parameter classes with case-insensitive lookup."""
__CRITERIONS: Final[Dict[str, Dict[str, Any]]] = { __CRITERIONS: Final[dict[str, dict[str, Any]]] = {
"CrossEntropyLoss": { "CrossEntropyLoss": {
"class": CrossEntropyLoss, "class": CrossEntropyLoss,
"params": CrossEntropyLossParams, "params": CrossEntropyLossParams,
@ -36,7 +36,7 @@ class CriterionRegistry:
} }
@classmethod @classmethod
def __get_entry(cls, name: str) -> Dict[str, Any]: def __get_entry(cls, name: str) -> dict[str, Any]:
""" """
Private method to retrieve the criterion entry from the registry using case-insensitive lookup. Private method to retrieve the criterion entry from the registry using case-insensitive lookup.
@ -44,7 +44,7 @@ class CriterionRegistry:
name (str): The name of the loss function. name (str): The name of the loss function.
Returns: Returns:
Dict[str, Any]: A dictionary containing the keys 'class' and 'params'. dict(str, Any): A dictionary containing the keys 'class' and 'params'.
Raises: Raises:
ValueError: If the loss function is not found. ValueError: If the loss function is not found.
@ -67,7 +67,7 @@ class CriterionRegistry:
name (str): Name of the loss function. name (str): Name of the loss function.
Returns: Returns:
Type[BaseLoss]: The loss function class. Type(BaseLoss): The loss function class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["class"] return entry["class"]
@ -81,17 +81,17 @@ class CriterionRegistry:
name (str): Name of the loss function. name (str): Name of the loss function.
Returns: Returns:
Type[BaseModel]: The loss function parameter class. Type(BaseModel): The loss function parameter class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["params"] return entry["params"]
@classmethod @classmethod
def get_available_criterions(cls) -> Tuple[str, ...]: def get_available_criterions(cls) -> tuple[str, ...]:
""" """
Returns a tuple of available loss function names in their original case. Returns a tuple of available loss function names in their original case.
Returns: Returns:
Tuple[str]: Tuple of available loss function names. tuple(str): Tuple of available loss function names.
""" """
return tuple(cls.__CRITERIONS.keys()) return tuple(cls.__CRITERIONS.keys())

@ -1,15 +1,12 @@
import abc import abc
import torch import torch
import torch.nn as nn
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, Any, Optional
from monai.metrics.cumulative_average import CumulativeAverage
class BaseLoss(nn.Module, abc.ABC): class BaseLoss(torch.nn.Module, abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, params: Optional[BaseModel] = None): def __init__(self, params: BaseModel | None = None) -> None:
super().__init__() super().__init__()
@ -28,16 +25,16 @@ class BaseLoss(nn.Module, abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def get_loss_metrics(self) -> Dict[str, float]: def get_loss_metrics(self) -> dict[str, float]:
""" """
Retrieves the tracked loss metrics. Retrieves the tracked loss metrics.
Returns: Returns:
Dict[str, float]: A dictionary containing the loss name and average loss value. dict(str, float): A dictionary containing the loss name and average loss value.
""" """
@abc.abstractmethod @abc.abstractmethod
def reset_metrics(self): def reset_metrics(self) -> None:
"""Resets the stored loss metrics.""" """Resets the stored loss metrics."""

@ -1,6 +1,8 @@
from .base import * from .base import BaseLoss
from typing import List, Literal, Union import torch
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from monai.metrics.cumulative_average import CumulativeAverage
class BCELossParams(BaseModel): class BCELossParams(BaseModel):
@ -11,11 +13,11 @@ class BCELossParams(BaseModel):
with_logits: bool = False with_logits: bool = False
weight: Optional[List[Union[int, float]]] = None # Sample weights weight: list[int | float] | None = None # Sample weights
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss pos_weight: list[int | float] | None = None # Used only for BCEWithLogitsLoss
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
""" """
Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`. Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`.
@ -23,7 +25,7 @@ class BCELossParams(BaseModel):
- Ensures only the valid parameters are passed based on the loss function. - Ensures only the valid parameters are passed based on the loss function.
Returns: Returns:
Dict[str, Any]: Filtered dictionary of parameters. dict(str, Any): Filtered dictionary of parameters.
""" """
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
if not self.with_logits: if not self.with_logits:
@ -47,19 +49,23 @@ class BCELoss(BaseLoss):
Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics. Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics.
""" """
def __init__(self, params: Optional[BCELossParams] = None): def __init__(self, params: BCELossParams | None = None) -> None:
""" """
Initializes the loss function with optional BCELoss parameters. Initializes the loss function with optional BCELoss parameters.
Args: Args:
params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None). params (BCELossParams | None): Parameters for nn.BCELoss (default: None).
""" """
super().__init__(params=params) super().__init__(params=params)
with_logits = params.with_logits if params is not None else False with_logits = params.with_logits if params is not None else False
_bce_params = params.asdict() if params is not None else {} _bce_params = params.asdict() if params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults # Initialize loss functions with user-provided parameters or PyTorch defaults
self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params) self.bce_loss = (
torch.nn.BCEWithLogitsLoss(**_bce_params)
if with_logits
else torch.nn.BCELoss(**_bce_params)
)
# Using CumulativeAverage from MONAI to track loss metrics # Using CumulativeAverage from MONAI to track loss metrics
self.loss_bce_metric = CumulativeAverage() self.loss_bce_metric = CumulativeAverage()
@ -90,18 +96,18 @@ class BCELoss(BaseLoss):
return loss return loss
def get_loss_metrics(self) -> Dict[str, float]: def get_loss_metrics(self) -> dict[str, float]:
""" """
Retrieves the tracked loss metrics. Retrieves the tracked loss metrics.
Returns: Returns:
Dict[str, float]: A dictionary containing the average BCE loss. dict(str, float): A dictionary containing the average BCE loss.
""" """
return { return {
"loss": round(self.loss_bce_metric.aggregate().item(), 4), "loss": round(self.loss_bce_metric.aggregate().item(), 4),
} }
def reset_metrics(self): def reset_metrics(self) -> None:
"""Resets the stored loss metrics.""" """Resets the stored loss metrics."""
self.loss_bce_metric.reset() self.loss_bce_metric.reset()

@ -1,6 +1,8 @@
from .base import * from .base import BaseLoss
from typing import List, Literal, Union import torch
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from monai.metrics.cumulative_average import CumulativeAverage
class CrossEntropyLossParams(BaseModel): class CrossEntropyLossParams(BaseModel):
@ -9,17 +11,17 @@ class CrossEntropyLossParams(BaseModel):
""" """
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
weight: Optional[List[Union[int, float]]] = None weight: list[int | float] | None = None
ignore_index: int = -100 ignore_index: int = -100
reduction: Literal["none", "mean", "sum"] = "mean" reduction: Literal["none", "mean", "sum"] = "mean"
label_smoothing: float = 0.0 label_smoothing: float = 0.0
def asdict(self): def asdict(self) -> dict[str, Any]:
""" """
Returns a dictionary of valid parameters for `nn.CrossEntropyLoss`. Returns a dictionary of valid parameters for `nn.CrossEntropyLoss`.
Returns: Returns:
Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss. dict(str, Any): Dictionary of parameters for nn.CrossEntropyLoss.
""" """
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
@ -36,18 +38,18 @@ class CrossEntropyLoss(BaseLoss):
Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics. Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics.
""" """
def __init__(self, params: Optional[CrossEntropyLossParams] = None): def __init__(self, params: CrossEntropyLossParams | None = None) -> None:
""" """
Initializes the loss function with optional CrossEntropyLoss parameters. Initializes the loss function with optional CrossEntropyLoss parameters.
Args: Args:
params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None). params (CrossEntropyLossParams | None): Parameters for nn.CrossEntropyLoss (default: None).
""" """
super().__init__(params=params) super().__init__(params=params)
_ce_params = params.asdict() if params is not None else {} _ce_params = params.asdict() if params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults # Initialize loss functions with user-provided parameters or PyTorch defaults
self.ce_loss = nn.CrossEntropyLoss(**_ce_params) self.ce_loss = torch.nn.CrossEntropyLoss(**_ce_params)
# Using CumulativeAverage from MONAI to track loss metrics # Using CumulativeAverage from MONAI to track loss metrics
self.loss_ce_metric = CumulativeAverage() self.loss_ce_metric = CumulativeAverage()
@ -78,18 +80,18 @@ class CrossEntropyLoss(BaseLoss):
return loss return loss
def get_loss_metrics(self) -> Dict[str, float]: def get_loss_metrics(self) -> dict[str, float]:
""" """
Retrieves the tracked loss metrics. Retrieves the tracked loss metrics.
Returns: Returns:
Dict[str, float]: A dictionary containing the average CrossEntropy loss. dict(str, float): A dictionary containing the average CrossEntropy loss.
""" """
return { return {
"loss": round(self.loss_ce_metric.aggregate().item(), 4), "loss": round(self.loss_ce_metric.aggregate().item(), 4),
} }
def reset_metrics(self): def reset_metrics(self) -> None:
"""Resets the stored loss metrics.""" """Resets the stored loss metrics."""
self.loss_ce_metric.reset() self.loss_ce_metric.reset()

@ -1,6 +1,8 @@
from .base import * from .base import BaseLoss
from typing import Literal import torch
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from monai.metrics.cumulative_average import CumulativeAverage
class MSELossParams(BaseModel): class MSELossParams(BaseModel):
@ -11,12 +13,12 @@ class MSELossParams(BaseModel):
reduction: Literal["none", "mean", "sum"] = "mean" reduction: Literal["none", "mean", "sum"] = "mean"
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
""" """
Returns a dictionary of valid parameters for `nn.MSELoss`. Returns a dictionary of valid parameters for `nn.MSELoss`.
Returns: Returns:
Dict[str, Any]: Dictionary of parameters for `nn.MSELoss`. dict(str, Any): Dictionary of parameters for `nn.MSELoss`.
""" """
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
@ -27,18 +29,18 @@ class MSELoss(BaseLoss):
Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics. Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics.
""" """
def __init__(self, params: Optional[MSELossParams] = None): def __init__(self, params: MSELossParams | None = None):
""" """
Initializes the loss function with optional MSELoss parameters. Initializes the loss function with optional MSELoss parameters.
Args: Args:
params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None). params (MSELossParams | None): Parameters for `nn.MSELoss` (default: None).
""" """
super().__init__(params=params) super().__init__(params=params)
_mse_params = params.asdict() if params is not None else {} _mse_params = params.asdict() if params is not None else {}
# Initialize MSE loss with user-provided parameters or PyTorch defaults # Initialize MSE loss with user-provided parameters or PyTorch defaults
self.mse_loss = nn.MSELoss(**_mse_params) self.mse_loss = torch.nn.MSELoss(**_mse_params)
# Using CumulativeAverage from MONAI to track loss metrics # Using CumulativeAverage from MONAI to track loss metrics
self.loss_mse_metric = CumulativeAverage() self.loss_mse_metric = CumulativeAverage()
@ -67,12 +69,12 @@ class MSELoss(BaseLoss):
return loss return loss
def get_loss_metrics(self) -> Dict[str, float]: def get_loss_metrics(self) -> dict[str, float]:
""" """
Retrieves the tracked loss metrics. Retrieves the tracked loss metrics.
Returns: Returns:
Dict[str, float]: A dictionary containing the average MSE loss. dict(str, float): A dictionary containing the average MSE loss.
""" """
return { return {
"loss": round(self.loss_mse_metric.aggregate().item(), 4), "loss": round(self.loss_mse_metric.aggregate().item(), 4),

@ -1,8 +1,11 @@
from .base import * from .base import BaseLoss
from .bce import BCELossParams from .bce import BCELossParams
from .mse import MSELossParams from .mse import MSELossParams
import torch
from typing import Any
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from monai.metrics.cumulative_average import CumulativeAverage
class BCE_MSE_LossParams(BaseModel): class BCE_MSE_LossParams(BaseModel):
@ -15,12 +18,12 @@ class BCE_MSE_LossParams(BaseModel):
bce_params: BCELossParams = BCELossParams() bce_params: BCELossParams = BCELossParams()
mse_params: MSELossParams = MSELossParams() mse_params: MSELossParams = MSELossParams()
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
""" """
Returns a dictionary of valid parameters for `nn.BCELoss` and `nn.MSELoss`. Returns a dictionary of valid parameters for `nn.BCELoss` and `nn.MSELoss`.
Returns: Returns:
Dict[str, Any]: Dictionary of parameters. dict(str, Any): Dictionary of parameters.
""" """
return { return {
@ -35,7 +38,7 @@ class BCE_MSE_Loss(BaseLoss):
Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction. Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction.
""" """
def __init__(self, params: Optional[BCE_MSE_LossParams]): def __init__(self, params: BCE_MSE_LossParams | None = None):
""" """
Initializes the loss function with optional BCE and MSE parameters. Initializes the loss function with optional BCE and MSE parameters.
""" """
@ -50,14 +53,16 @@ class BCE_MSE_Loss(BaseLoss):
# Choose BCE loss function # Choose BCE loss function
self.bce_loss = ( self.bce_loss = (
nn.BCEWithLogitsLoss(**_bce_params) if _params.bce_params.with_logits else nn.BCELoss(**_bce_params) torch.nn.BCEWithLogitsLoss(**_bce_params)
if _params.bce_params.with_logits
else torch.nn.BCELoss(**_bce_params)
) )
# Process MSE parameters # Process MSE parameters
_mse_params = _params.mse_params.asdict() _mse_params = _params.mse_params.asdict()
# Initialize MSE loss # Initialize MSE loss
self.mse_loss = nn.MSELoss(**_mse_params) self.mse_loss = torch.nn.MSELoss(**_mse_params)
# Using CumulativeAverage from MONAI to track loss metrics # Using CumulativeAverage from MONAI to track loss metrics
self.loss_bce_metric = CumulativeAverage() self.loss_bce_metric = CumulativeAverage()
@ -101,12 +106,12 @@ class BCE_MSE_Loss(BaseLoss):
return total_loss return total_loss
def get_loss_metrics(self) -> Dict[str, float]: def get_loss_metrics(self) -> dict[str, float]:
""" """
Retrieves the tracked loss metrics. Retrieves the tracked loss metrics.
Returns: Returns:
Dict[str, float]: A dictionary containing the average BCE and MSE loss. dict(str, float): A dictionary containing the average BCE and MSE loss.
""" """
return { return {
"bce_loss": round(self.loss_bce_metric.aggregate().item(), 4), "bce_loss": round(self.loss_bce_metric.aggregate().item(), 4),

@ -1,5 +1,5 @@
import torch.nn as nn from torch import nn
from typing import Dict, Final, Tuple, Type, Any, List, Union from typing import Final, Type, Any
from pydantic import BaseModel from pydantic import BaseModel
from .model_v import ModelV, ModelVParams from .model_v import ModelV, ModelVParams
@ -16,7 +16,7 @@ class ModelRegistry:
"""Registry for models and their parameter classes with case-insensitive lookup.""" """Registry for models and their parameter classes with case-insensitive lookup."""
# Single dictionary storing both model classes and parameter classes. # Single dictionary storing both model classes and parameter classes.
__MODELS: Final[Dict[str, Dict[str, Type[Any]]]] = { __MODELS: Final[dict[str, dict[str, Type[Any]]]] = {
"ModelV": { "ModelV": {
"class": ModelV, "class": ModelV,
"params": ModelVParams, "params": ModelVParams,
@ -24,7 +24,7 @@ class ModelRegistry:
} }
@classmethod @classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
""" """
Private method to retrieve the model entry from the registry using case-insensitive lookup. Private method to retrieve the model entry from the registry using case-insensitive lookup.
@ -32,7 +32,7 @@ class ModelRegistry:
name (str): The name of the model. name (str): The name of the model.
Returns: Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. dict(str, Type[Any]): A dictionary containing the keys 'class' and 'params'.
Raises: Raises:
ValueError: If the model is not found. ValueError: If the model is not found.
@ -55,7 +55,7 @@ class ModelRegistry:
name (str): Name of the model. name (str): Name of the model.
Returns: Returns:
Type[nn.Module]: The model class. Type(torch.nn.Module): The model class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["class"] return entry["class"]
@ -69,17 +69,17 @@ class ModelRegistry:
name (str): Name of the model. name (str): Name of the model.
Returns: Returns:
Type[BaseModel]: The model parameter class. Type(BaseModel): The model parameter class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["params"] return entry["params"]
@classmethod @classmethod
def get_available_models(cls) -> Tuple[str, ...]: def get_available_models(cls) -> tuple[str, ...]:
""" """
Returns a tuple of available model names in their original case. Returns a tuple of available model names in their original case.
Returns: Returns:
Tuple[str]: Tuple of available model names. Tuple(str): Tuple of available model names.
""" """
return tuple(cls.__MODELS.keys()) return tuple(cls.__MODELS.keys())

@ -1,8 +1,7 @@
from typing import List, Optional
import torch import torch
import torch.nn as nn from torch import nn
from typing import Any
from segmentation_models_pytorch import MAnet from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation from segmentation_models_pytorch.base.modules import Activation
@ -15,18 +14,18 @@ class ModelVParams(BaseModel):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
encoder_name: str = "mit_b5" # Default encoder encoder_name: str = "mit_b5" # Default encoder
encoder_weights: Optional[str] = "imagenet" # Pre-trained weights encoder_weights: str | None = "imagenet" # Pre-trained weights
decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration decoder_channels: list[int] = [1024, 512, 256, 128, 64] # Decoder configuration
decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels
in_channels: int = 3 # Number of input channels in_channels: int = 3 # Number of input channels
out_classes: int = 1 # Number of output classes out_classes: int = 1 # Number of output classes
def asdict(self): def asdict(self) -> dict[str, Any]:
""" """
Returns a dictionary of valid parameters for `nn.ModelV`. Returns a dictionary of valid parameters for `nn.ModelV`.
Returns: Returns:
Dict[str, Any]: Dictionary of parameters for nn.ModelV. dict(str, Any): Dictionary of parameters for nn.ModelV.
""" """
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
@ -84,11 +83,11 @@ class DeepSegmentationHead(nn.Sequential):
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: int = 3, kernel_size: int = 3,
activation: Optional[str] = None, activation: str | None = None,
upsampling: int = 1, upsampling: int = 1,
) -> None: ) -> None:
# Define a sequence of layers for the segmentation head # Define a sequence of layers for the segmentation head
layers: List[nn.Module] = [ layers: list[nn.Module] = [
nn.Conv2d( nn.Conv2d(
in_channels, in_channels,
in_channels // 2, in_channels // 2,

@ -1,5 +1,5 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, Final, Tuple, Type, List, Any, Union from typing import Final, Type, Any
from .base import BaseOptimizer from .base import BaseOptimizer
from .adam import AdamParams, AdamOptimizer from .adam import AdamParams, AdamOptimizer
@ -16,7 +16,7 @@ class OptimizerRegistry:
"""Registry for optimizers and their parameter classes with case-insensitive lookup.""" """Registry for optimizers and their parameter classes with case-insensitive lookup."""
# Single dictionary storing both optimizer classes and parameter classes. # Single dictionary storing both optimizer classes and parameter classes.
__OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = { __OPTIMIZERS: Final[dict[str, dict[str, Type[Any]]]] = {
"SGD": { "SGD": {
"class": SGDOptimizer, "class": SGDOptimizer,
"params": SGDParams, "params": SGDParams,
@ -32,7 +32,7 @@ class OptimizerRegistry:
} }
@classmethod @classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
""" """
Private method to retrieve the optimizer entry from the registry using case-insensitive lookup. Private method to retrieve the optimizer entry from the registry using case-insensitive lookup.
@ -40,7 +40,7 @@ class OptimizerRegistry:
name (str): The name of the optimizer. name (str): The name of the optimizer.
Returns: Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. dict(str, Type(Any)): A dictionary containing the keys 'class' and 'params'.
Raises: Raises:
ValueError: If the optimizer is not found. ValueError: If the optimizer is not found.
@ -63,7 +63,7 @@ class OptimizerRegistry:
name (str): Name of the optimizer. name (str): Name of the optimizer.
Returns: Returns:
Type[BaseOptimizer]: The optimizer class. Type(BaseOptimizer): The optimizer class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["class"] return entry["class"]
@ -77,17 +77,17 @@ class OptimizerRegistry:
name (str): Name of the optimizer. name (str): Name of the optimizer.
Returns: Returns:
Type[BaseModel]: The optimizer parameter class. Type(BaseModel): The optimizer parameter class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["params"] return entry["params"]
@classmethod @classmethod
def get_available_optimizers(cls) -> Tuple[str, ...]: def get_available_optimizers(cls) -> tuple[str, ...]:
""" """
Returns a tuple of available optimizer names in their original case. Returns a tuple of available optimizer names in their original case.
Returns: Returns:
Tuple[str]: Tuple of available optimizer names. Tuple(str): Tuple of available optimizer names.
""" """
return tuple(cls.__OPTIMIZERS.keys()) return tuple(cls.__OPTIMIZERS.keys())

@ -1,6 +1,6 @@
import torch import torch
from torch import optim from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Iterable
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer from .base import BaseOptimizer
@ -10,12 +10,12 @@ class AdamParams(BaseModel):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate lr: float = 1e-3 # Learning rate
betas: Tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages betas: tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages
eps: float = 1e-8 # Term added to denominator for numerical stability eps: float = 1e-8 # Term added to denominator for numerical stability
weight_decay: float = 0.0 # L2 regularization weight_decay: float = 0.0 # L2 regularization
amsgrad: bool = False # Whether to use the AMSGrad variant amsgrad: bool = False # Whether to use the AMSGrad variant
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.Adam`.""" """Returns a dictionary of valid parameters for `torch.optim.Adam`."""
return self.model_dump() return self.model_dump()
@ -25,7 +25,7 @@ class AdamOptimizer(BaseOptimizer):
Wrapper around torch.optim.Adam. Wrapper around torch.optim.Adam.
""" """
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams): def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams) -> None:
""" """
Initializes the Adam optimizer with given parameters. Initializes the Adam optimizer with given parameters.

@ -1,6 +1,6 @@
import torch import torch
from torch import optim from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Iterable
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer from .base import BaseOptimizer
@ -10,12 +10,12 @@ class AdamWParams(BaseModel):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate lr: float = 1e-3 # Learning rate
betas: Tuple[float, ...] = (0.9, 0.999) # Adam coefficients betas: tuple[float, ...] = (0.9, 0.999) # Adam coefficients
eps: float = 1e-8 # Numerical stability eps: float = 1e-8 # Numerical stability
weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay) weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay)
amsgrad: bool = False # Whether to use the AMSGrad variant amsgrad: bool = False # Whether to use the AMSGrad variant
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.AdamW`.""" """Returns a dictionary of valid parameters for `torch.optim.AdamW`."""
return self.model_dump() return self.model_dump()
@ -25,7 +25,7 @@ class AdamWOptimizer(BaseOptimizer):
Wrapper around torch.optim.AdamW. Wrapper around torch.optim.AdamW.
""" """
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams): def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams) -> None:
""" """
Initializes the AdamW optimizer with given parameters. Initializes the AdamW optimizer with given parameters.

@ -1,15 +1,15 @@
import torch import torch
import torch.optim as optim from torch import optim
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Iterable, Optional from typing import Any, Iterable
class BaseOptimizer: class BaseOptimizer:
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel): def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel) -> None:
super().__init__() super().__init__()
self.optim: Optional[optim.Optimizer] = None self.optim: optim.Optimizer | None = None
def zero_grad(self, set_to_none: bool = True) -> None: def zero_grad(self, set_to_none: bool = True) -> None:
@ -25,12 +25,12 @@ class BaseOptimizer:
self.optim.zero_grad(set_to_none=set_to_none) self.optim.zero_grad(set_to_none=set_to_none)
def step(self, closure: Optional[Any] = None) -> Any: def step(self, closure: Any | None = None) -> Any:
""" """
Performs a single optimization step (parameter update). Performs a single optimization step (parameter update).
Args: Args:
closure (Optional[Callable]): A closure that reevaluates the model and returns the loss. closure (Any | None): A closure that reevaluates the model and returns the loss.
This is required for optimizers like LBFGS that need multiple forward passes. This is required for optimizers like LBFGS that need multiple forward passes.
Returns: Returns:

@ -1,6 +1,6 @@
import torch import torch
from torch import optim from torch import optim
from typing import Any, Dict, Iterable, Optional from typing import Any, Iterable
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer from .base import BaseOptimizer
@ -16,7 +16,7 @@ class SGDParams(BaseModel):
weight_decay: float = 0.0 # L2 penalty weight_decay: float = 0.0 # L2 penalty
nesterov: bool = False # Enables Nesterov momentum nesterov: bool = False # Enables Nesterov momentum
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.SGD`.""" """Returns a dictionary of valid parameters for `torch.optim.SGD`."""
return self.model_dump() return self.model_dump()
@ -26,7 +26,7 @@ class SGDOptimizer(BaseOptimizer):
Wrapper around torch.optim.SGD. Wrapper around torch.optim.SGD.
""" """
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams): def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams) -> None:
""" """
Initializes the SGD optimizer with given parameters. Initializes the SGD optimizer with given parameters.

@ -1,5 +1,4 @@
import torch.optim.lr_scheduler as lr_scheduler from typing import Final, Type, Any
from typing import Dict, Final, Tuple, Type, List, Any, Union
from pydantic import BaseModel from pydantic import BaseModel
from .base import BaseScheduler from .base import BaseScheduler
@ -17,7 +16,7 @@ __all__ = [
class SchedulerRegistry: class SchedulerRegistry:
"""Registry for learning rate schedulers and their parameter classes with case-insensitive lookup.""" """Registry for learning rate schedulers and their parameter classes with case-insensitive lookup."""
__SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = { __SCHEDULERS: Final[dict[str, dict[str, Type[Any]]]] = {
"Step": { "Step": {
"class": StepLRScheduler, "class": StepLRScheduler,
"params": StepLRParams, "params": StepLRParams,
@ -37,7 +36,7 @@ class SchedulerRegistry:
} }
@classmethod @classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
""" """
Private method to retrieve the scheduler entry from the registry using case-insensitive lookup. Private method to retrieve the scheduler entry from the registry using case-insensitive lookup.
@ -45,7 +44,7 @@ class SchedulerRegistry:
name (str): The name of the scheduler. name (str): The name of the scheduler.
Returns: Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. dict(str, Type(Any)): A dictionary containing the keys 'class' and 'params'.
Raises: Raises:
ValueError: If the scheduler is not found. ValueError: If the scheduler is not found.
@ -68,7 +67,7 @@ class SchedulerRegistry:
name (str): Name of the scheduler. name (str): Name of the scheduler.
Returns: Returns:
Type[BaseScheduler]: The scheduler class. Type(BaseScheduler): The scheduler class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["class"] return entry["class"]
@ -82,17 +81,17 @@ class SchedulerRegistry:
name (str): Name of the scheduler. name (str): Name of the scheduler.
Returns: Returns:
Type[BaseModel]: The scheduler parameter class. Type(BaseModel): The scheduler parameter class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["params"] return entry["params"]
@classmethod @classmethod
def get_available_schedulers(cls) -> Tuple[str, ...]: def get_available_schedulers(cls) -> tuple[str, ...]:
""" """
Returns a tuple of available scheduler names in their original case. Returns a tuple of available scheduler names in their original case.
Returns: Returns:
Tuple[str]: Tuple of available scheduler names. Tuple(str): Tuple of available scheduler names.
""" """
return tuple(cls.__SCHEDULERS.keys()) return tuple(cls.__SCHEDULERS.keys())

@ -1,6 +1,5 @@
import torch.optim as optim from torch import optim
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional
class BaseScheduler: class BaseScheduler:
@ -9,8 +8,8 @@ class BaseScheduler:
Wraps a PyTorch LR scheduler and provides a unified interface. Wraps a PyTorch LR scheduler and provides a unified interface.
""" """
def __init__(self, optimizer: optim.Optimizer, params: BaseModel): def __init__(self, optimizer: optim.Optimizer, params: BaseModel) -> None:
self.scheduler: Optional[optim.lr_scheduler.LRScheduler] = None self.scheduler: optim.lr_scheduler.LRScheduler | None = None
def step(self) -> None: def step(self) -> None:
""" """
@ -20,7 +19,7 @@ class BaseScheduler:
if self.scheduler is not None: if self.scheduler is not None:
self.scheduler.step() self.scheduler.step()
def get_last_lr(self) -> List[float]: def get_last_lr(self) -> list[float]:
""" """
Returns the most recent learning rate(s). Returns the most recent learning rate(s).
""" """

@ -1,10 +1,9 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from .base import BaseScheduler from .base import BaseScheduler
from typing import Any
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from pydantic import BaseModel, ConfigDict
class CosineAnnealingLRParams(BaseModel): class CosineAnnealingLRParams(BaseModel):
@ -16,7 +15,7 @@ class CosineAnnealingLRParams(BaseModel):
last_epoch: int = -1 last_epoch: int = -1
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`.""" """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`."""
return self.model_dump() return self.model_dump()
@ -26,7 +25,7 @@ class CosineAnnealingLRScheduler(BaseScheduler):
Wrapper around torch.optim.lr_scheduler.CosineAnnealingLR. Wrapper around torch.optim.lr_scheduler.CosineAnnealingLR.
""" """
def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams): def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams) -> None:
""" """
Args: Args:
optimizer (Optimizer): Wrapped optimizer. optimizer (Optimizer): Wrapped optimizer.

@ -1,9 +1,9 @@
from typing import Any, Dict from .base import BaseScheduler
from pydantic import BaseModel, ConfigDict
from typing import Any
from torch import optim from torch import optim
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from pydantic import BaseModel, ConfigDict
from .base import BaseScheduler
class ExponentialLRParams(BaseModel): class ExponentialLRParams(BaseModel):
@ -13,7 +13,7 @@ class ExponentialLRParams(BaseModel):
gamma: float = 0.95 # Multiplicative factor of learning rate decay gamma: float = 0.95 # Multiplicative factor of learning rate decay
last_epoch: int = -1 last_epoch: int = -1
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`.""" """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`."""
return self.model_dump() return self.model_dump()
@ -23,7 +23,7 @@ class ExponentialLRScheduler(BaseScheduler):
Wrapper around torch.optim.lr_scheduler.ExponentialLR. Wrapper around torch.optim.lr_scheduler.ExponentialLR.
""" """
def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams): def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams) -> None:
""" """
Args: Args:
optimizer (Optimizer): Wrapped optimizer. optimizer (Optimizer): Wrapped optimizer.

@ -1,20 +1,20 @@
from typing import Any, Dict, Tuple from .base import BaseScheduler
from pydantic import BaseModel, ConfigDict
from typing import Any
from torch import optim from torch import optim
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from pydantic import BaseModel, ConfigDict
from .base import BaseScheduler
class MultiStepLRParams(BaseModel): class MultiStepLRParams(BaseModel):
"""Configuration for `torch.optim.lr_scheduler.MultiStepLR`.""" """Configuration for `torch.optim.lr_scheduler.MultiStepLR`."""
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay milestones: tuple[int, ...] = (30, 80) # List of epoch indices for LR decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1 last_epoch: int = -1
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`.""" """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`."""
return self.model_dump() return self.model_dump()
@ -24,7 +24,7 @@ class MultiStepLRScheduler(BaseScheduler):
Wrapper around torch.optim.lr_scheduler.MultiStepLR. Wrapper around torch.optim.lr_scheduler.MultiStepLR.
""" """
def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams): def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams) -> None:
""" """
Args: Args:
optimizer (Optimizer): Wrapped optimizer. optimizer (Optimizer): Wrapped optimizer.

@ -1,9 +1,9 @@
from typing import Any, Dict from .base import BaseScheduler
from pydantic import BaseModel, ConfigDict
from typing import Any
from torch import optim from torch import optim
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from pydantic import BaseModel, ConfigDict
from .base import BaseScheduler
class StepLRParams(BaseModel): class StepLRParams(BaseModel):
@ -14,7 +14,7 @@ class StepLRParams(BaseModel):
gamma: float = 0.1 # Multiplicative factor of learning rate decay gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1 last_epoch: int = -1
def asdict(self) -> Dict[str, Any]: def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`.""" """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`."""
return self.model_dump() return self.model_dump()
@ -25,7 +25,7 @@ class StepLRScheduler(BaseScheduler):
Wrapper around torch.optim.lr_scheduler.StepLR. Wrapper around torch.optim.lr_scheduler.StepLR.
""" """
def __init__(self, optimizer: optim.Optimizer, params: StepLRParams): def __init__(self, optimizer: optim.Optimizer, params: StepLRParams) -> None:
""" """
Args: Args:
optimizer (Optimizer): Wrapped optimizer. optimizer (Optimizer): Wrapped optimizer.

@ -19,15 +19,14 @@ from torch.utils.data import DataLoader
import fastremap import fastremap
import fill_voids import fill_voids
from skimage import morphology # from skimage import morphology
from skimage.segmentation import find_boundaries from skimage.segmentation import find_boundaries
from scipy.special import expit from scipy.special import expit
from scipy.ndimage import mean, find_objects from scipy.ndimage import mean, find_objects
from monai.data.dataset import Dataset from monai.data.dataset import Dataset
from monai.transforms import * # type: ignore from monai.transforms.compose import Compose
from monai.inferers.utils import sliding_window_inference from monai.inferers.utils import sliding_window_inference
from monai.metrics.cumulative_average import CumulativeAverage
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.colors as mcolors import matplotlib.colors as mcolors
@ -42,16 +41,16 @@ from itertools import chain
from pprint import pformat from pprint import pformat
from tabulate import tabulate from tabulate import tabulate
from typing import Any, Dict, Literal, Optional, Tuple, List, Union from typing import Any, Literal
from tqdm import tqdm from tqdm import tqdm
import wandb import wandb
from config import Config from config import Config
from core.models import * from core.models import ModelRegistry
from core.losses import * from core.losses import CriterionRegistry
from core.optimizers import * from core.optimizers import OptimizerRegistry
from core.schedulers import * from core.schedulers import SchedulerRegistry
from core.utils import ( from core.utils import (
compute_batch_segmentation_tp_fp_fn, compute_batch_segmentation_tp_fp_fn,
compute_f1_score, compute_f1_score,
@ -78,30 +77,30 @@ class CellSegmentator:
else None else None
) )
self._train_dataloader: Optional[DataLoader] = None self._train_dataloader: DataLoader | None = None
self._valid_dataloader: Optional[DataLoader] = None self._valid_dataloader: DataLoader | None = None
self._test_dataloader: Optional[DataLoader] = None self._test_dataloader: DataLoader | None = None
self._predict_dataloader: Optional[DataLoader] = None self._predict_dataloader: DataLoader | None = None
self._best_weights = None self._best_weights = None
def create_dataloaders( def create_dataloaders(
self, self,
train_transforms: Optional[Compose] = None, train_transforms: Compose | None = None,
valid_transforms: Optional[Compose] = None, valid_transforms: Compose | None = None,
test_transforms: Optional[Compose] = None, test_transforms: Compose | None = None,
predict_transforms: Optional[Compose] = None predict_transforms: Compose | None = None
) -> None: ) -> None:
""" """
Creates train, validation, test, and prediction dataloaders based on dataset configuration Creates train, validation, test, and prediction dataloaders based on dataset configuration
and provided transforms. and provided transforms.
Args: Args:
train_transforms (Optional[Compose]): Transformations for training data. train_transforms (Compose | None): Transformations for training data.
valid_transforms (Optional[Compose]): Transformations for validation data. valid_transforms (Compose | None): Transformations for validation data.
test_transforms (Optional[Compose]): Transformations for testing data. test_transforms (Compose | None): Transformations for testing data.
predict_transforms (Optional[Compose]): Transformations for prediction data. predict_transforms (Compose | None): Transformations for prediction data.
Raises: Raises:
ValueError: If required transforms are missing. ValueError: If required transforms are missing.
@ -257,7 +256,7 @@ class CellSegmentator:
def print_data_info( def print_data_info(
self, self,
loader_type: Literal["train", "valid", "test", "predict"], loader_type: Literal["train", "valid", "test", "predict"],
index: Optional[int] = None index: int | None = None
) -> None: ) -> None:
""" """
Prints statistics for a single sample from the specified dataloader. Prints statistics for a single sample from the specified dataloader.
@ -267,7 +266,7 @@ class CellSegmentator:
index: The sample index; if None, a random index is selected. index: The sample index; if None, a random index is selected.
""" """
# Retrieve the dataloader attribute, e.g., self._train_dataloader # Retrieve the dataloader attribute, e.g., self._train_dataloader
loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None) loader: DataLoader | None = getattr(self, f"_{loader_type}_dataloader", None)
if loader is None: if loader is None:
logger.error(f"Dataloader '{loader_type}' is not initialized.") logger.error(f"Dataloader '{loader_type}' is not initialized.")
return return
@ -326,8 +325,8 @@ class CellSegmentator:
lines.append("=" * 40) lines.append("=" * 40)
# Output via logger # Output via logger
for l in lines: for line in lines:
logger.info(l) logger.info(line)
def train(self, save_results: bool = True, only_masks: bool = False) -> None: def train(self, save_results: bool = True, only_masks: bool = False) -> None:
@ -661,16 +660,16 @@ class CellSegmentator:
logger.info(f"├─ Validation frequency: {training.val_freq}") logger.info(f"├─ Validation frequency: {training.val_freq}")
if training.is_split: if training.is_split:
logger.info(f"├─ Using pre-split directories:") logger.info( "├─ Using pre-split directories:")
logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}") logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}")
logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}") logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}")
logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}") logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}")
else: else:
logger.info(f"├─ Using unified dataset with splits:") logger.info( "├─ Using unified dataset with splits:")
logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}") logger.info( "│ ├─ All data dir: {training.split.all_data_dir}")
logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}") logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}")
logger.info(f"└─ Dataset split:") logger.info( "└─ Dataset split:")
logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}") logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}")
logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}") logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}")
logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}") logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}")
@ -703,12 +702,12 @@ class CellSegmentator:
logger.info("===================================") logger.info("===================================")
def __set_seed(self, seed: Optional[int]) -> None: def __set_seed(self, seed: int | None) -> None:
""" """
Sets the random seed for reproducibility across Python, NumPy, and PyTorch. Sets the random seed for reproducibility across Python, NumPy, and PyTorch.
Args: Args:
seed (Optional[int]): Seed value. If None, no seeding is performed. seed (int | None): Seed value. If None, no seeding is performed.
""" """
if seed is not None: if seed is not None:
random.seed(seed) random.seed(seed)
@ -724,9 +723,9 @@ class CellSegmentator:
def __get_dataset( def __get_dataset(
self, self,
images_dir: str, images_dir: str,
masks_dir: Optional[str], masks_dir: str | None,
transforms: Compose, transforms: Compose,
size: Union[int, float], size: int | float,
offset: int, offset: int,
shuffle: bool shuffle: bool
) -> Dataset: ) -> Dataset:
@ -735,9 +734,9 @@ class CellSegmentator:
Args: Args:
images_dir (str): Path to directory or glob pattern for input images. images_dir (str): Path to directory or glob pattern for input images.
masks_dir (Optional[str]): Path to directory or glob pattern for masks. masks_dir (str | None): Path to directory or glob pattern for masks.
transforms (Compose): Transformations to apply to each image or pair. transforms (Compose): Transformations to apply to each image or pair.
size (Union[int, float]): Either an integer or a fraction of the dataset. size (int | float): Either an integer or a fraction of the dataset.
offset (int): Number of images to skip from the start. offset (int): Number of images to skip from the start.
shuffle (bool): Whether to shuffle the dataset before slicing. shuffle (bool): Whether to shuffle the dataset before slicing.
@ -806,12 +805,12 @@ class CellSegmentator:
return Dataset(data, transforms) return Dataset(data, transforms)
def __print_with_logging(self, metrics: Dict[str, Union[float, np.ndarray]], step: int) -> None: def __print_with_logging(self, metrics: dict[str, float | np.ndarray], step: int) -> None:
""" """
Print metrics in a tabular format and log to W&B. Print metrics in a tabular format and log to W&B.
Args: Args:
metrics (Dict[str, Union[float, np.ndarray]]): Mapping from metric names metrics (dict(str, float | np.ndarray)): Mapping from metric names
to either a float or a ND numpy array. to either a float or a ND numpy array.
step (int): epoch index. step (int): epoch index.
""" """
@ -846,14 +845,14 @@ class CellSegmentator:
def __save_metrics_to_csv( def __save_metrics_to_csv(
self, self,
metrics: Dict[str, Union[float, np.ndarray]], metrics: dict[str, float | np.ndarray],
output_path: str output_path: str
) -> None: ) -> None:
""" """
Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'. Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'.
Args: Args:
metrics (Dict[str, Union[float, np.ndarray]]): metrics (dict(str, float | np.ndarray)):
Mapping from metric names to scalar values or numpy arrays. Mapping from metric names to scalar values or numpy arrays.
output_path (str): output_path (str):
Path to the output CSV file. Path to the output CSV file.
@ -874,22 +873,22 @@ class CellSegmentator:
def __run_epoch(self, def __run_epoch(self,
mode: Literal["train", "valid", "test"], mode: Literal["train", "valid", "test"],
epoch: Optional[int] = None, epoch: int | None = None,
save_results: bool = True, save_results: bool = True,
only_masks: bool = False only_masks: bool = False
) -> Dict[str, Union[float, np.ndarray]]: ) -> dict[str, float | np.ndarray]:
""" """
Execute one epoch of training, validation, or testing. Execute one epoch of training, validation, or testing.
Args: Args:
mode (str): One of 'train', 'valid', or 'test'. mode (str): One of 'train', 'valid', or 'test'.
epoch (int, optional): Current epoch number for logging. epoch (int | None): Current epoch number for logging.
save_results (bool): If True, the predicted masks and test metrics will be saved. save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved, only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays. without visualization overlays.
Returns: Returns:
Dict[str, Union[float, np.ndarray]]: Metrics for valid/test. dict(str, float | np.ndarray): Metrics for valid/test.
""" """
# Ensure required components are available # Ensure required components are available
if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None): if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None):
@ -988,7 +987,7 @@ class CellSegmentator:
if self._criterion is not None: if self._criterion is not None:
# Collect loss metrics # Collect loss metrics
epoch_metrics: Dict[str, Union[float, np.ndarray]] = { epoch_metrics: dict[str, float | np.ndarray] = {
f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items() f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()
} }
# Reset internal loss metrics accumulator # Reset internal loss metrics accumulator
@ -1051,17 +1050,17 @@ class CellSegmentator:
def __post_process_predictions( def __post_process_predictions(
self, self,
raw_outputs: torch.Tensor, raw_outputs: torch.Tensor,
ground_truth: Optional[torch.Tensor] = None ground_truth: torch.Tensor | None = None
) -> Tuple[np.ndarray, Optional[np.ndarray]]: ) -> tuple[np.ndarray, np.ndarray | None]:
""" """
Post-process raw network outputs to extract instance segmentation masks. Post-process raw network outputs to extract instance segmentation masks.
Args: Args:
raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W). raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W).
ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W). ground_truth (torch.Tensor | None): Ground truth masks of shape (B, С, H, W).
Returns: Returns:
Tuple[np.ndarray, Optional[np.ndarray]]: tuple(np.ndarray, np.ndarray | None):
- instance_masks: Instance-wise masks array of shape (B, С, H, W). - instance_masks: Instance-wise masks array of shape (B, С, H, W).
- labels_np: Converted ground truth of shape (B, С, H, W) or None if - labels_np: Converted ground truth of shape (B, С, H, W) or None if
ground_truth was not provided. ground_truth was not provided.
@ -1097,8 +1096,8 @@ class CellSegmentator:
ground_truth_masks: np.ndarray, ground_truth_masks: np.ndarray,
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_error_masks: bool = False return_error_masks: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, ) -> tuple[np.ndarray, np.ndarray, np.ndarray,
Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: np.ndarray | None, np.ndarray | None, np.ndarray | None]:
""" """
Compute batch-wise true positives, false positives, and false negatives Compute batch-wise true positives, false positives, and false negatives
for instance segmentation, using a configurable IoU threshold. for instance segmentation, using a configurable IoU threshold.
@ -1111,7 +1110,7 @@ class CellSegmentator:
return_error_masks (bool): Whether to also return binary error masks. return_error_masks (bool): Whether to also return binary error masks.
Returns: Returns:
Tuple(np.ndarray, np.ndarray, np.ndarray, tuple(np.ndarray, np.ndarray, np.ndarray,
np.ndarray | None, np.ndarray | None, np.ndarray | None): np.ndarray | None, np.ndarray | None, np.ndarray | None):
- tp: True positives per batch and class, shape (B, C) - tp: True positives per batch and class, shape (B, C)
- fp: False positives per batch and class, shape (B, C) - fp: False positives per batch and class, shape (B, C)
@ -1143,7 +1142,7 @@ class CellSegmentator:
false_positives: np.ndarray, false_positives: np.ndarray,
false_negatives: np.ndarray, false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro" reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro"
) -> Union[float, np.ndarray]: ) -> float | np.ndarray:
""" """
Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes. Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes.
@ -1266,7 +1265,7 @@ class CellSegmentator:
false_positives: np.ndarray, false_positives: np.ndarray,
false_negatives: np.ndarray, false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro" reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro"
) -> Union[float, np.ndarray]: ) -> float | np.ndarray:
""" """
Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes. Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes.
@ -1399,23 +1398,23 @@ class CellSegmentator:
def __save_prediction_masks( def __save_prediction_masks(
self, self,
sample: Dict[str, Any], sample: dict[str, Any],
predicted_mask: Union[np.ndarray, torch.Tensor], predicted_mask: np.ndarray | torch.Tensor,
start_index: int = 0, start_index: int = 0,
only_masks: bool = False, only_masks: bool = False,
masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None masks: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None
) -> None: ) -> None:
""" """
Save multi-channel predicted masks as TIFFs and Save multi-channel predicted masks as TIFFs and
corresponding visualizations as PNGs in separate folders. corresponding visualizations as PNGs in separate folders.
Args: Args:
sample (Dict[str, Any]): Batch sample from MONAI sample (dict(str, Any)): Batch sample from MONAI
LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict'). LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W). predicted_mask (np.ndarray | torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
start_index (int): Starting index for naming when metadata is missing. start_index (int): Starting index for naming when metadata is missing.
only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations. only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations.
masks (Tuple[np.ndarray, np.ndarray, np.ndarray] | None): masks (tuple[np.ndarray, np.ndarray, np.ndarray] | None):
A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None. A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None.
""" """
# Base directories (created once per call) # Base directories (created once per call)
@ -1428,14 +1427,14 @@ class CellSegmentator:
os.makedirs(evaluate_dir, exist_ok=True) os.makedirs(evaluate_dir, exist_ok=True)
# Convert tensors to numpy if necessary # Convert tensors to numpy if necessary
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: def to_numpy(x: np.ndarray | torch.Tensor) -> np.ndarray:
return x.cpu().numpy() if isinstance(x, torch.Tensor) else x return x.cpu().numpy() if isinstance(x, torch.Tensor) else x
pred_array = to_numpy(predicted_mask).astype(np.uint16) pred_array = to_numpy(predicted_mask).astype(np.uint16)
# Handle batch dimension # Handle batch dimension
for idx in range(pred_array.shape[0]): for idx in range(pred_array.shape[0]):
batch_sample: Dict[str, Any] = {} batch_sample: dict[str, Any] = {}
# copy per-sample image and meta # copy per-sample image and meta
img = to_numpy(sample["image"]) img = to_numpy(sample["image"])
if img.ndim == 4: if img.ndim == 4:
@ -1467,21 +1466,21 @@ class CellSegmentator:
def __save_single_prediction_mask( def __save_single_prediction_mask(
self, self,
sample: Dict[str, Any], sample: dict[str, Any],
pred_array: np.ndarray, pred_array: np.ndarray,
start_index: int, start_index: int,
masks_dir: str, masks_dir: str,
plots_dir: str, plots_dir: str,
evaluate_dir: str, evaluate_dir: str,
only_masks: bool = False, only_masks: bool = False,
masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None masks: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None
) -> None: ) -> None:
""" """
Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations. Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations.
Assumes output directories already exist. Assumes output directories already exist.
Args: Args:
sample (Dict[str, Any]): Dictionary containing 'image', 'mask', sample (dict(str, Any)): Dictionary containing 'image', 'mask',
and optional 'image_meta_dict' for metadata. and optional 'image_meta_dict' for metadata.
pred_array (np.ndarray): Predicted mask array of shape (C,H,W). pred_array (np.ndarray): Predicted mask array of shape (C,H,W).
start_index (int): Base index for generating filenames when metadata is missing. start_index (int): Base index for generating filenames when metadata is missing.
@ -1489,7 +1488,7 @@ class CellSegmentator:
plots_dir (str): Directory for saving PNG visualizations. plots_dir (str): Directory for saving PNG visualizations.
evaluate_dir (str): Directory for saving PNG visualizations of evaluation results. evaluate_dir (str): Directory for saving PNG visualizations of evaluation results.
only_masks (bool): If True, saves only TIFF mask files; skips PNG plots. only_masks (bool): If True, saves only TIFF mask files; skips PNG plots.
masks (Tuple[np.ndarray, np.ndarray, np.ndarray], optional): A tuple of masks (tuple[np.ndarray, np.ndarray, np.ndarray] | None): A tuple of
true-positive, false-positive, and false-negative mask arrays, true-positive, false-positive, and false-negative mask arrays,
each of shape (C,H,W). Defaults to None. each of shape (C,H,W). Defaults to None.
""" """
@ -1510,7 +1509,7 @@ class CellSegmentator:
"Expected 2D (H,W) or 3D (C,H,W)." "Expected 2D (H,W) or 3D (C,H,W)."
) )
true_mask_array: Optional[np.ndarray] = sample.get("mask") true_mask_array: np.ndarray | None = sample.get("mask")
if isinstance(true_mask_array, np.ndarray): if isinstance(true_mask_array, np.ndarray):
if true_mask_array.ndim == 2: if true_mask_array.ndim == 2:
true_mask_array = np.expand_dims(true_mask_array, axis=0) true_mask_array = np.expand_dims(true_mask_array, axis=0)
@ -1562,7 +1561,7 @@ class CellSegmentator:
file_path: str, file_path: str,
image_data: np.ndarray, image_data: np.ndarray,
predicted_mask: np.ndarray, predicted_mask: np.ndarray,
true_mask: Optional[np.ndarray] = None, true_mask: np.ndarray | None = None,
) -> None: ) -> None:
""" """
Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided. Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided.
@ -1572,7 +1571,7 @@ class CellSegmentator:
image_data (np.ndarray): The original input image array, expected shape (C, H, W). image_data (np.ndarray): The original input image array, expected shape (C, H, W).
predicted_mask (np.ndarray): The predicted mask array, shape (H, W), predicted_mask (np.ndarray): The predicted mask array, shape (H, W),
depending on the task. depending on the task.
true_mask (Optional[np.ndarray], optional): The ground-truth mask array. true_mask (np.ndarray | None): The ground-truth mask array.
If provided, an additional row with true mask and overlap visualization If provided, an additional row with true mask and overlap visualization
will be added to the plot. Default is None. will be added to the plot. Default is None.
@ -1603,7 +1602,7 @@ class CellSegmentator:
img: np.ndarray, img: np.ndarray,
mask: np.ndarray, mask: np.ndarray,
contour_color: str, contour_color: str,
titles: Tuple[str, ...] titles: tuple[str, ...]
): ):
""" """
Plot a row of three panels: original image, mask, and mask boundaries on image. Plot a row of three panels: original image, mask, and mask boundaries on image.
@ -1618,7 +1617,8 @@ class CellSegmentator:
# Panel 1: Original image # Panel 1: Original image
ax0, ax1, ax2 = axes ax0, ax1, ax2 = axes
ax0.imshow(img, cmap='gray' if img.ndim == 2 else None) ax0.imshow(img, cmap='gray' if img.ndim == 2 else None)
ax0.set_title(titles[0]); ax0.axis('off') ax0.set_title(titles[0])
ax0.axis('off')
# Compute boundaries once # Compute boundaries once
boundaries = find_boundaries(mask, mode='thick') boundaries = find_boundaries(mask, mode='thick')
@ -1793,7 +1793,8 @@ class CellSegmentator:
# Get coordinates of all non-zero pixels in the padded mask # Get coordinates of all non-zero pixels in the padded mask
y, x = torch.nonzero(masks_padded, as_tuple=True) y, x = torch.nonzero(masks_padded, as_tuple=True)
y = y.int(); x = x.int() # ensure integer type y = y.int()
x = x.int() # ensure integer type
# Generate 8-connected neighbors (including center) via broadcasted offsets # Generate 8-connected neighbors (including center) via broadcasted offsets
offsets = torch.tensor([ offsets = torch.tensor([
@ -1830,9 +1831,12 @@ class CellSegmentator:
], dtype=np.int16) ], dtype=np.int16)
# Compute centers (pixel indices) and extents via the provided helper # Compute centers (pixel indices) and extents via the provided helper
centers, ext = self.__get_mask_centers_and_extents(mask_channel, slices_arr) centers, ext = self.__get_mask_centers_and_extents(
mask_channel, slices_arr
)
# Move centers to GPU and shift by +1 for padding # Move centers to GPU and shift by +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding # (M, 2); +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1
# Determine number of diffusion iterations # Determine number of diffusion iterations
n_iter = 2 * ext.max() n_iter = 2 * ext.max()
@ -1865,7 +1869,7 @@ class CellSegmentator:
def __get_mask_centers_and_extents( def __get_mask_centers_and_extents(
label_map: np.ndarray, label_map: np.ndarray,
slices_arr: np.ndarray slices_arr: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Compute the centroids and extents of labeled regions in a 2D mask array. Compute the centroids and extents of labeled regions in a 2D mask array.
@ -1923,7 +1927,7 @@ class CellSegmentator:
neighbor_indices: torch.Tensor, neighbor_indices: torch.Tensor,
center_indices: torch.Tensor, center_indices: torch.Tensor,
valid_neighbor_mask: torch.Tensor, valid_neighbor_mask: torch.Tensor,
output_shape: Tuple[int, int], output_shape: tuple[int, int],
num_iterations: int = 200 num_iterations: int = 200
) -> np.ndarray: ) -> np.ndarray:
""" """
@ -1933,7 +1937,7 @@ class CellSegmentator:
neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel. neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel.
center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers. center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers.
valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid. valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid.
output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W). output_shape (tuple(int, int)): Desired 2D shape of the diffusion tensor, e.g., (H, W).
num_iterations (int, optional): Number of diffusion iterations. Defaults to 200. num_iterations (int, optional): Number of diffusion iterations. Defaults to 200.
Returns: Returns:
@ -2242,7 +2246,7 @@ class CellSegmentator:
flow_field: np.ndarray, flow_field: np.ndarray,
initial_coords: np.ndarray, initial_coords: np.ndarray,
num_iters: int = 200 num_iters: int = 200
) -> Union[np.ndarray, torch.Tensor]: ) -> np.ndarray | torch.Tensor:
""" """
Trace pixel positions through a flow field via iterative interpolation. Trace pixel positions through a flow field via iterative interpolation.
@ -2252,7 +2256,7 @@ class CellSegmentator:
num_iters (int): Number of integration steps. num_iters (int): Number of integration steps.
Returns: Returns:
np.ndarray or torch.Tensor: Final (y, x) positions of each point. (np.ndarray | torch.Tensor): Final (y, x) positions of each point.
""" """
dims = 2 dims = 2
# Extract spatial dimensions # Extract spatial dimensions
@ -2383,7 +2387,7 @@ class CellSegmentator:
self, self,
pixel_positions: torch.Tensor, pixel_positions: torch.Tensor,
valid_indices: np.ndarray, valid_indices: np.ndarray,
original_shape: Tuple[int, ...], original_shape: tuple[int, ...],
pad_radius: int = 20, pad_radius: int = 20,
max_size_fraction: float = 0.4 max_size_fraction: float = 0.4
) -> np.ndarray: ) -> np.ndarray:
@ -2534,7 +2538,7 @@ class CellSegmentator:
input_tensor: Tensor, input_tensor: Tensor,
kernel_size: int = 5, kernel_size: int = 5,
axis: int = 1, axis: int = 1,
output_tensor: Optional[Tensor] = None output_tensor: Tensor | None = None
) -> Tensor: ) -> Tensor:
""" """
Memory-efficient 1D max pooling along a specified axis using in-place updates. Memory-efficient 1D max pooling along a specified axis using in-place updates.
@ -2547,7 +2551,7 @@ class CellSegmentator:
input_tensor (Tensor): Source tensor for pooling. input_tensor (Tensor): Source tensor for pooling.
kernel_size (int): Size of the pooling window (must be odd and >= 3). kernel_size (int): Size of the pooling window (must be odd and >= 3).
axis (int): Axis along which to compute 1D max pooling. axis (int): Axis along which to compute 1D max pooling.
output_tensor (Optional[Tensor]): Tensor to store the result. output_tensor (Tensor | None): Tensor to store the result.
If None, a clone of input_tensor is used. If None, a clone of input_tensor is used.
Returns: Returns:
@ -2691,7 +2695,7 @@ class CellSegmentator:
self, self,
mask: np.ndarray, mask: np.ndarray,
flow_network: np.ndarray flow_network: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Compute mean squared error between network-predicted flows and flows derived from masks. Compute mean squared error between network-predicted flows and flows derived from masks.
@ -2700,7 +2704,7 @@ class CellSegmentator:
flow_network (np.ndarray): Network predicted flows of shape [axis, ...]. flow_network (np.ndarray): Network predicted flows of shape [axis, ...].
Returns: Returns:
Tuple[np.ndarray, np.ndarray]: tuple(np.ndarray, np.ndarray):
- flow_errors: 1D array (length = max label) of mean squared error per label. - flow_errors: 1D array (length = max label) of mean squared error per label.
- computed_flows: Array of flows derived from the mask, same shape as flow_network. - computed_flows: Array of flows derived from the mask, same shape as flow_network.

@ -8,7 +8,7 @@ from numpy.typing import NDArray
from numba import jit from numba import jit
from skimage import segmentation from skimage import segmentation
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from typing import Dict, List, Tuple, Any, Union from typing import Any
from core.logger import get_logger from core.logger import get_logger
@ -27,7 +27,7 @@ def compute_f1_score(
true_positives: int, true_positives: int,
false_positives: int, false_positives: int,
false_negatives: int false_negatives: int
) -> Tuple[float, float, float]: ) -> tuple[float, float, float]:
""" """
Computes the precision, recall, and F1-score given the numbers of Computes the precision, recall, and F1-score given the numbers of
true positives, false positives, and false negatives. true positives, false positives, and false negatives.
@ -76,7 +76,7 @@ def compute_confusion_matrix(
ground_truth_mask: np.ndarray, ground_truth_mask: np.ndarray,
predicted_mask: np.ndarray, predicted_mask: np.ndarray,
iou_threshold: float = 0.5 iou_threshold: float = 0.5
) -> Tuple[int, int, int]: ) -> tuple[int, int, int]:
""" """
Computes the confusion matrix elements (true positives, false positives, false negatives) Computes the confusion matrix elements (true positives, false positives, false negatives)
for a single image given the ground truth and predicted masks. for a single image given the ground truth and predicted masks.
@ -114,7 +114,7 @@ def compute_segmentation_tp_fp_fn(
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_error_masks: bool = False, return_error_masks: bool = False,
remove_boundary_objects: bool = True remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]: ) -> dict[str, np.ndarray]:
""" """
Computes TP, FP and FN for segmentation on a single image. Computes TP, FP and FN for segmentation on a single image.
@ -176,7 +176,7 @@ def compute_segmentation_tp_fp_fn(
false_positive_mask_list.append(results.get('fp_mask')) # type: ignore false_positive_mask_list.append(results.get('fp_mask')) # type: ignore
false_negative_mask_list.append(results.get('fn_mask')) # type: ignore false_negative_mask_list.append(results.get('fn_mask')) # type: ignore
output: Dict[str, np.ndarray] = { output: dict[str, np.ndarray] = {
'tp': np.array(true_positive_list), 'tp': np.array(true_positive_list),
'fp': np.array(false_positive_list), 'fp': np.array(false_positive_list),
'fn': np.array(false_negative_list) 'fn': np.array(false_negative_list)
@ -194,7 +194,7 @@ def compute_segmentation_f1_metrics(
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_error_masks: bool = False, return_error_masks: bool = False,
remove_boundary_objects: bool = True remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]: ) -> dict[str, np.ndarray]:
""" """
Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image. Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image.
@ -240,7 +240,7 @@ def compute_segmentation_f1_metrics(
recall_list.append(recall) recall_list.append(recall)
f1_score_list.append(f1_score) f1_score_list.append(f1_score)
output: Dict[str, np.ndarray] = { output: dict[str, np.ndarray] = {
'precision': np.array(precision_list), 'precision': np.array(precision_list),
'recall': np.array(recall_list), 'recall': np.array(recall_list),
'f1_score': np.array(f1_score_list), 'f1_score': np.array(f1_score_list),
@ -255,7 +255,7 @@ def compute_segmentation_average_precision_metrics(
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_error_masks: bool = False, return_error_masks: bool = False,
remove_boundary_objects: bool = True remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]: ) -> dict[str, np.ndarray]:
""" """
Computes the average precision (AP) for segmentation on a single image. Computes the average precision (AP) for segmentation on a single image.
@ -298,7 +298,7 @@ def compute_segmentation_average_precision_metrics(
) )
avg_precision_list.append(avg_precision) avg_precision_list.append(avg_precision)
output: Dict[str, np.ndarray] = { output: dict[str, np.ndarray] = {
'avg_precision': np.array(avg_precision_list) 'avg_precision': np.array(avg_precision_list)
} }
output.update(results) output.update(results)
@ -311,7 +311,7 @@ def compute_batch_segmentation_tp_fp_fn(
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_error_masks: bool = False, return_error_masks: bool = False,
remove_boundary_objects: bool = True remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]: ) -> dict[str, np.ndarray]:
""" """
Computes segmentation TP, FP and FN for a batch of images. Computes segmentation TP, FP and FN for a batch of images.
@ -361,7 +361,7 @@ def compute_batch_segmentation_tp_fp_fn(
fp_mask_list.append(result.get('fp_mask')) # type: ignore fp_mask_list.append(result.get('fp_mask')) # type: ignore
fn_mask_list.append(result.get('fn_mask')) # type: ignore fn_mask_list.append(result.get('fn_mask')) # type: ignore
output: Dict[str, np.ndarray] = { output: dict[str, np.ndarray] = {
'tp': np.stack(tp_list, axis=0), 'tp': np.stack(tp_list, axis=0),
'fp': np.stack(fp_list, axis=0), 'fp': np.stack(fp_list, axis=0),
'fn': np.stack(fn_list, axis=0) 'fn': np.stack(fn_list, axis=0)
@ -379,7 +379,7 @@ def compute_batch_segmentation_f1_metrics(
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_error_masks: bool = False, return_error_masks: bool = False,
remove_boundary_objects: bool = True remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]: ) -> dict[str, np.ndarray]:
""" """
Computes segmentation F1 metrics for a batch of images. Computes segmentation F1 metrics for a batch of images.
@ -435,7 +435,7 @@ def compute_batch_segmentation_f1_metrics(
fp_mask_list.append(result.get('fp_mask')) # type: ignore fp_mask_list.append(result.get('fp_mask')) # type: ignore
fn_mask_list.append(result.get('fn_mask')) # type: ignore fn_mask_list.append(result.get('fn_mask')) # type: ignore
output: Dict[str, np.ndarray] = { output: dict[str, np.ndarray] = {
'precision': np.stack(precision_list, axis=0), 'precision': np.stack(precision_list, axis=0),
'recall': np.stack(recall_list, axis=0), 'recall': np.stack(recall_list, axis=0),
'f1_score': np.stack(f1_score_list, axis=0), 'f1_score': np.stack(f1_score_list, axis=0),
@ -456,7 +456,7 @@ def compute_batch_segmentation_average_precision_metrics(
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_error_masks: bool = False, return_error_masks: bool = False,
remove_boundary_objects: bool = True remove_boundary_objects: bool = True
) -> Dict[str, NDArray]: ) -> dict[str, np.ndarray]:
""" """
Computes segmentation average precision metrics for a batch of images. Computes segmentation average precision metrics for a batch of images.
@ -508,7 +508,7 @@ def compute_batch_segmentation_average_precision_metrics(
fp_mask_list.append(result.get('fp_mask')) # type: ignore fp_mask_list.append(result.get('fp_mask')) # type: ignore
fn_mask_list.append(result.get('fn_mask')) # type: ignore fn_mask_list.append(result.get('fn_mask')) # type: ignore
output: Dict[str, NDArray] = { output: dict[str, np.ndarray] = {
'avg_precision': np.stack(avg_precision_list, axis=0), 'avg_precision': np.stack(avg_precision_list, axis=0),
'tp': np.stack(tp_list, axis=0), 'tp': np.stack(tp_list, axis=0),
'fp': np.stack(fp_list, axis=0), 'fp': np.stack(fp_list, axis=0),
@ -555,7 +555,7 @@ def _process_instance_matching(
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_masks: bool = False, return_masks: bool = False,
without_boundary_objects: bool = True without_boundary_objects: bool = True
) -> Dict[str, Union[int, NDArray[np.uint8]]]: ) -> dict[str, int | NDArray[np.uint8]]:
""" """
Processes instance matching on a full image by performing the following steps: Processes instance matching on a full image by performing the following steps:
- Removes objects that touch the image boundary and reindexes the masks. - Removes objects that touch the image boundary and reindexes the masks.
@ -597,8 +597,8 @@ def _process_instance_matching(
fn_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8) fn_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8)
# Mark all ground truth objects as false negatives. # Mark all ground truth objects as false negatives.
fn_mask[ground_truth_mask > 0] = 1 fn_mask[ground_truth_mask > 0] = 1
result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) # type: ignore
return result return result # type: ignore
# Compute the IoU matrix for the processed masks. # Compute the IoU matrix for the processed masks.
iou_matrix = _calculate_iou(processed_ground_truth, processed_prediction) iou_matrix = _calculate_iou(processed_ground_truth, processed_prediction)
@ -640,11 +640,11 @@ def _process_instance_matching(
for pred_label in (all_prediction_labels - matched_prediction_labels): for pred_label in (all_prediction_labels - matched_prediction_labels):
fp_mask[processed_prediction == pred_label] = 1 fp_mask[processed_prediction == pred_label] = 1
result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) # type: ignore
return result return result # type: ignore
def _compute_optimal_matching_pairs(iou_matrix: np.ndarray, iou_threshold: float) -> List[Any]: def _compute_optimal_matching_pairs(iou_matrix: np.ndarray, iou_threshold: float) -> list[Any]:
""" """
Computes the optimal matching pairs between ground truth and predicted masks using the IoU matrix. Computes the optimal matching pairs between ground truth and predicted masks using the IoU matrix.
@ -687,7 +687,7 @@ def _compute_patch_based_metrics(
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
return_masks: bool = False, return_masks: bool = False,
without_boundary_objects: bool = True without_boundary_objects: bool = True
) -> Dict[str, Union[int, NDArray[np.uint8]]]: ) -> dict[str, int | NDArray[np.uint8]]:
""" """
Computes segmentation metrics using a patch-based approach for very large images. Computes segmentation metrics using a patch-based approach for very large images.
@ -747,7 +747,7 @@ def _compute_patch_based_metrics(
padded_fp_mask[y_start:y_end, x_start:x_end] = patch_results.get('fp_mask', 0) # type: ignore padded_fp_mask[y_start:y_end, x_start:x_end] = patch_results.get('fp_mask', 0) # type: ignore
padded_fn_mask[y_start:y_end, x_start:x_end] = patch_results.get('fn_mask', 0) # type: ignore padded_fn_mask[y_start:y_end, x_start:x_end] = patch_results.get('fn_mask', 0) # type: ignore
results: Dict[str, Union[int, np.ndarray]] = {'tp': total_tp, 'fp': total_fp, 'fn': total_fn} results: dict[str, int | np.ndarray] = {'tp': total_tp, 'fp': total_fp, 'fn': total_fn}
if return_masks: if return_masks:
# Crop the padded masks back to the original image size. # Crop the padded masks back to the original image size.
results.update({ results.update({

@ -1,5 +1,4 @@
import os import os
from typing import Tuple
from config import Config, WandbConfig, DatasetConfig, ComponentConfig from config import Config, WandbConfig, DatasetConfig, ComponentConfig
@ -8,7 +7,7 @@ from core import (
) )
def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str: def prompt_choice(prompt_message: str, options: tuple[str, ...]) -> str:
""" """
Prompt the user with a list of options and return the selected option. Prompt the user with a list of options and return the selected option.
""" """

@ -1,20 +1,25 @@
import os import os
import sys
import argparse import argparse
import wandb import wandb
from config import Config from config import Config
from core.data import * from core.data import (
get_train_transforms,
get_valid_transforms,
get_test_transforms,
get_predict_transforms
)
from core.segmentator import CellSegmentator from core.segmentator import CellSegmentator
def main(): def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train or predict cell segmentator with specified config file." description="Train or predict cell segmentator with specified config file."
) )
parser.add_argument( parser.add_argument(
'-c', '--config', '-c', '--config',
type=str, type=str,
default='config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json',
help='Path to the JSON config file' help='Path to the JSON config file'
) )
parser.add_argument( parser.add_argument(
@ -36,6 +41,10 @@ def main():
' masks without additional visualizations') ' masks without additional visualizations')
) )
if len(sys.argv) == 1:
parser.print_help()
sys.exit(0)
args = parser.parse_args() args = parser.parse_args()
mode = args.mode mode = args.mode
@ -44,7 +53,7 @@ def main():
if mode == 'train' and not config.dataset_config.is_training: if mode == 'train' and not config.dataset_config.is_training:
raise ValueError( raise ValueError(
f"Config is not set for training (is_training=False), but mode 'train' was requested." "Config is not set for training (is_training=False), but mode 'train' was requested."
) )
if mode in ('test', 'predict') and config.dataset_config.is_training: if mode in ('test', 'predict') and config.dataset_config.is_training:
raise ValueError( raise ValueError(

Loading…
Cancel
Save