From b0e7b21a21163387a6a3abcfbc868a45bfe48173 Mon Sep 17 00:00:00 2001 From: laynholt Date: Sat, 17 May 2025 18:08:04 +0000 Subject: [PATCH] refactor: migrate all type annotations from Python 3.9 to 3.10 syntax Replaced typing module generics like List, Dict, Tuple with built-in alternatives (list, dict, tuple). Updated code to use new union syntax (X | Y) instead of Union[X, Y]. --- config/config.py | 57 ++++---- config/dataset_config.py | 18 +-- config/wandb_config.py | 20 +-- core/data/transforms/__init__.py | 4 +- core/data/transforms/cell_aware.py | 22 ++-- core/data/transforms/load_image.py | 18 +-- core/data/transforms/normalize_image.py | 12 +- core/data/transforms/random_crop.py | 16 +-- core/losses/__init__.py | 16 +-- core/losses/base.py | 13 +- core/losses/bce.py | 30 +++-- core/losses/ce.py | 24 ++-- core/losses/mse.py | 20 +-- core/losses/mse_with_bce.py | 21 +-- core/models/__init__.py | 18 +-- core/models/model_v.py | 17 ++- core/optimizers/__init__.py | 16 +-- core/optimizers/adam.py | 8 +- core/optimizers/adamw.py | 8 +- core/optimizers/base.py | 12 +- core/optimizers/sgd.py | 6 +- core/schedulers/__init__.py | 17 ++- core/schedulers/base.py | 9 +- core/schedulers/cosine_annealing.py | 13 +- core/schedulers/exponential.py | 12 +- core/schedulers/multi_step.py | 14 +- core/schedulers/step.py | 12 +- core/segmentator.py | 168 ++++++++++++------------ core/utils/measures.py | 46 +++---- generate_config.py | 3 +- main.py | 17 ++- 31 files changed, 359 insertions(+), 328 deletions(-) diff --git a/config/config.py b/config/config.py index df017c6..a33f54a 100644 --- a/config/config.py +++ b/config/config.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel from .wandb_config import WandbConfig @@ -13,7 +13,7 @@ class ComponentConfig(BaseModel): name: str params: BaseModel - def dump(self) -> Dict[str, Any]: + def dump(self) -> dict[str, Any]: """ Recursively serializes the component into a dictionary. @@ -24,22 +24,18 @@ class ComponentConfig(BaseModel): params_dump = self.params.model_dump() else: params_dump = self.params - return { - "name": self.name, - "params": params_dump - } - + return {"name": self.name, "params": params_dump} class Config(BaseModel): model: ComponentConfig dataset_config: DatasetConfig wandb_config: WandbConfig - criterion: Optional[ComponentConfig] = None - optimizer: Optional[ComponentConfig] = None - scheduler: Optional[ComponentConfig] = None + criterion: ComponentConfig | None = None + optimizer: ComponentConfig | None = None + scheduler: ComponentConfig | None = None - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """ Produce a JSON‐serializable dict of this config, including nested ComponentConfig and DatasetConfig entries. Useful for saving to file @@ -49,7 +45,7 @@ class Config(BaseModel): A dict with keys 'model', 'dataset_config', and (if set) 'criterion', 'optimizer', 'scheduler'. """ - data: Dict[str, Any] = { + data: dict[str, Any] = { "model": self.model.dump(), "dataset_config": self.dataset_config.model_dump(), } @@ -62,7 +58,6 @@ class Config(BaseModel): data["wandb"] = self.wandb_config.model_dump() return data - def save_json(self, file_path: str, indent: int = 4) -> None: """ Save this config to a JSON file. @@ -75,7 +70,6 @@ class Config(BaseModel): with open(file_path, "w", encoding="utf-8") as f: f.write(json.dumps(config_dict, indent=indent)) - @classmethod def load_json(cls, file_path: str) -> "Config": """ @@ -96,10 +90,12 @@ class Config(BaseModel): wandb_config = WandbConfig(**data.get("wandb", {})) # Helper function to parse registry fields. - def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]: + def parse_field( + component_data: dict[str, Any], registry_getter + ) -> ComponentConfig | None: name = component_data.get("name") params_data = component_data.get("params", {}) - + if name is not None: expected = registry_getter(name) params = expected(**params_data) @@ -107,16 +103,31 @@ class Config(BaseModel): return None from core import ( - ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry + ModelRegistry, + CriterionRegistry, + OptimizerRegistry, + SchedulerRegistry, ) - parsed_model = parse_field(data.get("model", {}), lambda key: ModelRegistry.get_model_params(key)) - parsed_criterion = parse_field(data.get("criterion", {}), lambda key: CriterionRegistry.get_criterion_params(key)) - parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key)) - parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key)) + parsed_model = parse_field( + data.get("model", {}), + lambda key: ModelRegistry.get_model_params(key), + ) + parsed_criterion = parse_field( + data.get("criterion", {}), + lambda key: CriterionRegistry.get_criterion_params(key), + ) + parsed_optimizer = parse_field( + data.get("optimizer", {}), + lambda key: OptimizerRegistry.get_optimizer_params(key), + ) + parsed_scheduler = parse_field( + data.get("scheduler", {}), + lambda key: SchedulerRegistry.get_scheduler_params(key), + ) if parsed_model is None: - raise ValueError('Failed to load model information') + raise ValueError("Failed to load model information") return cls( model=parsed_model, @@ -124,5 +135,5 @@ class Config(BaseModel): criterion=parsed_criterion, optimizer=parsed_optimizer, scheduler=parsed_scheduler, - wandb_config=wandb_config + wandb_config=wandb_config, ) diff --git a/config/dataset_config.py b/config/dataset_config.py index 0b1a646..346c0aa 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, model_validator, field_validator -from typing import Any, Dict, Optional, Union +from typing import Any import os @@ -7,7 +7,7 @@ class DatasetCommonConfig(BaseModel): """ Common configuration fields shared by both training and testing. """ - seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) + seed: int | None = 0 # Seed for splitting if data is not pre-split (and all random operations) device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda') use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) roi_size: int = 512 # The size of the square window for cropping @@ -65,9 +65,9 @@ class DatasetTrainingConfig(BaseModel): pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo() split: TrainingSplitInfo = TrainingSplitInfo() - train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic) - valid_size: Union[int, float] = 0.1 # Validation data size (int for static, float in (0,1] for dynamic) - test_size: Union[int, float] = 0.2 # Testing data size (int for static, float in (0,1] for dynamic) + train_size: int | float = 0.7 # Training data size (int for static, float in (0,1] for dynamic) + valid_size: int | float = 0.1 # Validation data size (int for static, float in (0,1] for dynamic) + test_size: int | float = 0.2 # Testing data size (int for static, float in (0,1] for dynamic) train_offset: int = 0 # Offset for training data valid_offset: int = 0 # Offset for validation data test_offset: int = 0 # Offset for testing data @@ -78,7 +78,7 @@ class DatasetTrainingConfig(BaseModel): @field_validator("train_size", "valid_size", "test_size", mode="before") - def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]: + def validate_sizes(cls, v: int | float) -> int | float: """ Validates size values: - If provided as a float, must be in the range (0, 1]. @@ -145,12 +145,12 @@ class DatasetTestingConfig(BaseModel): Configuration fields used only in testing mode. """ test_dir: str = "." # Test data directory; must be non-empty - test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic) + test_size: int | float = 1.0 # Testing data size (int for static, float in (0,1] for dynamic) test_offset: int = 0 # Offset for testing data shuffle: bool = True # Shuffle data @field_validator("test_size", mode="before") - def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]: + def validate_test_size(cls, v: int | float) -> int | float: """ Validates the test_size value. """ @@ -224,7 +224,7 @@ class DatasetConfig(BaseModel): raise ValueError(f"Path for pretrained_weights does not exist: {self.common.pretrained_weights}") return self - def model_dump(self, **kwargs) -> Dict[str, Any]: + def model_dump(self, **kwargs) -> dict[str, Any]: """ Dumps only the relevant configuration depending on the is_training flag. Only the nested configuration (training or testing) along with common fields is returned. diff --git a/config/wandb_config.py b/config/wandb_config.py index 82d10c5..d217cc0 100644 --- a/config/wandb_config.py +++ b/config/wandb_config.py @@ -1,19 +1,19 @@ from pydantic import BaseModel, model_validator -from typing import Any, Dict, Optional +from typing import Any class WandbConfig(BaseModel): """ Configuration for Weights & Biases logging. """ - use_wandb: bool = False # Whether to enable WandB logging - project: Optional[str] = None # WandB project name - group: Optional[str] = None # WandB group name - entity: Optional[str] = None # WandB entity (user or team) - name: Optional[str] = None # Name of the run - tags: Optional[list[str]] = None # List of tags for the run - notes: Optional[str] = None # Notes or description for the run - save_code: bool = True # Whether to save the code to WandB + use_wandb: bool = False # Whether to enable WandB logging + project: str | None = None # WandB project name + group: str | None = None # WandB group name + entity: str | None = None # WandB entity (user or team) + name: str | None = None # Name of the run + tags: list[str] | None = None # List of tags for the run + notes: str | None = None # Notes or description for the run + save_code: bool = True # Whether to save the code to WandB @model_validator(mode="after") def validate_wandb(self) -> "WandbConfig": @@ -22,7 +22,7 @@ class WandbConfig(BaseModel): raise ValueError("When use_wandb=True, 'project' must be provided") return self - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """ Return a dict of all W&B parameters, excluding 'use_wandb' and any None values. """ diff --git a/core/data/transforms/__init__.py b/core/data/transforms/__init__.py index 03e538a..ad321f6 100644 --- a/core/data/transforms/__init__.py +++ b/core/data/transforms/__init__.py @@ -1,6 +1,6 @@ from .cell_aware import IntensityDiversification -from .load_image import CustomLoadImage, CustomLoadImaged -from .normalize_image import CustomNormalizeImage, CustomNormalizeImaged +from .load_image import CustomLoadImaged +from .normalize_image import CustomNormalizeImaged from monai.transforms import * # type: ignore diff --git a/core/data/transforms/cell_aware.py b/core/data/transforms/cell_aware.py index d0d14d6..ed5e867 100644 --- a/core/data/transforms/cell_aware.py +++ b/core/data/transforms/cell_aware.py @@ -1,7 +1,7 @@ import copy import torch import numpy as np -from typing import Dict, Sequence, Tuple, Union +from typing import Sequence from skimage.segmentation import find_boundaries from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore @@ -26,14 +26,14 @@ class BoundaryExclusion(MapTransform): def __init__(self, keys: Sequence[str] = ("mask",), allow_missing_keys: bool = False) -> None: """ Args: - keys (Sequence[str]): Keys in the input dictionary corresponding to the label image. + keys (Sequence(str)): Keys in the input dictionary corresponding to the label image. Default is ("mask",). allow_missing_keys (bool): If True, missing keys in the input will be ignored. Default is False. """ super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) - def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + def __call__(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]: """ Apply the boundary exclusion transform to the label image. @@ -46,10 +46,10 @@ class BoundaryExclusion(MapTransform): 6. Assigning the transformed label back into the input dictionary. Args: - data (Dict[str, np.ndarray]): Dictionary containing at least the "mask" key with a label image. + data (Dict(str, np.ndarray)): Dictionary containing at least the "mask" key with a label image. Returns: - Dict[str, np.ndarray]: The input dictionary with the "mask" key updated after boundary exclusion. + Dict(str, np.ndarray): The input dictionary with the "mask" key updated after boundary exclusion. """ # Retrieve the original label image. label_original: np.ndarray = data["mask"] @@ -100,17 +100,17 @@ class IntensityDiversification(MapTransform): self, keys: Sequence[str] = ("image",), change_cell_ratio: float = 0.4, - scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7), + scale_factors: tuple[float, float] | float = (0.0, 0.7), allow_missing_keys: bool = False, ) -> None: """ Args: - keys (Sequence[str]): Keys in the input dictionary corresponding to the image. + keys (Sequence(str)): Keys in the input dictionary corresponding to the image. Default is ("image",). change_cell_ratio (float): Ratio of cells to apply the intensity scaling. For example, 0.4 means 40% of the cells will be transformed. Default is 0.4. - scale_factors (Sequence[float]): Factors used for random intensity scaling. + scale_factors (tuple(float, float) | float): Factors used for random intensity scaling. Default is (0.0, 0.7). allow_missing_keys (bool): If True, missing keys in the input will be ignored. Default is False. @@ -120,7 +120,7 @@ class IntensityDiversification(MapTransform): # Compose a random intensity scaling transform with 100% probability. self.randscale_intensity = Compose([RandScaleIntensity(prob=1.0, factors=scale_factors)]) - def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + def __call__(self, data: dict[str, np.ndarray]) -> dict[str, np.ndarray]: """ Apply a cell-wise intensity diversification transform to an input image. @@ -141,12 +141,12 @@ class IntensityDiversification(MapTransform): 9. Combine the unchanged and modified parts to update the image for that channel. Args: - data (Dict[str, np.ndarray]): A dictionary containing: + data (dict(str, np.ndarray)): A dictionary containing: - "image": The original image array. - "mask": The corresponding cell label image array. Returns: - Dict[str, np.ndarray]: The updated dictionary with the "image" key modified after applying + dict(str, np.ndarray): The updated dictionary with the "image" key modified after applying the intensity transformation. Raises: diff --git a/core/data/transforms/load_image.py b/core/data/transforms/load_image.py index 2ff9f55..b3964ec 100644 --- a/core/data/transforms/load_image.py +++ b/core/data/transforms/load_image.py @@ -1,7 +1,7 @@ import numpy as np import tifffile as tif import skimage.io as io -from typing import Final, List, Optional, Sequence, Type, Union +from typing import Final, Sequence, Type from monai.utils.enums import PostFix from monai.utils.module import optional_import @@ -45,7 +45,7 @@ class CustomLoadImage(LoadImage): """ def __init__( self, - reader: Optional[Union[ImageReader, Type[ImageReader], str]] = None, + reader: ImageReader | Type[ImageReader] | str | None = None, image_only: bool = False, dtype: DtypeLike = np.float32, ensure_channel_first: bool = False, @@ -75,9 +75,9 @@ class CustomLoadImaged(LoadImaged): def __init__( self, keys: KeysCollection, - reader: Optional[Union[Type[ImageReader], str]] = None, + reader: Type[ImageReader] | str | None = None, dtype: DtypeLike = np.float32, - meta_keys: Optional[KeysCollection] = None, + meta_keys: KeysCollection | None = None, meta_key_postfix: str = DEFAULT_POST_FIX, overwriting: bool = False, image_only: bool = False, @@ -141,13 +141,13 @@ class UniversalImageReader(NumpyReader): (e.g., repeating or cropping channels). """ def __init__( - self, channel_dim: Optional[int] = None, **kwargs, - ): + self, channel_dim: int | None = None, **kwargs, + ) -> None: super().__init__(channel_dim=channel_dim, **kwargs) self.kwargs = kwargs self.channel_dim = channel_dim - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Check if the file format is supported for reading. @@ -155,7 +155,7 @@ class UniversalImageReader(NumpyReader): """ return has_itk or is_supported_format(filename, SUPPORTED_IMAGE_FORMATS) - def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): """ Read image(s) from the given path. @@ -166,7 +166,7 @@ class UniversalImageReader(NumpyReader): Returns: A single image or a list of images depending on the number of paths provided. """ - images: List[np.ndarray] = [] # List to store the loaded images + images: list[np.ndarray] = [] # List to store the loaded images # Convert data to a tuple to support multiple files filenames: Sequence[PathLike] = ensure_tuple(data) diff --git a/core/data/transforms/normalize_image.py b/core/data/transforms/normalize_image.py index 42d0d1c..63c0e5e 100644 --- a/core/data/transforms/normalize_image.py +++ b/core/data/transforms/normalize_image.py @@ -2,7 +2,7 @@ import numpy as np from skimage import exposure from monai.config.type_definitions import KeysCollection from monai.transforms.transform import Transform, MapTransform -from typing import Dict, Hashable, Mapping, Sequence +from typing import Hashable, Mapping, Sequence __all__ = [ "CustomNormalizeImage", @@ -23,7 +23,7 @@ class CustomNormalizeImage(Transform): def __init__(self, percentiles: Sequence[float] = (0, 99), channel_wise: bool = False) -> None: """ Args: - percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling. + percentiles (Sequence(float)): Lower and upper percentiles used for intensity scaling. Default is (0, 99). channel_wise (bool): Whether to apply normalization on each channel individually. Default is False. @@ -106,7 +106,7 @@ class CustomNormalizeImaged(MapTransform): """ Args: keys (KeysCollection): Keys identifying the image entries in the dictionary. - percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling. + percentiles (Sequence(float)): Lower and upper percentiles used for intensity scaling. Default is (1, 99). channel_wise (bool): Whether to apply normalization on each channel individually. Default is False. @@ -117,7 +117,7 @@ class CustomNormalizeImaged(MapTransform): # Create an instance of the normalization transform with specified parameters. self.normalizer: CustomNormalizeImage = CustomNormalizeImage(percentiles, channel_wise) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: """ Apply the normalization transform to each image in the input dictionary. @@ -125,10 +125,10 @@ class CustomNormalizeImaged(MapTransform): data (Mapping[Hashable, np.ndarray]): A dictionary mapping keys to numpy arrays representing images. Returns: - Dict[Hashable, np.ndarray]: A dictionary with the same keys where the images have been normalized. + dict(Hashable, np.ndarray): A dictionary with the same keys where the images have been normalized. """ # Copy the input dictionary to avoid modifying the original data. - d: Dict[Hashable, np.ndarray] = dict(data) + d: dict[Hashable, np.ndarray] = dict(data) # Iterate over each key specified in the transform and normalize the corresponding image. for key in self.keys: d[key] = self.normalizer(d[key]) diff --git a/core/data/transforms/random_crop.py b/core/data/transforms/random_crop.py index 634b864..01cfb99 100644 --- a/core/data/transforms/random_crop.py +++ b/core/data/transforms/random_crop.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import Hashable, List, Sequence, Optional, Tuple +from typing import Sequence from monai.utils.misc import fall_back_tuple from monai.data.meta_tensor import MetaTensor @@ -14,7 +14,7 @@ logger = get_logger(__name__) def _compute_multilabel_bbox( mask: np.ndarray -) -> Optional[Tuple[List[int], List[int], List[int], List[int]]]: +) -> tuple[list[int], list[int], list[int], list[int]] | None: """ Compute per-channel bounding-box constraints and return lists of limits for each axis. @@ -33,10 +33,10 @@ def _compute_multilabel_bbox( if channels.size == 0: return None - top_mins: List[int] = [] - top_maxs: List[int] = [] - left_mins: List[int] = [] - left_maxs: List[int] = [] + top_mins: list[int] = [] + top_maxs: list[int] = [] + left_mins: list[int] = [] + left_maxs: list[int] = [] C = mask.shape[0] for ch in range(C): rs, cs = np.nonzero(mask[ch]) @@ -74,7 +74,7 @@ class SpatialCropAllClasses(Randomizable, Crop): super().__init__(lazy=lazy) self.roi_size = tuple(roi_size) self.num_candidates = num_candidates - self._slices: Optional[Tuple[slice, ...]] = None + self._slices: tuple[slice, ...] | None = None def randomize(self, img_size: Sequence[int]) -> None: # type: ignore """ @@ -139,7 +139,7 @@ class SpatialCropAllClasses(Randomizable, Crop): slice(left, left + crop_w), ) - def __call__(self, img: torch.Tensor, lazy: Optional[bool] = None) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore """ On first call (mask), computes crop. On subsequent (image), just applies. Raises if mask not provided first. diff --git a/core/losses/__init__.py b/core/losses/__init__.py index 203d9ba..0e18f5e 100644 --- a/core/losses/__init__.py +++ b/core/losses/__init__.py @@ -1,4 +1,4 @@ -from typing import Dict, Final, Tuple, Type, List, Any, Union +from typing import Final, Type, Any from pydantic import BaseModel from .base import BaseLoss @@ -16,7 +16,7 @@ __all__ = [ class CriterionRegistry: """Registry of loss functions and their parameter classes with case-insensitive lookup.""" - __CRITERIONS: Final[Dict[str, Dict[str, Any]]] = { + __CRITERIONS: Final[dict[str, dict[str, Any]]] = { "CrossEntropyLoss": { "class": CrossEntropyLoss, "params": CrossEntropyLossParams, @@ -36,7 +36,7 @@ class CriterionRegistry: } @classmethod - def __get_entry(cls, name: str) -> Dict[str, Any]: + def __get_entry(cls, name: str) -> dict[str, Any]: """ Private method to retrieve the criterion entry from the registry using case-insensitive lookup. @@ -44,7 +44,7 @@ class CriterionRegistry: name (str): The name of the loss function. Returns: - Dict[str, Any]: A dictionary containing the keys 'class' and 'params'. + dict(str, Any): A dictionary containing the keys 'class' and 'params'. Raises: ValueError: If the loss function is not found. @@ -67,7 +67,7 @@ class CriterionRegistry: name (str): Name of the loss function. Returns: - Type[BaseLoss]: The loss function class. + Type(BaseLoss): The loss function class. """ entry = cls.__get_entry(name) return entry["class"] @@ -81,17 +81,17 @@ class CriterionRegistry: name (str): Name of the loss function. Returns: - Type[BaseModel]: The loss function parameter class. + Type(BaseModel): The loss function parameter class. """ entry = cls.__get_entry(name) return entry["params"] @classmethod - def get_available_criterions(cls) -> Tuple[str, ...]: + def get_available_criterions(cls) -> tuple[str, ...]: """ Returns a tuple of available loss function names in their original case. Returns: - Tuple[str]: Tuple of available loss function names. + tuple(str): Tuple of available loss function names. """ return tuple(cls.__CRITERIONS.keys()) diff --git a/core/losses/base.py b/core/losses/base.py index 0b014e2..2a6ef83 100644 --- a/core/losses/base.py +++ b/core/losses/base.py @@ -1,15 +1,12 @@ import abc import torch -import torch.nn as nn from pydantic import BaseModel -from typing import Dict, Any, Optional -from monai.metrics.cumulative_average import CumulativeAverage -class BaseLoss(nn.Module, abc.ABC): +class BaseLoss(torch.nn.Module, abc.ABC): """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" - def __init__(self, params: Optional[BaseModel] = None): + def __init__(self, params: BaseModel | None = None) -> None: super().__init__() @@ -28,16 +25,16 @@ class BaseLoss(nn.Module, abc.ABC): @abc.abstractmethod - def get_loss_metrics(self) -> Dict[str, float]: + def get_loss_metrics(self) -> dict[str, float]: """ Retrieves the tracked loss metrics. Returns: - Dict[str, float]: A dictionary containing the loss name and average loss value. + dict(str, float): A dictionary containing the loss name and average loss value. """ @abc.abstractmethod - def reset_metrics(self): + def reset_metrics(self) -> None: """Resets the stored loss metrics.""" \ No newline at end of file diff --git a/core/losses/bce.py b/core/losses/bce.py index f7c9484..d55d163 100644 --- a/core/losses/bce.py +++ b/core/losses/bce.py @@ -1,6 +1,8 @@ -from .base import * -from typing import List, Literal, Union +from .base import BaseLoss +import torch +from typing import Any, Literal from pydantic import BaseModel, ConfigDict +from monai.metrics.cumulative_average import CumulativeAverage class BCELossParams(BaseModel): @@ -11,11 +13,11 @@ class BCELossParams(BaseModel): with_logits: bool = False - weight: Optional[List[Union[int, float]]] = None # Sample weights + weight: list[int | float] | None = None # Sample weights reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method - pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss + pos_weight: list[int | float] | None = None # Used only for BCEWithLogitsLoss - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """ Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`. @@ -23,7 +25,7 @@ class BCELossParams(BaseModel): - Ensures only the valid parameters are passed based on the loss function. Returns: - Dict[str, Any]: Filtered dictionary of parameters. + dict(str, Any): Filtered dictionary of parameters. """ loss_kwargs = self.model_dump() if not self.with_logits: @@ -47,19 +49,23 @@ class BCELoss(BaseLoss): Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics. """ - def __init__(self, params: Optional[BCELossParams] = None): + def __init__(self, params: BCELossParams | None = None) -> None: """ Initializes the loss function with optional BCELoss parameters. Args: - params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None). + params (BCELossParams | None): Parameters for nn.BCELoss (default: None). """ super().__init__(params=params) with_logits = params.with_logits if params is not None else False _bce_params = params.asdict() if params is not None else {} # Initialize loss functions with user-provided parameters or PyTorch defaults - self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params) + self.bce_loss = ( + torch.nn.BCEWithLogitsLoss(**_bce_params) + if with_logits + else torch.nn.BCELoss(**_bce_params) + ) # Using CumulativeAverage from MONAI to track loss metrics self.loss_bce_metric = CumulativeAverage() @@ -90,18 +96,18 @@ class BCELoss(BaseLoss): return loss - def get_loss_metrics(self) -> Dict[str, float]: + def get_loss_metrics(self) -> dict[str, float]: """ Retrieves the tracked loss metrics. Returns: - Dict[str, float]: A dictionary containing the average BCE loss. + dict(str, float): A dictionary containing the average BCE loss. """ return { "loss": round(self.loss_bce_metric.aggregate().item(), 4), } - def reset_metrics(self): + def reset_metrics(self) -> None: """Resets the stored loss metrics.""" self.loss_bce_metric.reset() diff --git a/core/losses/ce.py b/core/losses/ce.py index f1e2db7..28616ae 100644 --- a/core/losses/ce.py +++ b/core/losses/ce.py @@ -1,6 +1,8 @@ -from .base import * -from typing import List, Literal, Union +from .base import BaseLoss +import torch +from typing import Any, Literal from pydantic import BaseModel, ConfigDict +from monai.metrics.cumulative_average import CumulativeAverage class CrossEntropyLossParams(BaseModel): @@ -9,17 +11,17 @@ class CrossEntropyLossParams(BaseModel): """ model_config = ConfigDict(frozen=True) - weight: Optional[List[Union[int, float]]] = None + weight: list[int | float] | None = None ignore_index: int = -100 reduction: Literal["none", "mean", "sum"] = "mean" label_smoothing: float = 0.0 - def asdict(self): + def asdict(self) -> dict[str, Any]: """ Returns a dictionary of valid parameters for `nn.CrossEntropyLoss`. Returns: - Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss. + dict(str, Any): Dictionary of parameters for nn.CrossEntropyLoss. """ loss_kwargs = self.model_dump() @@ -36,18 +38,18 @@ class CrossEntropyLoss(BaseLoss): Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics. """ - def __init__(self, params: Optional[CrossEntropyLossParams] = None): + def __init__(self, params: CrossEntropyLossParams | None = None) -> None: """ Initializes the loss function with optional CrossEntropyLoss parameters. Args: - params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None). + params (CrossEntropyLossParams | None): Parameters for nn.CrossEntropyLoss (default: None). """ super().__init__(params=params) _ce_params = params.asdict() if params is not None else {} # Initialize loss functions with user-provided parameters or PyTorch defaults - self.ce_loss = nn.CrossEntropyLoss(**_ce_params) + self.ce_loss = torch.nn.CrossEntropyLoss(**_ce_params) # Using CumulativeAverage from MONAI to track loss metrics self.loss_ce_metric = CumulativeAverage() @@ -78,18 +80,18 @@ class CrossEntropyLoss(BaseLoss): return loss - def get_loss_metrics(self) -> Dict[str, float]: + def get_loss_metrics(self) -> dict[str, float]: """ Retrieves the tracked loss metrics. Returns: - Dict[str, float]: A dictionary containing the average CrossEntropy loss. + dict(str, float): A dictionary containing the average CrossEntropy loss. """ return { "loss": round(self.loss_ce_metric.aggregate().item(), 4), } - def reset_metrics(self): + def reset_metrics(self) -> None: """Resets the stored loss metrics.""" self.loss_ce_metric.reset() diff --git a/core/losses/mse.py b/core/losses/mse.py index 2b5ec98..f357e8c 100644 --- a/core/losses/mse.py +++ b/core/losses/mse.py @@ -1,6 +1,8 @@ -from .base import * -from typing import Literal +from .base import BaseLoss +import torch +from typing import Any, Literal from pydantic import BaseModel, ConfigDict +from monai.metrics.cumulative_average import CumulativeAverage class MSELossParams(BaseModel): @@ -11,12 +13,12 @@ class MSELossParams(BaseModel): reduction: Literal["none", "mean", "sum"] = "mean" - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """ Returns a dictionary of valid parameters for `nn.MSELoss`. Returns: - Dict[str, Any]: Dictionary of parameters for `nn.MSELoss`. + dict(str, Any): Dictionary of parameters for `nn.MSELoss`. """ loss_kwargs = self.model_dump() return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values @@ -27,18 +29,18 @@ class MSELoss(BaseLoss): Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics. """ - def __init__(self, params: Optional[MSELossParams] = None): + def __init__(self, params: MSELossParams | None = None): """ Initializes the loss function with optional MSELoss parameters. Args: - params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None). + params (MSELossParams | None): Parameters for `nn.MSELoss` (default: None). """ super().__init__(params=params) _mse_params = params.asdict() if params is not None else {} # Initialize MSE loss with user-provided parameters or PyTorch defaults - self.mse_loss = nn.MSELoss(**_mse_params) + self.mse_loss = torch.nn.MSELoss(**_mse_params) # Using CumulativeAverage from MONAI to track loss metrics self.loss_mse_metric = CumulativeAverage() @@ -67,12 +69,12 @@ class MSELoss(BaseLoss): return loss - def get_loss_metrics(self) -> Dict[str, float]: + def get_loss_metrics(self) -> dict[str, float]: """ Retrieves the tracked loss metrics. Returns: - Dict[str, float]: A dictionary containing the average MSE loss. + dict(str, float): A dictionary containing the average MSE loss. """ return { "loss": round(self.loss_mse_metric.aggregate().item(), 4), diff --git a/core/losses/mse_with_bce.py b/core/losses/mse_with_bce.py index acd8ee9..f5f89ba 100644 --- a/core/losses/mse_with_bce.py +++ b/core/losses/mse_with_bce.py @@ -1,8 +1,11 @@ -from .base import * +from .base import BaseLoss from .bce import BCELossParams from .mse import MSELossParams +import torch +from typing import Any from pydantic import BaseModel, ConfigDict +from monai.metrics.cumulative_average import CumulativeAverage class BCE_MSE_LossParams(BaseModel): @@ -15,12 +18,12 @@ class BCE_MSE_LossParams(BaseModel): bce_params: BCELossParams = BCELossParams() mse_params: MSELossParams = MSELossParams() - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """ Returns a dictionary of valid parameters for `nn.BCELoss` and `nn.MSELoss`. Returns: - Dict[str, Any]: Dictionary of parameters. + dict(str, Any): Dictionary of parameters. """ return { @@ -35,7 +38,7 @@ class BCE_MSE_Loss(BaseLoss): Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction. """ - def __init__(self, params: Optional[BCE_MSE_LossParams]): + def __init__(self, params: BCE_MSE_LossParams | None = None): """ Initializes the loss function with optional BCE and MSE parameters. """ @@ -50,14 +53,16 @@ class BCE_MSE_Loss(BaseLoss): # Choose BCE loss function self.bce_loss = ( - nn.BCEWithLogitsLoss(**_bce_params) if _params.bce_params.with_logits else nn.BCELoss(**_bce_params) + torch.nn.BCEWithLogitsLoss(**_bce_params) + if _params.bce_params.with_logits + else torch.nn.BCELoss(**_bce_params) ) # Process MSE parameters _mse_params = _params.mse_params.asdict() # Initialize MSE loss - self.mse_loss = nn.MSELoss(**_mse_params) + self.mse_loss = torch.nn.MSELoss(**_mse_params) # Using CumulativeAverage from MONAI to track loss metrics self.loss_bce_metric = CumulativeAverage() @@ -101,12 +106,12 @@ class BCE_MSE_Loss(BaseLoss): return total_loss - def get_loss_metrics(self) -> Dict[str, float]: + def get_loss_metrics(self) -> dict[str, float]: """ Retrieves the tracked loss metrics. Returns: - Dict[str, float]: A dictionary containing the average BCE and MSE loss. + dict(str, float): A dictionary containing the average BCE and MSE loss. """ return { "bce_loss": round(self.loss_bce_metric.aggregate().item(), 4), diff --git a/core/models/__init__.py b/core/models/__init__.py index c39baa5..f29508f 100644 --- a/core/models/__init__.py +++ b/core/models/__init__.py @@ -1,5 +1,5 @@ -import torch.nn as nn -from typing import Dict, Final, Tuple, Type, Any, List, Union +from torch import nn +from typing import Final, Type, Any from pydantic import BaseModel from .model_v import ModelV, ModelVParams @@ -16,7 +16,7 @@ class ModelRegistry: """Registry for models and their parameter classes with case-insensitive lookup.""" # Single dictionary storing both model classes and parameter classes. - __MODELS: Final[Dict[str, Dict[str, Type[Any]]]] = { + __MODELS: Final[dict[str, dict[str, Type[Any]]]] = { "ModelV": { "class": ModelV, "params": ModelVParams, @@ -24,7 +24,7 @@ class ModelRegistry: } @classmethod - def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: + def __get_entry(cls, name: str) -> dict[str, Type[Any]]: """ Private method to retrieve the model entry from the registry using case-insensitive lookup. @@ -32,7 +32,7 @@ class ModelRegistry: name (str): The name of the model. Returns: - Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. + dict(str, Type[Any]): A dictionary containing the keys 'class' and 'params'. Raises: ValueError: If the model is not found. @@ -55,7 +55,7 @@ class ModelRegistry: name (str): Name of the model. Returns: - Type[nn.Module]: The model class. + Type(torch.nn.Module): The model class. """ entry = cls.__get_entry(name) return entry["class"] @@ -69,17 +69,17 @@ class ModelRegistry: name (str): Name of the model. Returns: - Type[BaseModel]: The model parameter class. + Type(BaseModel): The model parameter class. """ entry = cls.__get_entry(name) return entry["params"] @classmethod - def get_available_models(cls) -> Tuple[str, ...]: + def get_available_models(cls) -> tuple[str, ...]: """ Returns a tuple of available model names in their original case. Returns: - Tuple[str]: Tuple of available model names. + Tuple(str): Tuple of available model names. """ return tuple(cls.__MODELS.keys()) diff --git a/core/models/model_v.py b/core/models/model_v.py index 02e11e8..35418ea 100644 --- a/core/models/model_v.py +++ b/core/models/model_v.py @@ -1,8 +1,7 @@ -from typing import List, Optional - import torch -import torch.nn as nn +from torch import nn +from typing import Any from segmentation_models_pytorch import MAnet from segmentation_models_pytorch.base.modules import Activation @@ -15,18 +14,18 @@ class ModelVParams(BaseModel): model_config = ConfigDict(frozen=True) encoder_name: str = "mit_b5" # Default encoder - encoder_weights: Optional[str] = "imagenet" # Pre-trained weights - decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration + encoder_weights: str | None = "imagenet" # Pre-trained weights + decoder_channels: list[int] = [1024, 512, 256, 128, 64] # Decoder configuration decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels in_channels: int = 3 # Number of input channels out_classes: int = 1 # Number of output classes - def asdict(self): + def asdict(self) -> dict[str, Any]: """ Returns a dictionary of valid parameters for `nn.ModelV`. Returns: - Dict[str, Any]: Dictionary of parameters for nn.ModelV. + dict(str, Any): Dictionary of parameters for nn.ModelV. """ loss_kwargs = self.model_dump() return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values @@ -84,11 +83,11 @@ class DeepSegmentationHead(nn.Sequential): in_channels: int, out_channels: int, kernel_size: int = 3, - activation: Optional[str] = None, + activation: str | None = None, upsampling: int = 1, ) -> None: # Define a sequence of layers for the segmentation head - layers: List[nn.Module] = [ + layers: list[nn.Module] = [ nn.Conv2d( in_channels, in_channels // 2, diff --git a/core/optimizers/__init__.py b/core/optimizers/__init__.py index f3a80d0..6f147b1 100644 --- a/core/optimizers/__init__.py +++ b/core/optimizers/__init__.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import Dict, Final, Tuple, Type, List, Any, Union +from typing import Final, Type, Any from .base import BaseOptimizer from .adam import AdamParams, AdamOptimizer @@ -16,7 +16,7 @@ class OptimizerRegistry: """Registry for optimizers and their parameter classes with case-insensitive lookup.""" # Single dictionary storing both optimizer classes and parameter classes. - __OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = { + __OPTIMIZERS: Final[dict[str, dict[str, Type[Any]]]] = { "SGD": { "class": SGDOptimizer, "params": SGDParams, @@ -32,7 +32,7 @@ class OptimizerRegistry: } @classmethod - def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: + def __get_entry(cls, name: str) -> dict[str, Type[Any]]: """ Private method to retrieve the optimizer entry from the registry using case-insensitive lookup. @@ -40,7 +40,7 @@ class OptimizerRegistry: name (str): The name of the optimizer. Returns: - Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. + dict(str, Type(Any)): A dictionary containing the keys 'class' and 'params'. Raises: ValueError: If the optimizer is not found. @@ -63,7 +63,7 @@ class OptimizerRegistry: name (str): Name of the optimizer. Returns: - Type[BaseOptimizer]: The optimizer class. + Type(BaseOptimizer): The optimizer class. """ entry = cls.__get_entry(name) return entry["class"] @@ -77,17 +77,17 @@ class OptimizerRegistry: name (str): Name of the optimizer. Returns: - Type[BaseModel]: The optimizer parameter class. + Type(BaseModel): The optimizer parameter class. """ entry = cls.__get_entry(name) return entry["params"] @classmethod - def get_available_optimizers(cls) -> Tuple[str, ...]: + def get_available_optimizers(cls) -> tuple[str, ...]: """ Returns a tuple of available optimizer names in their original case. Returns: - Tuple[str]: Tuple of available optimizer names. + Tuple(str): Tuple of available optimizer names. """ return tuple(cls.__OPTIMIZERS.keys()) diff --git a/core/optimizers/adam.py b/core/optimizers/adam.py index 7ebe157..b7422be 100644 --- a/core/optimizers/adam.py +++ b/core/optimizers/adam.py @@ -1,6 +1,6 @@ import torch from torch import optim -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Iterable from pydantic import BaseModel, ConfigDict from .base import BaseOptimizer @@ -10,12 +10,12 @@ class AdamParams(BaseModel): model_config = ConfigDict(frozen=True) lr: float = 1e-3 # Learning rate - betas: Tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages + betas: tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages eps: float = 1e-8 # Term added to denominator for numerical stability weight_decay: float = 0.0 # L2 regularization amsgrad: bool = False # Whether to use the AMSGrad variant - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.Adam`.""" return self.model_dump() @@ -25,7 +25,7 @@ class AdamOptimizer(BaseOptimizer): Wrapper around torch.optim.Adam. """ - def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams): + def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams) -> None: """ Initializes the Adam optimizer with given parameters. diff --git a/core/optimizers/adamw.py b/core/optimizers/adamw.py index f2d7ddb..ba9db0e 100644 --- a/core/optimizers/adamw.py +++ b/core/optimizers/adamw.py @@ -1,6 +1,6 @@ import torch from torch import optim -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Iterable from pydantic import BaseModel, ConfigDict from .base import BaseOptimizer @@ -10,12 +10,12 @@ class AdamWParams(BaseModel): model_config = ConfigDict(frozen=True) lr: float = 1e-3 # Learning rate - betas: Tuple[float, ...] = (0.9, 0.999) # Adam coefficients + betas: tuple[float, ...] = (0.9, 0.999) # Adam coefficients eps: float = 1e-8 # Numerical stability weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay) amsgrad: bool = False # Whether to use the AMSGrad variant - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.AdamW`.""" return self.model_dump() @@ -25,7 +25,7 @@ class AdamWOptimizer(BaseOptimizer): Wrapper around torch.optim.AdamW. """ - def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams): + def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams) -> None: """ Initializes the AdamW optimizer with given parameters. diff --git a/core/optimizers/base.py b/core/optimizers/base.py index 32cef7c..a252bc5 100644 --- a/core/optimizers/base.py +++ b/core/optimizers/base.py @@ -1,15 +1,15 @@ import torch -import torch.optim as optim +from torch import optim from pydantic import BaseModel -from typing import Any, Iterable, Optional +from typing import Any, Iterable class BaseOptimizer: """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" - def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel): + def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel) -> None: super().__init__() - self.optim: Optional[optim.Optimizer] = None + self.optim: optim.Optimizer | None = None def zero_grad(self, set_to_none: bool = True) -> None: @@ -25,12 +25,12 @@ class BaseOptimizer: self.optim.zero_grad(set_to_none=set_to_none) - def step(self, closure: Optional[Any] = None) -> Any: + def step(self, closure: Any | None = None) -> Any: """ Performs a single optimization step (parameter update). Args: - closure (Optional[Callable]): A closure that reevaluates the model and returns the loss. + closure (Any | None): A closure that reevaluates the model and returns the loss. This is required for optimizers like LBFGS that need multiple forward passes. Returns: diff --git a/core/optimizers/sgd.py b/core/optimizers/sgd.py index 3870b9e..b4097d6 100644 --- a/core/optimizers/sgd.py +++ b/core/optimizers/sgd.py @@ -1,6 +1,6 @@ import torch from torch import optim -from typing import Any, Dict, Iterable, Optional +from typing import Any, Iterable from pydantic import BaseModel, ConfigDict from .base import BaseOptimizer @@ -16,7 +16,7 @@ class SGDParams(BaseModel): weight_decay: float = 0.0 # L2 penalty nesterov: bool = False # Enables Nesterov momentum - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.SGD`.""" return self.model_dump() @@ -26,7 +26,7 @@ class SGDOptimizer(BaseOptimizer): Wrapper around torch.optim.SGD. """ - def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams): + def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams) -> None: """ Initializes the SGD optimizer with given parameters. diff --git a/core/schedulers/__init__.py b/core/schedulers/__init__.py index f2ff6ca..04eb21d 100644 --- a/core/schedulers/__init__.py +++ b/core/schedulers/__init__.py @@ -1,5 +1,4 @@ -import torch.optim.lr_scheduler as lr_scheduler -from typing import Dict, Final, Tuple, Type, List, Any, Union +from typing import Final, Type, Any from pydantic import BaseModel from .base import BaseScheduler @@ -17,7 +16,7 @@ __all__ = [ class SchedulerRegistry: """Registry for learning rate schedulers and their parameter classes with case-insensitive lookup.""" - __SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = { + __SCHEDULERS: Final[dict[str, dict[str, Type[Any]]]] = { "Step": { "class": StepLRScheduler, "params": StepLRParams, @@ -37,7 +36,7 @@ class SchedulerRegistry: } @classmethod - def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: + def __get_entry(cls, name: str) -> dict[str, Type[Any]]: """ Private method to retrieve the scheduler entry from the registry using case-insensitive lookup. @@ -45,7 +44,7 @@ class SchedulerRegistry: name (str): The name of the scheduler. Returns: - Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. + dict(str, Type(Any)): A dictionary containing the keys 'class' and 'params'. Raises: ValueError: If the scheduler is not found. @@ -68,7 +67,7 @@ class SchedulerRegistry: name (str): Name of the scheduler. Returns: - Type[BaseScheduler]: The scheduler class. + Type(BaseScheduler): The scheduler class. """ entry = cls.__get_entry(name) return entry["class"] @@ -82,17 +81,17 @@ class SchedulerRegistry: name (str): Name of the scheduler. Returns: - Type[BaseModel]: The scheduler parameter class. + Type(BaseModel): The scheduler parameter class. """ entry = cls.__get_entry(name) return entry["params"] @classmethod - def get_available_schedulers(cls) -> Tuple[str, ...]: + def get_available_schedulers(cls) -> tuple[str, ...]: """ Returns a tuple of available scheduler names in their original case. Returns: - Tuple[str]: Tuple of available scheduler names. + Tuple(str): Tuple of available scheduler names. """ return tuple(cls.__SCHEDULERS.keys()) diff --git a/core/schedulers/base.py b/core/schedulers/base.py index 892f7fd..db81b01 100644 --- a/core/schedulers/base.py +++ b/core/schedulers/base.py @@ -1,6 +1,5 @@ -import torch.optim as optim +from torch import optim from pydantic import BaseModel -from typing import List, Optional class BaseScheduler: @@ -9,8 +8,8 @@ class BaseScheduler: Wraps a PyTorch LR scheduler and provides a unified interface. """ - def __init__(self, optimizer: optim.Optimizer, params: BaseModel): - self.scheduler: Optional[optim.lr_scheduler.LRScheduler] = None + def __init__(self, optimizer: optim.Optimizer, params: BaseModel) -> None: + self.scheduler: optim.lr_scheduler.LRScheduler | None = None def step(self) -> None: """ @@ -20,7 +19,7 @@ class BaseScheduler: if self.scheduler is not None: self.scheduler.step() - def get_last_lr(self) -> List[float]: + def get_last_lr(self) -> list[float]: """ Returns the most recent learning rate(s). """ diff --git a/core/schedulers/cosine_annealing.py b/core/schedulers/cosine_annealing.py index f0bba2b..bcdba8b 100644 --- a/core/schedulers/cosine_annealing.py +++ b/core/schedulers/cosine_annealing.py @@ -1,10 +1,9 @@ -from typing import Any, Dict -from pydantic import BaseModel, ConfigDict -from torch import optim -from torch.optim.lr_scheduler import CosineAnnealingLR - from .base import BaseScheduler +from typing import Any +from torch import optim +from torch.optim.lr_scheduler import CosineAnnealingLR +from pydantic import BaseModel, ConfigDict class CosineAnnealingLRParams(BaseModel): @@ -16,7 +15,7 @@ class CosineAnnealingLRParams(BaseModel): last_epoch: int = -1 - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`.""" return self.model_dump() @@ -26,7 +25,7 @@ class CosineAnnealingLRScheduler(BaseScheduler): Wrapper around torch.optim.lr_scheduler.CosineAnnealingLR. """ - def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams): + def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams) -> None: """ Args: optimizer (Optimizer): Wrapped optimizer. diff --git a/core/schedulers/exponential.py b/core/schedulers/exponential.py index 807aed7..778658e 100644 --- a/core/schedulers/exponential.py +++ b/core/schedulers/exponential.py @@ -1,9 +1,9 @@ -from typing import Any, Dict -from pydantic import BaseModel, ConfigDict +from .base import BaseScheduler + +from typing import Any from torch import optim from torch.optim.lr_scheduler import ExponentialLR - -from .base import BaseScheduler +from pydantic import BaseModel, ConfigDict class ExponentialLRParams(BaseModel): @@ -13,7 +13,7 @@ class ExponentialLRParams(BaseModel): gamma: float = 0.95 # Multiplicative factor of learning rate decay last_epoch: int = -1 - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`.""" return self.model_dump() @@ -23,7 +23,7 @@ class ExponentialLRScheduler(BaseScheduler): Wrapper around torch.optim.lr_scheduler.ExponentialLR. """ - def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams): + def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams) -> None: """ Args: optimizer (Optimizer): Wrapped optimizer. diff --git a/core/schedulers/multi_step.py b/core/schedulers/multi_step.py index 99d3858..d79944d 100644 --- a/core/schedulers/multi_step.py +++ b/core/schedulers/multi_step.py @@ -1,20 +1,20 @@ -from typing import Any, Dict, Tuple -from pydantic import BaseModel, ConfigDict +from .base import BaseScheduler + +from typing import Any from torch import optim from torch.optim.lr_scheduler import MultiStepLR - -from .base import BaseScheduler +from pydantic import BaseModel, ConfigDict class MultiStepLRParams(BaseModel): """Configuration for `torch.optim.lr_scheduler.MultiStepLR`.""" model_config = ConfigDict(frozen=True) - milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay + milestones: tuple[int, ...] = (30, 80) # List of epoch indices for LR decay gamma: float = 0.1 # Multiplicative factor of learning rate decay last_epoch: int = -1 - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`.""" return self.model_dump() @@ -24,7 +24,7 @@ class MultiStepLRScheduler(BaseScheduler): Wrapper around torch.optim.lr_scheduler.MultiStepLR. """ - def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams): + def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams) -> None: """ Args: optimizer (Optimizer): Wrapped optimizer. diff --git a/core/schedulers/step.py b/core/schedulers/step.py index 0b9b609..d8da435 100644 --- a/core/schedulers/step.py +++ b/core/schedulers/step.py @@ -1,9 +1,9 @@ -from typing import Any, Dict -from pydantic import BaseModel, ConfigDict +from .base import BaseScheduler + +from typing import Any from torch import optim from torch.optim.lr_scheduler import StepLR - -from .base import BaseScheduler +from pydantic import BaseModel, ConfigDict class StepLRParams(BaseModel): @@ -14,7 +14,7 @@ class StepLRParams(BaseModel): gamma: float = 0.1 # Multiplicative factor of learning rate decay last_epoch: int = -1 - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`.""" return self.model_dump() @@ -25,7 +25,7 @@ class StepLRScheduler(BaseScheduler): Wrapper around torch.optim.lr_scheduler.StepLR. """ - def __init__(self, optimizer: optim.Optimizer, params: StepLRParams): + def __init__(self, optimizer: optim.Optimizer, params: StepLRParams) -> None: """ Args: optimizer (Optimizer): Wrapped optimizer. diff --git a/core/segmentator.py b/core/segmentator.py index f67ed5a..5cc9151 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -19,15 +19,14 @@ from torch.utils.data import DataLoader import fastremap import fill_voids -from skimage import morphology +# from skimage import morphology from skimage.segmentation import find_boundaries from scipy.special import expit from scipy.ndimage import mean, find_objects from monai.data.dataset import Dataset -from monai.transforms import * # type: ignore +from monai.transforms.compose import Compose from monai.inferers.utils import sliding_window_inference -from monai.metrics.cumulative_average import CumulativeAverage import matplotlib.pyplot as plt import matplotlib.colors as mcolors @@ -42,16 +41,16 @@ from itertools import chain from pprint import pformat from tabulate import tabulate -from typing import Any, Dict, Literal, Optional, Tuple, List, Union +from typing import Any, Literal from tqdm import tqdm import wandb from config import Config -from core.models import * -from core.losses import * -from core.optimizers import * -from core.schedulers import * +from core.models import ModelRegistry +from core.losses import CriterionRegistry +from core.optimizers import OptimizerRegistry +from core.schedulers import SchedulerRegistry from core.utils import ( compute_batch_segmentation_tp_fp_fn, compute_f1_score, @@ -78,30 +77,30 @@ class CellSegmentator: else None ) - self._train_dataloader: Optional[DataLoader] = None - self._valid_dataloader: Optional[DataLoader] = None - self._test_dataloader: Optional[DataLoader] = None - self._predict_dataloader: Optional[DataLoader] = None + self._train_dataloader: DataLoader | None = None + self._valid_dataloader: DataLoader | None = None + self._test_dataloader: DataLoader | None = None + self._predict_dataloader: DataLoader | None = None self._best_weights = None def create_dataloaders( self, - train_transforms: Optional[Compose] = None, - valid_transforms: Optional[Compose] = None, - test_transforms: Optional[Compose] = None, - predict_transforms: Optional[Compose] = None + train_transforms: Compose | None = None, + valid_transforms: Compose | None = None, + test_transforms: Compose | None = None, + predict_transforms: Compose | None = None ) -> None: """ Creates train, validation, test, and prediction dataloaders based on dataset configuration and provided transforms. Args: - train_transforms (Optional[Compose]): Transformations for training data. - valid_transforms (Optional[Compose]): Transformations for validation data. - test_transforms (Optional[Compose]): Transformations for testing data. - predict_transforms (Optional[Compose]): Transformations for prediction data. + train_transforms (Compose | None): Transformations for training data. + valid_transforms (Compose | None): Transformations for validation data. + test_transforms (Compose | None): Transformations for testing data. + predict_transforms (Compose | None): Transformations for prediction data. Raises: ValueError: If required transforms are missing. @@ -257,7 +256,7 @@ class CellSegmentator: def print_data_info( self, loader_type: Literal["train", "valid", "test", "predict"], - index: Optional[int] = None + index: int | None = None ) -> None: """ Prints statistics for a single sample from the specified dataloader. @@ -267,7 +266,7 @@ class CellSegmentator: index: The sample index; if None, a random index is selected. """ # Retrieve the dataloader attribute, e.g., self._train_dataloader - loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None) + loader: DataLoader | None = getattr(self, f"_{loader_type}_dataloader", None) if loader is None: logger.error(f"Dataloader '{loader_type}' is not initialized.") return @@ -326,8 +325,8 @@ class CellSegmentator: lines.append("=" * 40) # Output via logger - for l in lines: - logger.info(l) + for line in lines: + logger.info(line) def train(self, save_results: bool = True, only_masks: bool = False) -> None: @@ -661,16 +660,16 @@ class CellSegmentator: logger.info(f"├─ Validation frequency: {training.val_freq}") if training.is_split: - logger.info(f"├─ Using pre-split directories:") + logger.info( "├─ Using pre-split directories:") logger.info(f"│ ├─ Train dir: {training.pre_split.train_dir}") logger.info(f"│ ├─ Valid dir: {training.pre_split.valid_dir}") logger.info(f"│ └─ Test dir: {training.pre_split.test_dir}") else: - logger.info(f"├─ Using unified dataset with splits:") - logger.info(f"│ ├─ All data dir: {training.split.all_data_dir}") + logger.info( "├─ Using unified dataset with splits:") + logger.info( "│ ├─ All data dir: {training.split.all_data_dir}") logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}") - logger.info(f"└─ Dataset split:") + logger.info( "└─ Dataset split:") logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}") logger.info(f" ├─ Valid size: {training.valid_size}, offset: {training.valid_offset}") logger.info(f" └─ Test size: {training.test_size}, offset: {training.test_offset}") @@ -703,12 +702,12 @@ class CellSegmentator: logger.info("===================================") - def __set_seed(self, seed: Optional[int]) -> None: + def __set_seed(self, seed: int | None) -> None: """ Sets the random seed for reproducibility across Python, NumPy, and PyTorch. Args: - seed (Optional[int]): Seed value. If None, no seeding is performed. + seed (int | None): Seed value. If None, no seeding is performed. """ if seed is not None: random.seed(seed) @@ -724,9 +723,9 @@ class CellSegmentator: def __get_dataset( self, images_dir: str, - masks_dir: Optional[str], + masks_dir: str | None, transforms: Compose, - size: Union[int, float], + size: int | float, offset: int, shuffle: bool ) -> Dataset: @@ -735,9 +734,9 @@ class CellSegmentator: Args: images_dir (str): Path to directory or glob pattern for input images. - masks_dir (Optional[str]): Path to directory or glob pattern for masks. + masks_dir (str | None): Path to directory or glob pattern for masks. transforms (Compose): Transformations to apply to each image or pair. - size (Union[int, float]): Either an integer or a fraction of the dataset. + size (int | float): Either an integer or a fraction of the dataset. offset (int): Number of images to skip from the start. shuffle (bool): Whether to shuffle the dataset before slicing. @@ -806,12 +805,12 @@ class CellSegmentator: return Dataset(data, transforms) - def __print_with_logging(self, metrics: Dict[str, Union[float, np.ndarray]], step: int) -> None: + def __print_with_logging(self, metrics: dict[str, float | np.ndarray], step: int) -> None: """ Print metrics in a tabular format and log to W&B. Args: - metrics (Dict[str, Union[float, np.ndarray]]): Mapping from metric names + metrics (dict(str, float | np.ndarray)): Mapping from metric names to either a float or a ND numpy array. step (int): epoch index. """ @@ -846,14 +845,14 @@ class CellSegmentator: def __save_metrics_to_csv( self, - metrics: Dict[str, Union[float, np.ndarray]], + metrics: dict[str, float | np.ndarray], output_path: str ) -> None: """ Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'. Args: - metrics (Dict[str, Union[float, np.ndarray]]): + metrics (dict(str, float | np.ndarray)): Mapping from metric names to scalar values or numpy arrays. output_path (str): Path to the output CSV file. @@ -874,22 +873,22 @@ class CellSegmentator: def __run_epoch(self, mode: Literal["train", "valid", "test"], - epoch: Optional[int] = None, + epoch: int | None = None, save_results: bool = True, only_masks: bool = False - ) -> Dict[str, Union[float, np.ndarray]]: + ) -> dict[str, float | np.ndarray]: """ Execute one epoch of training, validation, or testing. Args: mode (str): One of 'train', 'valid', or 'test'. - epoch (int, optional): Current epoch number for logging. + epoch (int | None): Current epoch number for logging. save_results (bool): If True, the predicted masks and test metrics will be saved. only_masks (bool): If True and save_results is True, only raw predicted masks are saved, without visualization overlays. Returns: - Dict[str, Union[float, np.ndarray]]: Metrics for valid/test. + dict(str, float | np.ndarray): Metrics for valid/test. """ # Ensure required components are available if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None): @@ -988,7 +987,7 @@ class CellSegmentator: if self._criterion is not None: # Collect loss metrics - epoch_metrics: Dict[str, Union[float, np.ndarray]] = { + epoch_metrics: dict[str, float | np.ndarray] = { f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items() } # Reset internal loss metrics accumulator @@ -1051,17 +1050,17 @@ class CellSegmentator: def __post_process_predictions( self, raw_outputs: torch.Tensor, - ground_truth: Optional[torch.Tensor] = None - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ground_truth: torch.Tensor | None = None + ) -> tuple[np.ndarray, np.ndarray | None]: """ Post-process raw network outputs to extract instance segmentation masks. Args: raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W). - ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W). + ground_truth (torch.Tensor | None): Ground truth masks of shape (B, С, H, W). Returns: - Tuple[np.ndarray, Optional[np.ndarray]]: + tuple(np.ndarray, np.ndarray | None): - instance_masks: Instance-wise masks array of shape (B, С, H, W). - labels_np: Converted ground truth of shape (B, С, H, W) or None if ground_truth was not provided. @@ -1097,8 +1096,8 @@ class CellSegmentator: ground_truth_masks: np.ndarray, iou_threshold: float = 0.5, return_error_masks: bool = False - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, - Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, + np.ndarray | None, np.ndarray | None, np.ndarray | None]: """ Compute batch-wise true positives, false positives, and false negatives for instance segmentation, using a configurable IoU threshold. @@ -1111,7 +1110,7 @@ class CellSegmentator: return_error_masks (bool): Whether to also return binary error masks. Returns: - Tuple(np.ndarray, np.ndarray, np.ndarray, + tuple(np.ndarray, np.ndarray, np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None): - tp: True positives per batch and class, shape (B, C) - fp: False positives per batch and class, shape (B, C) @@ -1143,7 +1142,7 @@ class CellSegmentator: false_positives: np.ndarray, false_negatives: np.ndarray, reduction: Literal["micro", "macro", "weighted", "imagewise", "per_class", "none"] = "micro" - ) -> Union[float, np.ndarray]: + ) -> float | np.ndarray: """ Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes. @@ -1266,7 +1265,7 @@ class CellSegmentator: false_positives: np.ndarray, false_negatives: np.ndarray, reduction: Literal["micro", "macro", "weighted", "imagewise", 'per_class', "none"] = "micro" - ) -> Union[float, np.ndarray]: + ) -> float | np.ndarray: """ Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes. @@ -1399,23 +1398,23 @@ class CellSegmentator: def __save_prediction_masks( self, - sample: Dict[str, Any], - predicted_mask: Union[np.ndarray, torch.Tensor], + sample: dict[str, Any], + predicted_mask: np.ndarray | torch.Tensor, start_index: int = 0, only_masks: bool = False, - masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None + masks: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None ) -> None: """ Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders. Args: - sample (Dict[str, Any]): Batch sample from MONAI + sample (dict(str, Any)): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict'). - predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W). + predicted_mask (np.ndarray | torch.Tensor): Array of shape (C, H, W) or (B, C, H, W). start_index (int): Starting index for naming when metadata is missing. only_masks (bool): If True, save only the raw predicted mask TIFFs and skip PNG visualizations. - masks (Tuple[np.ndarray, np.ndarray, np.ndarray] | None): + masks (tuple[np.ndarray, np.ndarray, np.ndarray] | None): A tuple (tp_masks, fp_masks, fn_masks), each of shape (B, C, H, W). Defaults to None. """ # Base directories (created once per call) @@ -1428,14 +1427,14 @@ class CellSegmentator: os.makedirs(evaluate_dir, exist_ok=True) # Convert tensors to numpy if necessary - def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + def to_numpy(x: np.ndarray | torch.Tensor) -> np.ndarray: return x.cpu().numpy() if isinstance(x, torch.Tensor) else x pred_array = to_numpy(predicted_mask).astype(np.uint16) # Handle batch dimension for idx in range(pred_array.shape[0]): - batch_sample: Dict[str, Any] = {} + batch_sample: dict[str, Any] = {} # copy per-sample image and meta img = to_numpy(sample["image"]) if img.ndim == 4: @@ -1467,21 +1466,21 @@ class CellSegmentator: def __save_single_prediction_mask( self, - sample: Dict[str, Any], + sample: dict[str, Any], pred_array: np.ndarray, start_index: int, masks_dir: str, plots_dir: str, evaluate_dir: str, only_masks: bool = False, - masks: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None + masks: tuple[np.ndarray, np.ndarray, np.ndarray] | None = None ) -> None: """ Save a single sample's predicted mask and optional TP/FP/FN masks and visualizations. Assumes output directories already exist. Args: - sample (Dict[str, Any]): Dictionary containing 'image', 'mask', + sample (dict(str, Any)): Dictionary containing 'image', 'mask', and optional 'image_meta_dict' for metadata. pred_array (np.ndarray): Predicted mask array of shape (C,H,W). start_index (int): Base index for generating filenames when metadata is missing. @@ -1489,7 +1488,7 @@ class CellSegmentator: plots_dir (str): Directory for saving PNG visualizations. evaluate_dir (str): Directory for saving PNG visualizations of evaluation results. only_masks (bool): If True, saves only TIFF mask files; skips PNG plots. - masks (Tuple[np.ndarray, np.ndarray, np.ndarray], optional): A tuple of + masks (tuple[np.ndarray, np.ndarray, np.ndarray] | None): A tuple of true-positive, false-positive, and false-negative mask arrays, each of shape (C,H,W). Defaults to None. """ @@ -1510,7 +1509,7 @@ class CellSegmentator: "Expected 2D (H,W) or 3D (C,H,W)." ) - true_mask_array: Optional[np.ndarray] = sample.get("mask") + true_mask_array: np.ndarray | None = sample.get("mask") if isinstance(true_mask_array, np.ndarray): if true_mask_array.ndim == 2: true_mask_array = np.expand_dims(true_mask_array, axis=0) @@ -1562,7 +1561,7 @@ class CellSegmentator: file_path: str, image_data: np.ndarray, predicted_mask: np.ndarray, - true_mask: Optional[np.ndarray] = None, + true_mask: np.ndarray | None = None, ) -> None: """ Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided. @@ -1572,7 +1571,7 @@ class CellSegmentator: image_data (np.ndarray): The original input image array, expected shape (C, H, W). predicted_mask (np.ndarray): The predicted mask array, shape (H, W), depending on the task. - true_mask (Optional[np.ndarray], optional): The ground-truth mask array. + true_mask (np.ndarray | None): The ground-truth mask array. If provided, an additional row with true mask and overlap visualization will be added to the plot. Default is None. @@ -1603,7 +1602,7 @@ class CellSegmentator: img: np.ndarray, mask: np.ndarray, contour_color: str, - titles: Tuple[str, ...] + titles: tuple[str, ...] ): """ Plot a row of three panels: original image, mask, and mask boundaries on image. @@ -1618,7 +1617,8 @@ class CellSegmentator: # Panel 1: Original image ax0, ax1, ax2 = axes ax0.imshow(img, cmap='gray' if img.ndim == 2 else None) - ax0.set_title(titles[0]); ax0.axis('off') + ax0.set_title(titles[0]) + ax0.axis('off') # Compute boundaries once boundaries = find_boundaries(mask, mode='thick') @@ -1793,7 +1793,8 @@ class CellSegmentator: # Get coordinates of all non-zero pixels in the padded mask y, x = torch.nonzero(masks_padded, as_tuple=True) - y = y.int(); x = x.int() # ensure integer type + y = y.int() + x = x.int() # ensure integer type # Generate 8-connected neighbors (including center) via broadcasted offsets offsets = torch.tensor([ @@ -1830,9 +1831,12 @@ class CellSegmentator: ], dtype=np.int16) # Compute centers (pixel indices) and extents via the provided helper - centers, ext = self.__get_mask_centers_and_extents(mask_channel, slices_arr) + centers, ext = self.__get_mask_centers_and_extents( + mask_channel, slices_arr + ) # Move centers to GPU and shift by +1 for padding - meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding + # (M, 2); +1 for padding + meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # Determine number of diffusion iterations n_iter = 2 * ext.max() @@ -1865,7 +1869,7 @@ class CellSegmentator: def __get_mask_centers_and_extents( label_map: np.ndarray, slices_arr: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray]: """ Compute the centroids and extents of labeled regions in a 2D mask array. @@ -1923,7 +1927,7 @@ class CellSegmentator: neighbor_indices: torch.Tensor, center_indices: torch.Tensor, valid_neighbor_mask: torch.Tensor, - output_shape: Tuple[int, int], + output_shape: tuple[int, int], num_iterations: int = 200 ) -> np.ndarray: """ @@ -1933,7 +1937,7 @@ class CellSegmentator: neighbor_indices (torch.Tensor): Tensor of shape (2, 9, N) containing row and column indices for 9 neighbors per pixel. center_indices (torch.Tensor): Tensor of shape (2, N) with row and column indices of mask centers. valid_neighbor_mask (torch.Tensor): Boolean tensor of shape (9, N) indicating if each neighbor is valid. - output_shape (Tuple[int, int]): Desired 2D shape of the diffusion tensor, e.g., (H, W). + output_shape (tuple(int, int)): Desired 2D shape of the diffusion tensor, e.g., (H, W). num_iterations (int, optional): Number of diffusion iterations. Defaults to 200. Returns: @@ -2242,7 +2246,7 @@ class CellSegmentator: flow_field: np.ndarray, initial_coords: np.ndarray, num_iters: int = 200 - ) -> Union[np.ndarray, torch.Tensor]: + ) -> np.ndarray | torch.Tensor: """ Trace pixel positions through a flow field via iterative interpolation. @@ -2252,7 +2256,7 @@ class CellSegmentator: num_iters (int): Number of integration steps. Returns: - np.ndarray or torch.Tensor: Final (y, x) positions of each point. + (np.ndarray | torch.Tensor): Final (y, x) positions of each point. """ dims = 2 # Extract spatial dimensions @@ -2383,7 +2387,7 @@ class CellSegmentator: self, pixel_positions: torch.Tensor, valid_indices: np.ndarray, - original_shape: Tuple[int, ...], + original_shape: tuple[int, ...], pad_radius: int = 20, max_size_fraction: float = 0.4 ) -> np.ndarray: @@ -2534,7 +2538,7 @@ class CellSegmentator: input_tensor: Tensor, kernel_size: int = 5, axis: int = 1, - output_tensor: Optional[Tensor] = None + output_tensor: Tensor | None = None ) -> Tensor: """ Memory-efficient 1D max pooling along a specified axis using in-place updates. @@ -2547,7 +2551,7 @@ class CellSegmentator: input_tensor (Tensor): Source tensor for pooling. kernel_size (int): Size of the pooling window (must be odd and >= 3). axis (int): Axis along which to compute 1D max pooling. - output_tensor (Optional[Tensor]): Tensor to store the result. + output_tensor (Tensor | None): Tensor to store the result. If None, a clone of input_tensor is used. Returns: @@ -2691,7 +2695,7 @@ class CellSegmentator: self, mask: np.ndarray, flow_network: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray]: """ Compute mean squared error between network-predicted flows and flows derived from masks. @@ -2700,7 +2704,7 @@ class CellSegmentator: flow_network (np.ndarray): Network predicted flows of shape [axis, ...]. Returns: - Tuple[np.ndarray, np.ndarray]: + tuple(np.ndarray, np.ndarray): - flow_errors: 1D array (length = max label) of mean squared error per label. - computed_flows: Array of flows derived from the mask, same shape as flow_network. diff --git a/core/utils/measures.py b/core/utils/measures.py index c628a0c..af92a56 100644 --- a/core/utils/measures.py +++ b/core/utils/measures.py @@ -8,7 +8,7 @@ from numpy.typing import NDArray from numba import jit from skimage import segmentation from scipy.optimize import linear_sum_assignment -from typing import Dict, List, Tuple, Any, Union +from typing import Any from core.logger import get_logger @@ -27,7 +27,7 @@ def compute_f1_score( true_positives: int, false_positives: int, false_negatives: int -) -> Tuple[float, float, float]: +) -> tuple[float, float, float]: """ Computes the precision, recall, and F1-score given the numbers of true positives, false positives, and false negatives. @@ -76,7 +76,7 @@ def compute_confusion_matrix( ground_truth_mask: np.ndarray, predicted_mask: np.ndarray, iou_threshold: float = 0.5 -) -> Tuple[int, int, int]: +) -> tuple[int, int, int]: """ Computes the confusion matrix elements (true positives, false positives, false negatives) for a single image given the ground truth and predicted masks. @@ -114,7 +114,7 @@ def compute_segmentation_tp_fp_fn( iou_threshold: float = 0.5, return_error_masks: bool = False, remove_boundary_objects: bool = True -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Computes TP, FP and FN for segmentation on a single image. @@ -176,7 +176,7 @@ def compute_segmentation_tp_fp_fn( false_positive_mask_list.append(results.get('fp_mask')) # type: ignore false_negative_mask_list.append(results.get('fn_mask')) # type: ignore - output: Dict[str, np.ndarray] = { + output: dict[str, np.ndarray] = { 'tp': np.array(true_positive_list), 'fp': np.array(false_positive_list), 'fn': np.array(false_negative_list) @@ -194,7 +194,7 @@ def compute_segmentation_f1_metrics( iou_threshold: float = 0.5, return_error_masks: bool = False, remove_boundary_objects: bool = True -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Computes F1 metrics (precision, recall, F1-score) for segmentation on a single image. @@ -240,7 +240,7 @@ def compute_segmentation_f1_metrics( recall_list.append(recall) f1_score_list.append(f1_score) - output: Dict[str, np.ndarray] = { + output: dict[str, np.ndarray] = { 'precision': np.array(precision_list), 'recall': np.array(recall_list), 'f1_score': np.array(f1_score_list), @@ -255,7 +255,7 @@ def compute_segmentation_average_precision_metrics( iou_threshold: float = 0.5, return_error_masks: bool = False, remove_boundary_objects: bool = True -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Computes the average precision (AP) for segmentation on a single image. @@ -298,7 +298,7 @@ def compute_segmentation_average_precision_metrics( ) avg_precision_list.append(avg_precision) - output: Dict[str, np.ndarray] = { + output: dict[str, np.ndarray] = { 'avg_precision': np.array(avg_precision_list) } output.update(results) @@ -311,7 +311,7 @@ def compute_batch_segmentation_tp_fp_fn( iou_threshold: float = 0.5, return_error_masks: bool = False, remove_boundary_objects: bool = True -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Computes segmentation TP, FP and FN for a batch of images. @@ -361,7 +361,7 @@ def compute_batch_segmentation_tp_fp_fn( fp_mask_list.append(result.get('fp_mask')) # type: ignore fn_mask_list.append(result.get('fn_mask')) # type: ignore - output: Dict[str, np.ndarray] = { + output: dict[str, np.ndarray] = { 'tp': np.stack(tp_list, axis=0), 'fp': np.stack(fp_list, axis=0), 'fn': np.stack(fn_list, axis=0) @@ -379,7 +379,7 @@ def compute_batch_segmentation_f1_metrics( iou_threshold: float = 0.5, return_error_masks: bool = False, remove_boundary_objects: bool = True -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Computes segmentation F1 metrics for a batch of images. @@ -435,7 +435,7 @@ def compute_batch_segmentation_f1_metrics( fp_mask_list.append(result.get('fp_mask')) # type: ignore fn_mask_list.append(result.get('fn_mask')) # type: ignore - output: Dict[str, np.ndarray] = { + output: dict[str, np.ndarray] = { 'precision': np.stack(precision_list, axis=0), 'recall': np.stack(recall_list, axis=0), 'f1_score': np.stack(f1_score_list, axis=0), @@ -456,7 +456,7 @@ def compute_batch_segmentation_average_precision_metrics( iou_threshold: float = 0.5, return_error_masks: bool = False, remove_boundary_objects: bool = True -) -> Dict[str, NDArray]: +) -> dict[str, np.ndarray]: """ Computes segmentation average precision metrics for a batch of images. @@ -508,7 +508,7 @@ def compute_batch_segmentation_average_precision_metrics( fp_mask_list.append(result.get('fp_mask')) # type: ignore fn_mask_list.append(result.get('fn_mask')) # type: ignore - output: Dict[str, NDArray] = { + output: dict[str, np.ndarray] = { 'avg_precision': np.stack(avg_precision_list, axis=0), 'tp': np.stack(tp_list, axis=0), 'fp': np.stack(fp_list, axis=0), @@ -555,7 +555,7 @@ def _process_instance_matching( iou_threshold: float = 0.5, return_masks: bool = False, without_boundary_objects: bool = True -) -> Dict[str, Union[int, NDArray[np.uint8]]]: +) -> dict[str, int | NDArray[np.uint8]]: """ Processes instance matching on a full image by performing the following steps: - Removes objects that touch the image boundary and reindexes the masks. @@ -597,8 +597,8 @@ def _process_instance_matching( fn_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8) # Mark all ground truth objects as false negatives. fn_mask[ground_truth_mask > 0] = 1 - result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) - return result + result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) # type: ignore + return result # type: ignore # Compute the IoU matrix for the processed masks. iou_matrix = _calculate_iou(processed_ground_truth, processed_prediction) @@ -640,11 +640,11 @@ def _process_instance_matching( for pred_label in (all_prediction_labels - matched_prediction_labels): fp_mask[processed_prediction == pred_label] = 1 - result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) - return result + result.update({'tp_mask': tp_mask, 'fp_mask': fp_mask, 'fn_mask': fn_mask}) # type: ignore + return result # type: ignore -def _compute_optimal_matching_pairs(iou_matrix: np.ndarray, iou_threshold: float) -> List[Any]: +def _compute_optimal_matching_pairs(iou_matrix: np.ndarray, iou_threshold: float) -> list[Any]: """ Computes the optimal matching pairs between ground truth and predicted masks using the IoU matrix. @@ -687,7 +687,7 @@ def _compute_patch_based_metrics( iou_threshold: float = 0.5, return_masks: bool = False, without_boundary_objects: bool = True -) -> Dict[str, Union[int, NDArray[np.uint8]]]: +) -> dict[str, int | NDArray[np.uint8]]: """ Computes segmentation metrics using a patch-based approach for very large images. @@ -747,7 +747,7 @@ def _compute_patch_based_metrics( padded_fp_mask[y_start:y_end, x_start:x_end] = patch_results.get('fp_mask', 0) # type: ignore padded_fn_mask[y_start:y_end, x_start:x_end] = patch_results.get('fn_mask', 0) # type: ignore - results: Dict[str, Union[int, np.ndarray]] = {'tp': total_tp, 'fp': total_fp, 'fn': total_fn} + results: dict[str, int | np.ndarray] = {'tp': total_tp, 'fp': total_fp, 'fn': total_fn} if return_masks: # Crop the padded masks back to the original image size. results.update({ diff --git a/generate_config.py b/generate_config.py index 2bb91ec..62d45a5 100644 --- a/generate_config.py +++ b/generate_config.py @@ -1,5 +1,4 @@ import os -from typing import Tuple from config import Config, WandbConfig, DatasetConfig, ComponentConfig @@ -8,7 +7,7 @@ from core import ( ) -def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str: +def prompt_choice(prompt_message: str, options: tuple[str, ...]) -> str: """ Prompt the user with a list of options and return the selected option. """ diff --git a/main.py b/main.py index 5c2d091..4ddb2dd 100644 --- a/main.py +++ b/main.py @@ -1,20 +1,25 @@ import os +import sys import argparse import wandb from config import Config -from core.data import * +from core.data import ( + get_train_transforms, + get_valid_transforms, + get_test_transforms, + get_predict_transforms +) from core.segmentator import CellSegmentator -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Train or predict cell segmentator with specified config file." ) parser.add_argument( '-c', '--config', type=str, - default='config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json', help='Path to the JSON config file' ) parser.add_argument( @@ -36,6 +41,10 @@ def main(): ' masks without additional visualizations') ) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(0) + args = parser.parse_args() mode = args.mode @@ -44,7 +53,7 @@ def main(): if mode == 'train' and not config.dataset_config.is_training: raise ValueError( - f"Config is not set for training (is_training=False), but mode 'train' was requested." + "Config is not set for training (is_training=False), but mode 'train' was requested." ) if mode in ('test', 'predict') and config.dataset_config.is_training: raise ValueError(