refactor: migrate all type annotations from Python 3.9 to 3.10 syntax

Replaced typing module generics like List, Dict, Tuple with built-in alternatives (list, dict, tuple).
Updated code to use new union syntax (X | Y) instead of Union[X, Y].
master
laynholt 1 month ago
parent 40c21d5456
commit b0e7b21a21

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel
from .wandb_config import WandbConfig
@ -13,7 +13,7 @@ class ComponentConfig(BaseModel):
name: str
params: BaseModel
def dump(self) -> Dict[str, Any]:
def dump(self) -> dict[str, Any]:
"""
Recursively serializes the component into a dictionary.
@ -24,22 +24,18 @@ class ComponentConfig(BaseModel):
params_dump = self.params.model_dump()
else:
params_dump = self.params
return {
"name": self.name,
"params": params_dump
}
return {"name": self.name, "params": params_dump}
class Config(BaseModel):
model: ComponentConfig
dataset_config: DatasetConfig
wandb_config: WandbConfig
criterion: Optional[ComponentConfig] = None
optimizer: Optional[ComponentConfig] = None
scheduler: Optional[ComponentConfig] = None
criterion: ComponentConfig | None = None
optimizer: ComponentConfig | None = None
scheduler: ComponentConfig | None = None
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""
Produce a JSONserializable dict of this config, including nested
ComponentConfig and DatasetConfig entries. Useful for saving to file
@ -49,7 +45,7 @@ class Config(BaseModel):
A dict with keys 'model', 'dataset_config', and (if set)
'criterion', 'optimizer', 'scheduler'.
"""
data: Dict[str, Any] = {
data: dict[str, Any] = {
"model": self.model.dump(),
"dataset_config": self.dataset_config.model_dump(),
}
@ -62,7 +58,6 @@ class Config(BaseModel):
data["wandb"] = self.wandb_config.model_dump()
return data
def save_json(self, file_path: str, indent: int = 4) -> None:
"""
Save this config to a JSON file.
@ -75,7 +70,6 @@ class Config(BaseModel):
with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(config_dict, indent=indent))
@classmethod
def load_json(cls, file_path: str) -> "Config":
"""
@ -96,7 +90,9 @@ class Config(BaseModel):
wandb_config = WandbConfig(**data.get("wandb", {}))
# Helper function to parse registry fields.
def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]:
def parse_field(
component_data: dict[str, Any], registry_getter
) -> ComponentConfig | None:
name = component_data.get("name")
params_data = component_data.get("params", {})
@ -107,16 +103,31 @@ class Config(BaseModel):
return None
from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
ModelRegistry,
CriterionRegistry,
OptimizerRegistry,
SchedulerRegistry,
)
parsed_model = parse_field(data.get("model", {}), lambda key: ModelRegistry.get_model_params(key))
parsed_criterion = parse_field(data.get("criterion", {}), lambda key: CriterionRegistry.get_criterion_params(key))
parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key))
parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key))
parsed_model = parse_field(
data.get("model", {}),
lambda key: ModelRegistry.get_model_params(key),
)
parsed_criterion = parse_field(
data.get("criterion", {}),
lambda key: CriterionRegistry.get_criterion_params(key),
)
parsed_optimizer = parse_field(
data.get("optimizer", {}),
lambda key: OptimizerRegistry.get_optimizer_params(key),
)
parsed_scheduler = parse_field(
data.get("scheduler", {}),
lambda key: SchedulerRegistry.get_scheduler_params(key),
)
if parsed_model is None:
raise ValueError('Failed to load model information')
raise ValueError("Failed to load model information")
return cls(
model=parsed_model,
@ -124,5 +135,5 @@ class Config(BaseModel):
criterion=parsed_criterion,
optimizer=parsed_optimizer,
scheduler=parsed_scheduler,
wandb_config=wandb_config
wandb_config=wandb_config,
)

@ -1,5 +1,5 @@
from pydantic import BaseModel, model_validator, field_validator
from typing import Any, Dict, Optional, Union
from typing import Any
import os
@ -7,7 +7,7 @@ class DatasetCommonConfig(BaseModel):
"""
Common configuration fields shared by both training and testing.
"""
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
seed: int | None = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
roi_size: int = 512 # The size of the square window for cropping
@ -65,9 +65,9 @@ class DatasetTrainingConfig(BaseModel):
pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo()
split: TrainingSplitInfo = TrainingSplitInfo()
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.1 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.2 # Testing data size (int for static, float in (0,1] for dynamic)
train_size: int | float = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: int | float = 0.1 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: int | float = 0.2 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data
@ -78,7 +78,7 @@ class DatasetTrainingConfig(BaseModel):
@field_validator("train_size", "valid_size", "test_size", mode="before")
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]:
def validate_sizes(cls, v: int | float) -> int | float:
"""
Validates size values:
- If provided as a float, must be in the range (0, 1].
@ -145,12 +145,12 @@ class DatasetTestingConfig(BaseModel):
Configuration fields used only in testing mode.
"""
test_dir: str = "." # Test data directory; must be non-empty
test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
test_size: int | float = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
test_offset: int = 0 # Offset for testing data
shuffle: bool = True # Shuffle data
@field_validator("test_size", mode="before")
def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]:
def validate_test_size(cls, v: int | float) -> int | float:
"""
Validates the test_size value.
"""
@ -224,7 +224,7 @@ class DatasetConfig(BaseModel):
raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}")
return self
def model_dump(self, **kwargs) -> Dict[str, Any]:
def model_dump(self, **kwargs) -> dict[str, Any]:
"""
Dumps only the relevant configuration depending on the is_training flag.
Only the nested configuration (training or testing) along with common fields is returned.

@ -1,5 +1,5 @@
from pydantic import BaseModel, model_validator
from typing import Any, Dict, Optional
from typing import Any
class WandbConfig(BaseModel):
@ -7,12 +7,12 @@ class WandbConfig(BaseModel):
Configuration for Weights & Biases logging.
"""
use_wandb: bool = False # Whether to enable WandB logging
project: Optional[str] = None # WandB project name
group: Optional[str] = None # WandB group name
entity: Optional[str] = None # WandB entity (user or team)
name: Optional[str] = None # Name of the run
tags: Optional[list[str]] = None # List of tags for the run
notes: Optional[str] = None # Notes or description for the run
project: str | None = None # WandB project name
group: str | None = None # WandB group name
entity: str | None = None # WandB entity (user or team)
name: str | None = None # Name of the run
tags: list[str] | None = None # List of tags for the run
notes: str | None = None # Notes or description for the run
save_code: bool = True # Whether to save the code to WandB
@model_validator(mode="after")
@ -22,7 +22,7 @@ class WandbConfig(BaseModel):
raise ValueError("When use_wandb=True, 'project' must be provided")
return self
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""
Return a dict of all W&B parameters, excluding 'use_wandb' and any None values.
"""

@ -1,6 +1,6 @@
from .cell_aware import IntensityDiversification
from .load_image import CustomLoadImage, CustomLoadImaged
from .normalize_image import CustomNormalizeImage, CustomNormalizeImaged
from .load_image import CustomLoadImaged
from .normalize_image import CustomNormalizeImaged
from monai.transforms import * # type: ignore

@ -1,7 +1,7 @@
import copy
import torch
import numpy as np
from typing import Dict, Sequence, Tuple, Union
from typing import Sequence
from skimage.segmentation import find_boundaries
from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore
@ -26,14 +26,14 @@ class BoundaryExclusion(MapTransform):
def __init__(self, keys: Sequence[str] = ("mask",), allow_missing_keys: bool = False) -> None:
"""
Args:
keys (Sequence[str]): Keys in the input dictionary corresponding to the label image.
keys (Sequence(str)): Keys in the input dictionary corresponding to the label image.
Default is ("mask",).
allow_missing_keys (bool): If True, missing keys in the input will be ignored.
Default is False.
"""
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
def __call__(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""
Apply the boundary exclusion transform to the label image.
@ -46,10 +46,10 @@ class BoundaryExclusion(MapTransform):
6. Assigning the transformed label back into the input dictionary.
Args:
data (Dict[str, np.ndarray]): Dictionary containing at least the "mask" key with a label image.
data (Dict(str, np.ndarray)): Dictionary containing at least the "mask" key with a label image.
Returns:
Dict[str, np.ndarray]: The input dictionary with the "mask" key updated after boundary exclusion.
Dict(str, np.ndarray): The input dictionary with the "mask" key updated after boundary exclusion.
"""
# Retrieve the original label image.
label_original: np.ndarray = data["mask"]
@ -100,17 +100,17 @@ class IntensityDiversification(MapTransform):
self,
keys: Sequence[str] = ("image",),
change_cell_ratio: float = 0.4,
scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7),
scale_factors: tuple[float, float] | float = (0.0, 0.7),
allow_missing_keys: bool = False,
) -> None:
"""
Args:
keys (Sequence[str]): Keys in the input dictionary corresponding to the image.
keys (Sequence(str)): Keys in the input dictionary corresponding to the image.
Default is ("image",).
change_cell_ratio (float): Ratio of cells to apply the intensity scaling.
For example, 0.4 means 40% of the cells will be transformed.
Default is 0.4.
scale_factors (Sequence[float]): Factors used for random intensity scaling.
scale_factors (tuple(float, float) | float): Factors used for random intensity scaling.
Default is (0.0, 0.7).
allow_missing_keys (bool): If True, missing keys in the input will be ignored.
Default is False.
@ -120,7 +120,7 @@ class IntensityDiversification(MapTransform):
# Compose a random intensity scaling transform with 100% probability.
self.randscale_intensity = Compose([RandScaleIntensity(prob=1.0, factors=scale_factors)])
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
def __call__(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""
Apply a cell-wise intensity diversification transform to an input image.
@ -141,12 +141,12 @@ class IntensityDiversification(MapTransform):
9. Combine the unchanged and modified parts to update the image for that channel.
Args:
data (Dict[str, np.ndarray]): A dictionary containing:
data (dict(str, np.ndarray)): A dictionary containing:
- "image": The original image array.
- "mask": The corresponding cell label image array.
Returns:
Dict[str, np.ndarray]: The updated dictionary with the "image" key modified after applying
dict(str, np.ndarray): The updated dictionary with the "image" key modified after applying
the intensity transformation.
Raises:

@ -1,7 +1,7 @@
import numpy as np
import tifffile as tif
import skimage.io as io
from typing import Final, List, Optional, Sequence, Type, Union
from typing import Final, Sequence, Type
from monai.utils.enums import PostFix
from monai.utils.module import optional_import
@ -45,7 +45,7 @@ class CustomLoadImage(LoadImage):
"""
def __init__(
self,
reader: Optional[Union[ImageReader, Type[ImageReader], str]] = None,
reader: ImageReader | Type[ImageReader] | str | None = None,
image_only: bool = False,
dtype: DtypeLike = np.float32,
ensure_channel_first: bool = False,
@ -75,9 +75,9 @@ class CustomLoadImaged(LoadImaged):
def __init__(
self,
keys: KeysCollection,
reader: Optional[Union[Type[ImageReader], str]] = None,
reader: Type[ImageReader] | str | None = None,
dtype: DtypeLike = np.float32,
meta_keys: Optional[KeysCollection] = None,
meta_keys: KeysCollection | None = None,
meta_key_postfix: str = DEFAULT_POST_FIX,
overwriting: bool = False,
image_only: bool = False,
@ -141,13 +141,13 @@ class UniversalImageReader(NumpyReader):
(e.g., repeating or cropping channels).
"""
def __init__(
self, channel_dim: Optional[int] = None, **kwargs,
):
self, channel_dim: int | None = None, **kwargs,
) -> None:
super().__init__(channel_dim=channel_dim, **kwargs)
self.kwargs = kwargs
self.channel_dim = channel_dim
def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool:
def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
"""
Check if the file format is supported for reading.
@ -155,7 +155,7 @@ class UniversalImageReader(NumpyReader):
"""
return has_itk or is_supported_format(filename, SUPPORTED_IMAGE_FORMATS)
def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs):
def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
"""
Read image(s) from the given path.
@ -166,7 +166,7 @@ class UniversalImageReader(NumpyReader):
Returns:
A single image or a list of images depending on the number of paths provided.
"""
images: List[np.ndarray] = [] # List to store the loaded images
images: list[np.ndarray] = [] # List to store the loaded images
# Convert data to a tuple to support multiple files
filenames: Sequence[PathLike] = ensure_tuple(data)

@ -2,7 +2,7 @@ import numpy as np
from skimage import exposure
from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import Transform, MapTransform
from typing import Dict, Hashable, Mapping, Sequence
from typing import Hashable, Mapping, Sequence
__all__ = [
"CustomNormalizeImage",
@ -23,7 +23,7 @@ class CustomNormalizeImage(Transform):
def __init__(self, percentiles: Sequence[float] = (0, 99), channel_wise: bool = False) -> None:
"""
Args:
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling.
percentiles (Sequence(float)): Lower and upper percentiles used for intensity scaling.
Default is (0, 99).
channel_wise (bool): Whether to apply normalization on each channel individually.
Default is False.
@ -106,7 +106,7 @@ class CustomNormalizeImaged(MapTransform):
"""
Args:
keys (KeysCollection): Keys identifying the image entries in the dictionary.
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling.
percentiles (Sequence(float)): Lower and upper percentiles used for intensity scaling.
Default is (1, 99).
channel_wise (bool): Whether to apply normalization on each channel individually.
Default is False.
@ -117,7 +117,7 @@ class CustomNormalizeImaged(MapTransform):
# Create an instance of the normalization transform with specified parameters.
self.normalizer: CustomNormalizeImage = CustomNormalizeImage(percentiles, channel_wise)
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:
"""
Apply the normalization transform to each image in the input dictionary.
@ -125,10 +125,10 @@ class CustomNormalizeImaged(MapTransform):
data (Mapping[Hashable, np.ndarray]): A dictionary mapping keys to numpy arrays representing images.
Returns:
Dict[Hashable, np.ndarray]: A dictionary with the same keys where the images have been normalized.
dict(Hashable, np.ndarray): A dictionary with the same keys where the images have been normalized.
"""
# Copy the input dictionary to avoid modifying the original data.
d: Dict[Hashable, np.ndarray] = dict(data)
d: dict[Hashable, np.ndarray] = dict(data)
# Iterate over each key specified in the transform and normalize the corresponding image.
for key in self.keys:
d[key] = self.normalizer(d[key])

@ -1,6 +1,6 @@
import torch
import numpy as np
from typing import Hashable, List, Sequence, Optional, Tuple
from typing import Sequence
from monai.utils.misc import fall_back_tuple
from monai.data.meta_tensor import MetaTensor
@ -14,7 +14,7 @@ logger = get_logger(__name__)
def _compute_multilabel_bbox(
mask: np.ndarray
) -> Optional[Tuple[List[int], List[int], List[int], List[int]]]:
) -> tuple[list[int], list[int], list[int], list[int]] | None:
"""
Compute per-channel bounding-box constraints and return lists of limits for each axis.
@ -33,10 +33,10 @@ def _compute_multilabel_bbox(
if channels.size == 0:
return None
top_mins: List[int] = []
top_maxs: List[int] = []
left_mins: List[int] = []
left_maxs: List[int] = []
top_mins: list[int] = []
top_maxs: list[int] = []
left_mins: list[int] = []
left_maxs: list[int] = []
C = mask.shape[0]
for ch in range(C):
rs, cs = np.nonzero(mask[ch])
@ -74,7 +74,7 @@ class SpatialCropAllClasses(Randomizable, Crop):
super().__init__(lazy=lazy)
self.roi_size = tuple(roi_size)
self.num_candidates = num_candidates
self._slices: Optional[Tuple[slice, ...]] = None
self._slices: tuple[slice, ...] | None = None
def randomize(self, img_size: Sequence[int]) -> None: # type: ignore
"""
@ -139,7 +139,7 @@ class SpatialCropAllClasses(Randomizable, Crop):
slice(left, left + crop_w),
)
def __call__(self, img: torch.Tensor, lazy: Optional[bool] = None) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore
"""
On first call (mask), computes crop. On subsequent (image), just applies.
Raises if mask not provided first.

@ -1,4 +1,4 @@
from typing import Dict, Final, Tuple, Type, List, Any, Union
from typing import Final, Type, Any
from pydantic import BaseModel
from .base import BaseLoss
@ -16,7 +16,7 @@ __all__ = [
class CriterionRegistry:
"""Registry of loss functions and their parameter classes with case-insensitive lookup."""
__CRITERIONS: Final[Dict[str, Dict[str, Any]]] = {
__CRITERIONS: Final[dict[str, dict[str, Any]]] = {
"CrossEntropyLoss": {
"class": CrossEntropyLoss,
"params": CrossEntropyLossParams,
@ -36,7 +36,7 @@ class CriterionRegistry:
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Any]:
def __get_entry(cls, name: str) -> dict[str, Any]:
"""
Private method to retrieve the criterion entry from the registry using case-insensitive lookup.
@ -44,7 +44,7 @@ class CriterionRegistry:
name (str): The name of the loss function.
Returns:
Dict[str, Any]: A dictionary containing the keys 'class' and 'params'.
dict(str, Any): A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the loss function is not found.
@ -67,7 +67,7 @@ class CriterionRegistry:
name (str): Name of the loss function.
Returns:
Type[BaseLoss]: The loss function class.
Type(BaseLoss): The loss function class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@ -81,17 +81,17 @@ class CriterionRegistry:
name (str): Name of the loss function.
Returns:
Type[BaseModel]: The loss function parameter class.
Type(BaseModel): The loss function parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_criterions(cls) -> Tuple[str, ...]:
def get_available_criterions(cls) -> tuple[str, ...]:
"""
Returns a tuple of available loss function names in their original case.
Returns:
Tuple[str]: Tuple of available loss function names.
tuple(str): Tuple of available loss function names.
"""
return tuple(cls.__CRITERIONS.keys())

@ -1,15 +1,12 @@
import abc
import torch
import torch.nn as nn
from pydantic import BaseModel
from typing import Dict, Any, Optional
from monai.metrics.cumulative_average import CumulativeAverage
class BaseLoss(nn.Module, abc.ABC):
class BaseLoss(torch.nn.Module, abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, params: Optional[BaseModel] = None):
def __init__(self, params: BaseModel | None = None) -> None:
super().__init__()
@ -28,16 +25,16 @@ class BaseLoss(nn.Module, abc.ABC):
@abc.abstractmethod
def get_loss_metrics(self) -> Dict[str, float]:
def get_loss_metrics(self) -> dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the loss name and average loss value.
dict(str, float): A dictionary containing the loss name and average loss value.
"""
@abc.abstractmethod
def reset_metrics(self):
def reset_metrics(self) -> None:
"""Resets the stored loss metrics."""

@ -1,6 +1,8 @@
from .base import *
from typing import List, Literal, Union
from .base import BaseLoss
import torch
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict
from monai.metrics.cumulative_average import CumulativeAverage
class BCELossParams(BaseModel):
@ -11,11 +13,11 @@ class BCELossParams(BaseModel):
with_logits: bool = False
weight: Optional[List[Union[int, float]]] = None # Sample weights
weight: list[int | float] | None = None # Sample weights
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss
pos_weight: list[int | float] | None = None # Used only for BCEWithLogitsLoss
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`.
@ -23,7 +25,7 @@ class BCELossParams(BaseModel):
- Ensures only the valid parameters are passed based on the loss function.
Returns:
Dict[str, Any]: Filtered dictionary of parameters.
dict(str, Any): Filtered dictionary of parameters.
"""
loss_kwargs = self.model_dump()
if not self.with_logits:
@ -47,19 +49,23 @@ class BCELoss(BaseLoss):
Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics.
"""
def __init__(self, params: Optional[BCELossParams] = None):
def __init__(self, params: BCELossParams | None = None) -> None:
"""
Initializes the loss function with optional BCELoss parameters.
Args:
params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None).
params (BCELossParams | None): Parameters for nn.BCELoss (default: None).
"""
super().__init__(params=params)
with_logits = params.with_logits if params is not None else False
_bce_params = params.asdict() if params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults
self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params)
self.bce_loss = (
torch.nn.BCEWithLogitsLoss(**_bce_params)
if with_logits
else torch.nn.BCELoss(**_bce_params)
)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_bce_metric = CumulativeAverage()
@ -90,18 +96,18 @@ class BCELoss(BaseLoss):
return loss
def get_loss_metrics(self) -> Dict[str, float]:
def get_loss_metrics(self) -> dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average BCE loss.
dict(str, float): A dictionary containing the average BCE loss.
"""
return {
"loss": round(self.loss_bce_metric.aggregate().item(), 4),
}
def reset_metrics(self):
def reset_metrics(self) -> None:
"""Resets the stored loss metrics."""
self.loss_bce_metric.reset()

@ -1,6 +1,8 @@
from .base import *
from typing import List, Literal, Union
from .base import BaseLoss
import torch
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict
from monai.metrics.cumulative_average import CumulativeAverage
class CrossEntropyLossParams(BaseModel):
@ -9,17 +11,17 @@ class CrossEntropyLossParams(BaseModel):
"""
model_config = ConfigDict(frozen=True)
weight: Optional[List[Union[int, float]]] = None
weight: list[int | float] | None = None
ignore_index: int = -100
reduction: Literal["none", "mean", "sum"] = "mean"
label_smoothing: float = 0.0
def asdict(self):
def asdict(self) -> dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.CrossEntropyLoss`.
Returns:
Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss.
dict(str, Any): Dictionary of parameters for nn.CrossEntropyLoss.
"""
loss_kwargs = self.model_dump()
@ -36,18 +38,18 @@ class CrossEntropyLoss(BaseLoss):
Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics.
"""
def __init__(self, params: Optional[CrossEntropyLossParams] = None):
def __init__(self, params: CrossEntropyLossParams | None = None) -> None:
"""
Initializes the loss function with optional CrossEntropyLoss parameters.
Args:
params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None).
params (CrossEntropyLossParams | None): Parameters for nn.CrossEntropyLoss (default: None).
"""
super().__init__(params=params)
_ce_params = params.asdict() if params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults
self.ce_loss = nn.CrossEntropyLoss(**_ce_params)
self.ce_loss = torch.nn.CrossEntropyLoss(**_ce_params)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_ce_metric = CumulativeAverage()
@ -78,18 +80,18 @@ class CrossEntropyLoss(BaseLoss):
return loss
def get_loss_metrics(self) -> Dict[str, float]:
def get_loss_metrics(self) -> dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average CrossEntropy loss.
dict(str, float): A dictionary containing the average CrossEntropy loss.
"""
return {
"loss": round(self.loss_ce_metric.aggregate().item(), 4),
}
def reset_metrics(self):
def reset_metrics(self) -> None:
"""Resets the stored loss metrics."""
self.loss_ce_metric.reset()

@ -1,6 +1,8 @@
from .base import *
from typing import Literal
from .base import BaseLoss
import torch
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict
from monai.metrics.cumulative_average import CumulativeAverage
class MSELossParams(BaseModel):
@ -11,12 +13,12 @@ class MSELossParams(BaseModel):
reduction: Literal["none", "mean", "sum"] = "mean"
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.MSELoss`.
Returns:
Dict[str, Any]: Dictionary of parameters for `nn.MSELoss`.
dict(str, Any): Dictionary of parameters for `nn.MSELoss`.
"""
loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
@ -27,18 +29,18 @@ class MSELoss(BaseLoss):
Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics.
"""
def __init__(self, params: Optional[MSELossParams] = None):
def __init__(self, params: MSELossParams | None = None):
"""
Initializes the loss function with optional MSELoss parameters.
Args:
params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None).
params (MSELossParams | None): Parameters for `nn.MSELoss` (default: None).
"""
super().__init__(params=params)
_mse_params = params.asdict() if params is not None else {}
# Initialize MSE loss with user-provided parameters or PyTorch defaults
self.mse_loss = nn.MSELoss(**_mse_params)
self.mse_loss = torch.nn.MSELoss(**_mse_params)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_mse_metric = CumulativeAverage()
@ -67,12 +69,12 @@ class MSELoss(BaseLoss):
return loss
def get_loss_metrics(self) -> Dict[str, float]:
def get_loss_metrics(self) -> dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average MSE loss.
dict(str, float): A dictionary containing the average MSE loss.
"""
return {
"loss": round(self.loss_mse_metric.aggregate().item(), 4),

@ -1,8 +1,11 @@
from .base import *
from .base import BaseLoss
from .bce import BCELossParams
from .mse import MSELossParams
import torch
from typing import Any
from pydantic import BaseModel, ConfigDict
from monai.metrics.cumulative_average import CumulativeAverage
class BCE_MSE_LossParams(BaseModel):
@ -15,12 +18,12 @@ class BCE_MSE_LossParams(BaseModel):
bce_params: BCELossParams = BCELossParams()
mse_params: MSELossParams = MSELossParams()
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.BCELoss` and `nn.MSELoss`.
Returns:
Dict[str, Any]: Dictionary of parameters.
dict(str, Any): Dictionary of parameters.
"""
return {
@ -35,7 +38,7 @@ class BCE_MSE_Loss(BaseLoss):
Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction.
"""
def __init__(self, params: Optional[BCE_MSE_LossParams]):
def __init__(self, params: BCE_MSE_LossParams | None = None):
"""
Initializes the loss function with optional BCE and MSE parameters.
"""
@ -50,14 +53,16 @@ class BCE_MSE_Loss(BaseLoss):
# Choose BCE loss function
self.bce_loss = (
nn.BCEWithLogitsLoss(**_bce_params) if _params.bce_params.with_logits else nn.BCELoss(**_bce_params)
torch.nn.BCEWithLogitsLoss(**_bce_params)
if _params.bce_params.with_logits
else torch.nn.BCELoss(**_bce_params)
)
# Process MSE parameters
_mse_params = _params.mse_params.asdict()
# Initialize MSE loss
self.mse_loss = nn.MSELoss(**_mse_params)
self.mse_loss = torch.nn.MSELoss(**_mse_params)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_bce_metric = CumulativeAverage()
@ -101,12 +106,12 @@ class BCE_MSE_Loss(BaseLoss):
return total_loss
def get_loss_metrics(self) -> Dict[str, float]:
def get_loss_metrics(self) -> dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average BCE and MSE loss.
dict(str, float): A dictionary containing the average BCE and MSE loss.
"""
return {
"bce_loss": round(self.loss_bce_metric.aggregate().item(), 4),

@ -1,5 +1,5 @@
import torch.nn as nn
from typing import Dict, Final, Tuple, Type, Any, List, Union
from torch import nn
from typing import Final, Type, Any
from pydantic import BaseModel
from .model_v import ModelV, ModelVParams
@ -16,7 +16,7 @@ class ModelRegistry:
"""Registry for models and their parameter classes with case-insensitive lookup."""
# Single dictionary storing both model classes and parameter classes.
__MODELS: Final[Dict[str, Dict[str, Type[Any]]]] = {
__MODELS: Final[dict[str, dict[str, Type[Any]]]] = {
"ModelV": {
"class": ModelV,
"params": ModelVParams,
@ -24,7 +24,7 @@ class ModelRegistry:
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
"""
Private method to retrieve the model entry from the registry using case-insensitive lookup.
@ -32,7 +32,7 @@ class ModelRegistry:
name (str): The name of the model.
Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
dict(str, Type[Any]): A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the model is not found.
@ -55,7 +55,7 @@ class ModelRegistry:
name (str): Name of the model.
Returns:
Type[nn.Module]: The model class.
Type(torch.nn.Module): The model class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@ -69,17 +69,17 @@ class ModelRegistry:
name (str): Name of the model.
Returns:
Type[BaseModel]: The model parameter class.
Type(BaseModel): The model parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_models(cls) -> Tuple[str, ...]:
def get_available_models(cls) -> tuple[str, ...]:
"""
Returns a tuple of available model names in their original case.
Returns:
Tuple[str]: Tuple of available model names.
Tuple(str): Tuple of available model names.
"""
return tuple(cls.__MODELS.keys())

@ -1,8 +1,7 @@
from typing import List, Optional
import torch
import torch.nn as nn
from torch import nn
from typing import Any
from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation
@ -15,18 +14,18 @@ class ModelVParams(BaseModel):
model_config = ConfigDict(frozen=True)
encoder_name: str = "mit_b5" # Default encoder
encoder_weights: Optional[str] = "imagenet" # Pre-trained weights
decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration
encoder_weights: str | None = "imagenet" # Pre-trained weights
decoder_channels: list[int] = [1024, 512, 256, 128, 64] # Decoder configuration
decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels
in_channels: int = 3 # Number of input channels
out_classes: int = 1 # Number of output classes
def asdict(self):
def asdict(self) -> dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.ModelV`.
Returns:
Dict[str, Any]: Dictionary of parameters for nn.ModelV.
dict(str, Any): Dictionary of parameters for nn.ModelV.
"""
loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
@ -84,11 +83,11 @@ class DeepSegmentationHead(nn.Sequential):
in_channels: int,
out_channels: int,
kernel_size: int = 3,
activation: Optional[str] = None,
activation: str | None = None,
upsampling: int = 1,
) -> None:
# Define a sequence of layers for the segmentation head
layers: List[nn.Module] = [
layers: list[nn.Module] = [
nn.Conv2d(
in_channels,
in_channels // 2,

@ -1,5 +1,5 @@
from pydantic import BaseModel
from typing import Dict, Final, Tuple, Type, List, Any, Union
from typing import Final, Type, Any
from .base import BaseOptimizer
from .adam import AdamParams, AdamOptimizer
@ -16,7 +16,7 @@ class OptimizerRegistry:
"""Registry for optimizers and their parameter classes with case-insensitive lookup."""
# Single dictionary storing both optimizer classes and parameter classes.
__OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
__OPTIMIZERS: Final[dict[str, dict[str, Type[Any]]]] = {
"SGD": {
"class": SGDOptimizer,
"params": SGDParams,
@ -32,7 +32,7 @@ class OptimizerRegistry:
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
"""
Private method to retrieve the optimizer entry from the registry using case-insensitive lookup.
@ -40,7 +40,7 @@ class OptimizerRegistry:
name (str): The name of the optimizer.
Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
dict(str, Type(Any)): A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the optimizer is not found.
@ -63,7 +63,7 @@ class OptimizerRegistry:
name (str): Name of the optimizer.
Returns:
Type[BaseOptimizer]: The optimizer class.
Type(BaseOptimizer): The optimizer class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@ -77,17 +77,17 @@ class OptimizerRegistry:
name (str): Name of the optimizer.
Returns:
Type[BaseModel]: The optimizer parameter class.
Type(BaseModel): The optimizer parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_optimizers(cls) -> Tuple[str, ...]:
def get_available_optimizers(cls) -> tuple[str, ...]:
"""
Returns a tuple of available optimizer names in their original case.
Returns:
Tuple[str]: Tuple of available optimizer names.
Tuple(str): Tuple of available optimizer names.
"""
return tuple(cls.__OPTIMIZERS.keys())

@ -1,6 +1,6 @@
import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple
from typing import Any, Iterable
from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
@ -10,12 +10,12 @@ class AdamParams(BaseModel):
model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate
betas: Tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages
betas: tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages
eps: float = 1e-8 # Term added to denominator for numerical stability
weight_decay: float = 0.0 # L2 regularization
amsgrad: bool = False # Whether to use the AMSGrad variant
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.Adam`."""
return self.model_dump()
@ -25,7 +25,7 @@ class AdamOptimizer(BaseOptimizer):
Wrapper around torch.optim.Adam.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams):
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams) -> None:
"""
Initializes the Adam optimizer with given parameters.

@ -1,6 +1,6 @@
import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple
from typing import Any, Iterable
from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
@ -10,12 +10,12 @@ class AdamWParams(BaseModel):
model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate
betas: Tuple[float, ...] = (0.9, 0.999) # Adam coefficients
betas: tuple[float, ...] = (0.9, 0.999) # Adam coefficients
eps: float = 1e-8 # Numerical stability
weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay)
amsgrad: bool = False # Whether to use the AMSGrad variant
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.AdamW`."""
return self.model_dump()
@ -25,7 +25,7 @@ class AdamWOptimizer(BaseOptimizer):
Wrapper around torch.optim.AdamW.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams):
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams) -> None:
"""
Initializes the AdamW optimizer with given parameters.

@ -1,15 +1,15 @@
import torch
import torch.optim as optim
from torch import optim
from pydantic import BaseModel
from typing import Any, Iterable, Optional
from typing import Any, Iterable
class BaseOptimizer:
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel):
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel) -> None:
super().__init__()
self.optim: Optional[optim.Optimizer] = None
self.optim: optim.Optimizer | None = None
def zero_grad(self, set_to_none: bool = True) -> None:
@ -25,12 +25,12 @@ class BaseOptimizer:
self.optim.zero_grad(set_to_none=set_to_none)
def step(self, closure: Optional[Any] = None) -> Any:
def step(self, closure: Any | None = None) -> Any:
"""
Performs a single optimization step (parameter update).
Args:
closure (Optional[Callable]): A closure that reevaluates the model and returns the loss.
closure (Any | None): A closure that reevaluates the model and returns the loss.
This is required for optimizers like LBFGS that need multiple forward passes.
Returns:

@ -1,6 +1,6 @@
import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional
from typing import Any, Iterable
from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
@ -16,7 +16,7 @@ class SGDParams(BaseModel):
weight_decay: float = 0.0 # L2 penalty
nesterov: bool = False # Enables Nesterov momentum
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.SGD`."""
return self.model_dump()
@ -26,7 +26,7 @@ class SGDOptimizer(BaseOptimizer):
Wrapper around torch.optim.SGD.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams):
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams) -> None:
"""
Initializes the SGD optimizer with given parameters.

@ -1,5 +1,4 @@
import torch.optim.lr_scheduler as lr_scheduler
from typing import Dict, Final, Tuple, Type, List, Any, Union
from typing import Final, Type, Any
from pydantic import BaseModel
from .base import BaseScheduler
@ -17,7 +16,7 @@ __all__ = [
class SchedulerRegistry:
"""Registry for learning rate schedulers and their parameter classes with case-insensitive lookup."""
__SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
__SCHEDULERS: Final[dict[str, dict[str, Type[Any]]]] = {
"Step": {
"class": StepLRScheduler,
"params": StepLRParams,
@ -37,7 +36,7 @@ class SchedulerRegistry:
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
"""
Private method to retrieve the scheduler entry from the registry using case-insensitive lookup.
@ -45,7 +44,7 @@ class SchedulerRegistry:
name (str): The name of the scheduler.
Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
dict(str, Type(Any)): A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the scheduler is not found.
@ -68,7 +67,7 @@ class SchedulerRegistry:
name (str): Name of the scheduler.
Returns:
Type[BaseScheduler]: The scheduler class.
Type(BaseScheduler): The scheduler class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@ -82,17 +81,17 @@ class SchedulerRegistry:
name (str): Name of the scheduler.
Returns:
Type[BaseModel]: The scheduler parameter class.
Type(BaseModel): The scheduler parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_schedulers(cls) -> Tuple[str, ...]:
def get_available_schedulers(cls) -> tuple[str, ...]:
"""
Returns a tuple of available scheduler names in their original case.
Returns:
Tuple[str]: Tuple of available scheduler names.
Tuple(str): Tuple of available scheduler names.
"""
return tuple(cls.__SCHEDULERS.keys())

@ -1,6 +1,5 @@
import torch.optim as optim
from torch import optim
from pydantic import BaseModel
from typing import List, Optional
class BaseScheduler:
@ -9,8 +8,8 @@ class BaseScheduler:
Wraps a PyTorch LR scheduler and provides a unified interface.
"""
def __init__(self, optimizer: optim.Optimizer, params: BaseModel):
self.scheduler: Optional[optim.lr_scheduler.LRScheduler] = None
def __init__(self, optimizer: optim.Optimizer, params: BaseModel) -> None:
self.scheduler: optim.lr_scheduler.LRScheduler | None = None
def step(self) -> None:
"""
@ -20,7 +19,7 @@ class BaseScheduler:
if self.scheduler is not None:
self.scheduler.step()
def get_last_lr(self) -> List[float]:
def get_last_lr(self) -> list[float]:
"""
Returns the most recent learning rate(s).
"""

@ -1,10 +1,9 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from .base import BaseScheduler
from typing import Any
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from pydantic import BaseModel, ConfigDict
class CosineAnnealingLRParams(BaseModel):
@ -16,7 +15,7 @@ class CosineAnnealingLRParams(BaseModel):
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`."""
return self.model_dump()
@ -26,7 +25,7 @@ class CosineAnnealingLRScheduler(BaseScheduler):
Wrapper around torch.optim.lr_scheduler.CosineAnnealingLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams):
def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.

@ -1,9 +1,9 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
from .base import BaseScheduler
from typing import Any
from torch import optim
from torch.optim.lr_scheduler import ExponentialLR
from .base import BaseScheduler
from pydantic import BaseModel, ConfigDict
class ExponentialLRParams(BaseModel):
@ -13,7 +13,7 @@ class ExponentialLRParams(BaseModel):
gamma: float = 0.95 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`."""
return self.model_dump()
@ -23,7 +23,7 @@ class ExponentialLRScheduler(BaseScheduler):
Wrapper around torch.optim.lr_scheduler.ExponentialLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams):
def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.

@ -1,20 +1,20 @@
from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict
from .base import BaseScheduler
from typing import Any
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from .base import BaseScheduler
from pydantic import BaseModel, ConfigDict
class MultiStepLRParams(BaseModel):
"""Configuration for `torch.optim.lr_scheduler.MultiStepLR`."""
model_config = ConfigDict(frozen=True)
milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay
milestones: tuple[int, ...] = (30, 80) # List of epoch indices for LR decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`."""
return self.model_dump()
@ -24,7 +24,7 @@ class MultiStepLRScheduler(BaseScheduler):
Wrapper around torch.optim.lr_scheduler.MultiStepLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams):
def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.

@ -1,9 +1,9 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
from .base import BaseScheduler
from typing import Any
from torch import optim
from torch.optim.lr_scheduler import StepLR
from .base import BaseScheduler
from pydantic import BaseModel, ConfigDict
class StepLRParams(BaseModel):
@ -14,7 +14,7 @@ class StepLRParams(BaseModel):
gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]:
def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`."""
return self.model_dump()
@ -25,7 +25,7 @@ class StepLRScheduler(BaseScheduler):
Wrapper around torch.optim.lr_scheduler.StepLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: StepLRParams):
def __init__(self, optimizer: optim.Optimizer, params: StepLRParams) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.

@ -19,15 +19,14 @@ from torch.utils.data import DataLoader
import fastremap
import fill_voids
from skimage import morphology
# from skimage import morphology
from skimage.segmentation import find_boundaries
from scipy.special import expit
from scipy.ndimage import mean, find_objects
from monai.data.dataset import Dataset
from monai.transforms import * # type: ignore
from monai.transforms.compose import Compose
from monai.inferers.utils import sliding_window_inference
from monai.metrics.cumulative_average import CumulativeAverage
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
@ -42,16 +41,16 @@ from itertools import chain
from pprint import pformat
from tabulate import tabulate
from typing import Any, Dict, Literal, Optional, Tuple, List, Union
from typing import Any, Literal
from tqdm import tqdm
import wandb
from config import Config
from core.models import *
from core.losses import *
from core.optimizers import *
from core.schedulers import *
from core.models import ModelRegistry
from core.losses import CriterionRegistry
from core.optimizers import OptimizerRegistry
from core.schedulers import SchedulerRegistry
from core.utils import (
compute_batch_segmentation_tp_fp_fn,
compute_f1_score,
@ -78,30 +77,30 @@ class CellSegmentator:
else None
)
self._train_dataloader: Optional[DataLoader] = None
self._valid_dataloader: Optional[DataLoader] = None
self._test_dataloader: Optional[DataLoader] = None
self._predict_dataloader: Optional[DataLoader] = None
self._train_dataloader: DataLoader | None = None
self._valid_dataloader: DataLoader | None = None
self._test_dataloader: DataLoader | None = None
self._predict_dataloader: DataLoader | None = None
self._best_weights = None
def create_dataloaders(
self,
train_transforms: Optional[Compose] = None,
valid_transforms: Optional[Compose] = None,
test_transforms: Optional[Compose] = None,
predict_transforms: Optional[Compose] = None
train_transforms: Compose | None = None,
valid_transforms: Compose | None = None,
test_transforms: Compose | None = None,
predict_transforms: Compose | None = None
) -> None:
"""
Creates train, validation, test, and prediction dataloaders based on dataset configuration
and provided transforms.
Args:
train_transforms (Optional[Compose]): Transformations for training data.
valid_transforms (Optional[Compose]): Transformations for validation data.
test_transforms (Optional[Compose]): Transformations for testing data.
predict_transforms (Optional[Compose]): Transformations for prediction data.
train_transforms (Compose | None): Transformations for training data.
valid_transforms (Compose | None): Transformations for validation data.
test_transforms (Compose | None): Transformations for testing data.
predict_transforms (Compose | None): Transformations for prediction data.
Raises:
ValueError: If required transforms are missing.
@ -257,7 +256,7 @@ class CellSegmentator:
def print_data_info(
self,
loader_type: Literal["train", "valid", "test", "predict"],
index: Optional[int] = None
index: int | None = None
) -> None:
"""
Prints statistics for a single sample from the specified dataloader.
@ -267,7 +266,7 @@ class CellSegmentator:
index: The sample index; if None, a random index is selected.
"""
# Retrieve the dataloader attribute, e.g., self._train_dataloader
loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None)
loader: DataLoader | None = getattr(self, f"_{loader_type}_dataloader", None)
if loader is None:
logger.error(f"Dataloader '{loader_type}' is not initialized.")
return
@ -326,8 +325,8 @@ class CellSegmentator:
lines.append("=" * 40)
# Output via logger
for l in lines:
logger.info(l)
for line in lines:
logger.info(line)
def train(self, save_results: bool = True, only_masks: bool = False) -> None:
@ -661,16 +660,16 @@ class CellSegmentator:
logger.info(f"├─ Validation frequency: {training.val_freq}")
if training.is_split:
logger.info(f"├─ Using pre-split directories:")
logger.info( "├─ Using pre-split directories:")
logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}")
logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}")
logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}")
else:
logger.info(f"├─ Using unified dataset with splits:")
logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}")
logger.info( "├─ Using unified dataset with splits:")
logger.info( "│ ├─ All data dir: {training.split.all_data_dir}")
logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}")
logger.info(f"└─ Dataset split:")
logger.info( "└─ Dataset split:")
logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}")
logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}")
logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}")
@ -703,12 +702,12 @@ class CellSegmentator:
logger.info("===================================")
def __set_seed(self, seed: Optional[int]) -> None:
def __set_seed(self, seed: int | None) -> None:
"""
Sets the random seed for reproducibility across Python, NumPy, and PyTorch.
Args:
seed (Optional[int]): Seed value. If None, no seeding is performed.
seed (int | None): Seed value. If None, no seeding is performed.
"""
if seed is not None:
random.seed(seed)
@ -724,9 +723,9 @@ class CellSegmentator:
def __get_dataset(
self,
images_dir: str,
masks_dir: Optional[str],
masks_dir: str | None,
transforms: Compose,
size: Union[int, float],
size: int | float,
offset: int,
shuffle: bool
) -> Dataset:
@ -735,9 +734,9 @@ class CellSegmentator:
Args:
images_dir (str): Path to directory or glob pattern for input images.
masks_dir (Optional[str]): Path to directory or glob pattern for masks.
masks_dir (str | None): Path to directory or glob pattern for masks.
transforms (Compose): Transformations to apply to each image or pair.
size (Union[int, float]): Either an integer or a fraction of the dataset.
size (int | float): Either an integer or a fraction of the dataset.
offset (int): Number of images to skip from the start.
shuffle (bool): Whether to shuffle the dataset before slicing.
@ -806,12 +805,12 @@ class CellSegmentator:
return Dataset(data, transforms)
def __print_with_logging(self, metrics: Dict[str, Union[float, np.ndarray]], step: int) -> None:
def __print_with_logging(self, metrics: dict[str, float | np.ndarray], step: int) -> None:
"""
Print metrics in a tabular format and log to W&B.
Args:
metrics (Dict[str, Union[float, np.ndarray]]): Mapping from metric names
metrics (dict(str, float | np.ndarray)): Mapping from metric names
to either a float or a ND numpy array.
step (int): epoch index.
"""
@ -846,14 +845,14 @@ class CellSegmentator:
def __save_metrics_to_csv(
self,
metrics: Dict[str, Union[float, np.ndarray]],
metrics: dict[str, float | np.ndarray],
output_path: str
) -> None:
"""
Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'.
Args:
metrics (Dict[str, Union[float, np.ndarray]]):
metrics (dict(str, float | np.ndarray)):
Mapping from metric names to scalar values or numpy arrays.
output_path (str):
Path to the output CSV file.
@ -874,22 +873,22 @@ class CellSegmentator:
def __run_epoch(self,
mode: Literal["train", "valid", "test"],
epoch: Optional[int] = None,
epoch: int | None = None,
save_results: bool = True,
only_masks: bool = False
) -> Dict[str, Union[float, np.ndarray]]:
) -> dict[str, float | np.ndarray]:
"""
Execute one epoch of training, validation, or testing.
Args:
mode (str): One of 'train', 'valid', or 'test'.
epoch (int, optional): Current epoch number for logging.
epoch (int | None): Current epoch number for logging.
save_results (bool): If True, the predicted masks and test metrics will be saved.
only_masks (bool): If True and save_results is True, only raw predicted masks are saved,
without visualization overlays.
Returns:
Dict[str, Union[float, np.ndarray]]: Metrics for valid/test.
dict(str, float | np.ndarray): Metrics for valid/test.
"""
# Ensure required components are available
if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None):
@ -988,7 +987,7 @@ class CellSegmentator:
if self._criterion is not None:
# Collect loss metrics
epoch_metrics: Dict[str, Union[float, np.ndarray]] = {
epoch_metrics: dict[str, float | np.ndarray] = {
f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()
}
# Reset internal loss metrics accumulator
@ -1051,17 +1050,17 @@ class CellSegmentator:
def __post_process_predictions(
self,
raw_outputs: torch.Tensor,
ground_truth: Optional[torch.Tensor] = None
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
ground_truth: torch.Tensor | None = None
) -> tuple[np.ndarray, np.ndarray | None]:
"""
Post-process raw network outputs to extract instance segmentation masks.
Args:
raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W).
ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W).
ground_truth (torch.Tensor | None): Ground truth masks of shape (B, С, H, W).
Returns:
Tuple[np.ndarray, Optional[np.ndarray]]:
tuple(np.ndarray, np.ndarray | None):
- instance_masks: Instance-wise masks array of shape (B, С, H, W).
- labels_np: Converted ground truth of shape (B, С, H, W) or None if
ground_truth was not provided.
@ -1097,8 +1096,8 @@ class CellSegmentator:
ground_truth_masks: np.ndarray,
iou_threshold: float = 0.5,
return_error_masks: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray,
Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray,
np.ndarray | None, np.ndarray | None, np.ndarray | None]:
"""
Compute batch-wise true positives, false positives, and false negatives
for instance segmentation, using a configurable IoU threshold.
@ -1111,7 +1110,7 @@ class CellSegmentator:
return_error_masks (bool): Whether to also return binary error masks.
Returns:
Tuple(np.ndarray, np.ndarray, np.ndarray,
tuple(np.ndarray, np.ndarray, np.ndarray,
np.ndarray | None, np.ndarray | None, np.ndarray | None):
- tp: True positives per batch and class, shape (B, C)
- fp: False positives per batch and class, shape (B, C)
@ -1143,7 +1142,7 @@ class CellSegmentator:
false_positives: np.ndarray,
false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro"
) -> Union[float, np.ndarray]:
) -> float | np.ndarray:
"""
Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes.
@ -1266,7 +1265,7 @@ class CellSegmentator:
false_positives: np.ndarray,
false_negatives: np.ndarray,
reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro"
) -> Union[float, np.ndarray]:
) -> float | np.ndarray:
"""
Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes.
@ -1399,23 +1398,23 @@ class CellSegmentator:
def __save_prediction_masks(
self,
sample: Dict[str, Any],
predicted_mask: Union[np.ndarray, torch.Tensor],
sample: dict[str, Any],
predicted_mask: np.ndarray | torch.Tensor,
start_index: int = 0,
only_masks: bool = False,
masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None
masks: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None
) -> None:
"""
Save multi-channel predicted masks as TIFFs and
corresponding visualizations as PNGs in separate folders.
Args:
sample (Dict[str, Any]): Batch sample from MONAI
sample (dict(str, Any)): Batch sample from MONAI
LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
predicted_mask (np.ndarray | torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
start_index (int): Starting index for naming when metadata is missing.
only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations.
masks (Tuple[np.ndarray, np.ndarray, np.ndarray] | None):
masks (tuple[np.ndarray, np.ndarray, np.ndarray] | None):
A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None.
"""
# Base directories (created once per call)
@ -1428,14 +1427,14 @@ class CellSegmentator:
os.makedirs(evaluate_dir, exist_ok=True)
# Convert tensors to numpy if necessary
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
def to_numpy(x: np.ndarray | torch.Tensor) -> np.ndarray:
return x.cpu().numpy() if isinstance(x, torch.Tensor) else x
pred_array = to_numpy(predicted_mask).astype(np.uint16)
# Handle batch dimension
for idx in range(pred_array.shape[0]):
batch_sample: Dict[str, Any] = {}
batch_sample: dict[str, Any] = {}
# copy per-sample image and meta
img = to_numpy(sample["image"])
if img.ndim == 4:
@ -1467,21 +1466,21 @@ class CellSegmentator:
def __save_single_prediction_mask(
self,
sample: Dict[str, Any],
sample: dict[str, Any],
pred_array: np.ndarray,
start_index: int,
masks_dir: str,
plots_dir: str,
evaluate_dir: str,
only_masks: bool = False,
masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None
masks: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None
) -> None:
"""
Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations.
Assumes output directories already exist.
Args:
sample (Dict[str, Any]): Dictionary containing 'image', 'mask',
sample (dict(str, Any)): Dictionary containing 'image', 'mask',
and optional 'image_meta_dict' for metadata.
pred_array (np.ndarray): Predicted mask array of shape (C,H,W).
start_index (int): Base index for generating filenames when metadata is missing.
@ -1489,7 +1488,7 @@ class CellSegmentator:
plots_dir (str): Directory for saving PNG visualizations.
evaluate_dir (str): Directory for saving PNG visualizations of evaluation results.
only_masks (bool): If True, saves only TIFF mask files; skips PNG plots.
masks (Tuple[np.ndarray, np.ndarray, np.ndarray], optional): A tuple of
masks (tuple[np.ndarray, np.ndarray, np.ndarray] | None): A tuple of
true-positive, false-positive, and false-negative mask arrays,
each of shape (C,H,W). Defaults to None.
"""
@ -1510,7 +1509,7 @@ class CellSegmentator:
"Expected 2D (H,W) or 3D (C,H,W)."
)
true_mask_array: Optional[np.ndarray] = sample.get("mask")
true_mask_array: np.ndarray | None = sample.get("mask")
if isinstance(true_mask_array, np.ndarray):
if true_mask_array.ndim == 2:
true_mask_array = np.expand_dims(true_mask_array, axis=0)
@ -1562,7 +1561,7 @@ class CellSegmentator:
file_path: str,
image_data: np.ndarray,
predicted_mask: np.ndarray,
true_mask: Optional[np.ndarray] = None,
true_mask: np.ndarray | None = None,
) -> None:
"""
Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided.
@ -1572,7 +1571,7 @@ class CellSegmentator:
image_data (np.ndarray): The original input image array, expected shape (C, H, W).
predicted_mask (np.ndarray): The predicted mask array, shape (H, W),
depending on the task.
true_mask (Optional[np.ndarray], optional): The ground-truth mask array.
true_mask (np.ndarray | None): The ground-truth mask array.
If provided, an additional row with true mask and overlap visualization
will be added to the plot. Default is None.
@ -1603,7 +1602,7 @@ class CellSegmentator:
img: np.ndarray,
mask: np.ndarray,
contour_color: str,
titles: Tuple[str, ...]
titles: tuple[str, ...]
):
"""
Plot a row of three panels: original image, mask, and mask boundaries on image.
@ -1618,7 +1617,8 @@ class CellSegmentator:
# Panel 1: Original image
ax0, ax1, ax2 = axes
ax0.imshow(img, cmap='gray' if img.ndim == 2 else None)
ax0.set_title(titles[0]); ax0.axis('off')
ax0.set_title(titles[0])
ax0.axis('off')
# Compute boundaries once
boundaries = find_boundaries(mask, mode='thick')
@ -1793,7 +1793,8 @@ class CellSegmentator:
# Get coordinates of all non-zero pixels in the padded mask
y, x = torch.nonzero(masks_padded, as_tuple=True)
y = y.int(); x = x.int() # ensure integer type
y = y.int()
x = x.int() # ensure integer type
# Generate 8-connected neighbors (including center) via broadcasted offsets
offsets = torch.tensor([
@ -1830,9 +1831,12 @@ class CellSegmentator:
], dtype=np.int16)
# Compute centers (pixel indices) and extents via the provided helper
centers, ext = self.__get_mask_centers_and_extents(mask_channel, slices_arr)
centers, ext = self.__get_mask_centers_and_extents(
mask_channel, slices_arr
)
# Move centers to GPU and shift by +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding
# (M, 2); +1 for padding
meds_p = torch.from_numpy(centers).to(self._device).long() + 1
# Determine number of diffusion iterations
n_iter = 2 * ext.max()
@ -1865,7 +1869,7 @@ class CellSegmentator:
def __get_mask_centers_and_extents(
label_map: np.ndarray,
slices_arr: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the centroids and extents of labeled regions in a 2D mask array.
@ -1923,7 +1927,7 @@ class CellSegmentator:
neighbor_indices: torch.Tensor,
center_indices: torch.Tensor,
valid_neighbor_mask: torch.Tensor,
output_shape: Tuple[int, int],
output_shape: tuple[int, int],
num_iterations: int = 200
) -> np.ndarray:
"""
@ -1933,7 +1937,7 @@ class CellSegmentator:
neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel.
center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers.
valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid.
output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W).
output_shape (tuple(int, int)): Desired 2D shape of the diffusion tensor, e.g., (H, W).
num_iterations (int, optional): Number of diffusion iterations. Defaults to 200.
Returns:
@ -2242,7 +2246,7 @@ class CellSegmentator:
flow_field: np.ndarray,
initial_coords: np.ndarray,
num_iters: int = 200
) -> Union[np.ndarray, torch.Tensor]:
) -> np.ndarray | torch.Tensor:
"""
Trace pixel positions through a flow field via iterative interpolation.
@ -2252,7 +2256,7 @@ class CellSegmentator:
num_iters (int): Number of integration steps.
Returns:
np.ndarray or torch.Tensor: Final (y, x) positions of each point.
(np.ndarray | torch.Tensor): Final (y, x) positions of each point.
"""
dims = 2
# Extract spatial dimensions
@ -2383,7 +2387,7 @@ class CellSegmentator:
self,
pixel_positions: torch.Tensor,
valid_indices: np.ndarray,
original_shape: Tuple[int, ...],
original_shape: tuple[int, ...],
pad_radius: int = 20,
max_size_fraction: float = 0.4
) -> np.ndarray:
@ -2534,7 +2538,7 @@ class CellSegmentator:
input_tensor: Tensor,
kernel_size: int = 5,
axis: int = 1,
output_tensor: Optional[Tensor] = None
output_tensor: Tensor | None = None
) -> Tensor:
"""
Memory-efficient 1D max pooling along a specified axis using in-place updates.
@ -2547,7 +2551,7 @@ class CellSegmentator:
input_tensor (Tensor): Source tensor for pooling.
kernel_size (int): Size of the pooling window (must be odd and >= 3).
axis (int): Axis along which to compute 1D max pooling.
output_tensor (Optional[Tensor]): Tensor to store the result.
output_tensor (Tensor | None): Tensor to store the result.
If None, a clone of input_tensor is used.
Returns:
@ -2691,7 +2695,7 @@ class CellSegmentator:
self,
mask: np.ndarray,
flow_network: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute mean squared error between network-predicted flows and flows derived from masks.
@ -2700,7 +2704,7 @@ class CellSegmentator:
flow_network (np.ndarray): Network predicted flows of shape [axis, ...].
Returns:
Tuple[np.ndarray, np.ndarray]:
tuple(np.ndarray, np.ndarray):
- flow_errors: 1D array (length = max label) of mean squared error per label.
- computed_flows: Array of flows derived from the mask, same shape as flow_network.

@ -8,7 +8,7 @@ from numpy.typing import NDArray
from numba import jit
from skimage import segmentation
from scipy.optimize import linear_sum_assignment
from typing import Dict, List, Tuple, Any, Union
from typing import Any
from core.logger import get_logger
@ -27,7 +27,7 @@ def compute_f1_score(
true_positives: int,
false_positives: int,
false_negatives: int
) -> Tuple[float, float, float]:
) -> tuple[float, float, float]:
"""
Computes the precision, recall, and F1-score given the numbers of
true positives, false positives, and false negatives.
@ -76,7 +76,7 @@ def compute_confusion_matrix(
ground_truth_mask: np.ndarray,
predicted_mask: np.ndarray,
iou_threshold: float = 0.5
) -> Tuple[int, int, int]:
) -> tuple[int, int, int]:
"""
Computes the confusion matrix elements (true positives, false positives, false negatives)
for a single image given the ground truth and predicted masks.
@ -114,7 +114,7 @@ def compute_segmentation_tp_fp_fn(
iou_threshold: float = 0.5,
return_error_masks: bool = False,
remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""
Computes TP, FP and FN for segmentation on a single image.
@ -176,7 +176,7 @@ def compute_segmentation_tp_fp_fn(
false_positive_mask_list.append(results.get('fp_mask')) # type: ignore
false_negative_mask_list.append(results.get('fn_mask')) # type: ignore
output: Dict[str, np.ndarray] = {
output: dict[str, np.ndarray] = {
'tp': np.array(true_positive_list),
'fp': np.array(false_positive_list),
'fn': np.array(false_negative_list)
@ -194,7 +194,7 @@ def compute_segmentation_f1_metrics(
iou_threshold: float = 0.5,
return_error_masks: bool = False,
remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""
Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image.
@ -240,7 +240,7 @@ def compute_segmentation_f1_metrics(
recall_list.append(recall)
f1_score_list.append(f1_score)
output: Dict[str, np.ndarray] = {
output: dict[str, np.ndarray] = {
'precision': np.array(precision_list),
'recall': np.array(recall_list),
'f1_score': np.array(f1_score_list),
@ -255,7 +255,7 @@ def compute_segmentation_average_precision_metrics(
iou_threshold: float = 0.5,
return_error_masks: bool = False,
remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""
Computes the average precision (AP) for segmentation on a single image.
@ -298,7 +298,7 @@ def compute_segmentation_average_precision_metrics(
)
avg_precision_list.append(avg_precision)
output: Dict[str, np.ndarray] = {
output: dict[str, np.ndarray] = {
'avg_precision': np.array(avg_precision_list)
}
output.update(results)
@ -311,7 +311,7 @@ def compute_batch_segmentation_tp_fp_fn(
iou_threshold: float = 0.5,
return_error_masks: bool = False,
remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""
Computes segmentation TP, FP and FN for a batch of images.
@ -361,7 +361,7 @@ def compute_batch_segmentation_tp_fp_fn(
fp_mask_list.append(result.get('fp_mask')) # type: ignore
fn_mask_list.append(result.get('fn_mask')) # type: ignore
output: Dict[str, np.ndarray] = {
output: dict[str, np.ndarray] = {
'tp': np.stack(tp_list, axis=0),
'fp': np.stack(fp_list, axis=0),
'fn': np.stack(fn_list, axis=0)
@ -379,7 +379,7 @@ def compute_batch_segmentation_f1_metrics(
iou_threshold: float = 0.5,
return_error_masks: bool = False,
remove_boundary_objects: bool = True
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""
Computes segmentation F1 metrics for a batch of images.
@ -435,7 +435,7 @@ def compute_batch_segmentation_f1_metrics(
fp_mask_list.append(result.get('fp_mask')) # type: ignore
fn_mask_list.append(result.get('fn_mask')) # type: ignore
output: Dict[str, np.ndarray] = {
output: dict[str, np.ndarray] = {
'precision': np.stack(precision_list, axis=0),
'recall': np.stack(recall_list, axis=0),
'f1_score': np.stack(f1_score_list, axis=0),
@ -456,7 +456,7 @@ def compute_batch_segmentation_average_precision_metrics(
iou_threshold: float = 0.5,
return_error_masks: bool = False,
remove_boundary_objects: bool = True
) -> Dict[str, NDArray]:
) -> dict[str, np.ndarray]:
"""
Computes segmentation average precision metrics for a batch of images.
@ -508,7 +508,7 @@ def compute_batch_segmentation_average_precision_metrics(
fp_mask_list.append(result.get('fp_mask')) # type: ignore
fn_mask_list.append(result.get('fn_mask')) # type: ignore
output: Dict[str, NDArray] = {
output: dict[str, np.ndarray] = {
'avg_precision': np.stack(avg_precision_list, axis=0),
'tp': np.stack(tp_list, axis=0),
'fp': np.stack(fp_list, axis=0),
@ -555,7 +555,7 @@ def _process_instance_matching(
iou_threshold: float = 0.5,
return_masks: bool = False,
without_boundary_objects: bool = True
) -> Dict[str, Union[int, NDArray[np.uint8]]]:
) -> dict[str, int | NDArray[np.uint8]]:
"""
Processes instance matching on a full image by performing the following steps:
- Removes objects that touch the image boundary and reindexes the masks.
@ -597,8 +597,8 @@ def _process_instance_matching(
fn_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8)
# Mark all ground truth objects as false negatives.
fn_mask[ground_truth_mask > 0] = 1
result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask})
return result
result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) # type: ignore
return result # type: ignore
# Compute the IoU matrix for the processed masks.
iou_matrix = _calculate_iou(processed_ground_truth, processed_prediction)
@ -640,11 +640,11 @@ def _process_instance_matching(
for pred_label in (all_prediction_labels - matched_prediction_labels):
fp_mask[processed_prediction == pred_label] = 1
result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask})
return result
result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) # type: ignore
return result # type: ignore
def _compute_optimal_matching_pairs(iou_matrix: np.ndarray, iou_threshold: float) -> List[Any]:
def _compute_optimal_matching_pairs(iou_matrix: np.ndarray, iou_threshold: float) -> list[Any]:
"""
Computes the optimal matching pairs between ground truth and predicted masks using the IoU matrix.
@ -687,7 +687,7 @@ def _compute_patch_based_metrics(
iou_threshold: float = 0.5,
return_masks: bool = False,
without_boundary_objects: bool = True
) -> Dict[str, Union[int, NDArray[np.uint8]]]:
) -> dict[str, int | NDArray[np.uint8]]:
"""
Computes segmentation metrics using a patch-based approach for very large images.
@ -747,7 +747,7 @@ def _compute_patch_based_metrics(
padded_fp_mask[y_start:y_end, x_start:x_end] = patch_results.get('fp_mask', 0) # type: ignore
padded_fn_mask[y_start:y_end, x_start:x_end] = patch_results.get('fn_mask', 0) # type: ignore
results: Dict[str, Union[int, np.ndarray]] = {'tp': total_tp, 'fp': total_fp, 'fn': total_fn}
results: dict[str, int | np.ndarray] = {'tp': total_tp, 'fp': total_fp, 'fn': total_fn}
if return_masks:
# Crop the padded masks back to the original image size.
results.update({

@ -1,5 +1,4 @@
import os
from typing import Tuple
from config import Config, WandbConfig, DatasetConfig, ComponentConfig
@ -8,7 +7,7 @@ from core import (
)
def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str:
def prompt_choice(prompt_message: str, options: tuple[str, ...]) -> str:
"""
Prompt the user with a list of options and return the selected option.
"""

@ -1,20 +1,25 @@
import os
import sys
import argparse
import wandb
from config import Config
from core.data import *
from core.data import (
get_train_transforms,
get_valid_transforms,
get_test_transforms,
get_predict_transforms
)
from core.segmentator import CellSegmentator
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Train or predict cell segmentator with specified config file."
)
parser.add_argument(
'-c', '--config',
type=str,
default='config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json',
help='Path to the JSON config file'
)
parser.add_argument(
@ -36,6 +41,10 @@ def main():
' masks without additional visualizations')
)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(0)
args = parser.parse_args()
mode = args.mode
@ -44,7 +53,7 @@ def main():
if mode == 'train' and not config.dataset_config.is_training:
raise ValueError(
f"Config is not set for training (is_training=False), but mode 'train' was requested."
"Config is not set for training (is_training=False), but mode 'train' was requested."
)
if mode in ('test', 'predict') and config.dataset_config.is_training:
raise ValueError(

Loading…
Cancel
Save