1) Removed support for parameter lists;

2) added wrappers for optimizer and schedule;
3) changed parameters for losses.
master
laynholt 3 weeks ago
parent 78f97a72a2
commit 5db2220917

@ -1,29 +1,41 @@
import json import json
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, Optional
from pydantic import BaseModel from pydantic import BaseModel
from .dataset_config import DatasetConfig from .dataset_config import DatasetConfig
class Config(BaseModel):
model: Dict[str, Union[BaseModel, List[BaseModel]]]
dataset_config: DatasetConfig
criterion: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
optimizer: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
scheduler: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
@staticmethod __all__ = ["Config", "ComponentConfig"]
def __dump_field(value: Any) -> Any:
class ComponentConfig(BaseModel):
name: str
params: BaseModel
def dump(self) -> Dict[str, Any]:
""" """
Recursively dumps a field if it is a BaseModel or a list/dict of BaseModels. Recursively serializes the component into a dictionary.
Returns:
dict: A dictionary containing the component name and its serialized parameters.
""" """
if isinstance(value, BaseModel): if isinstance(self.params, BaseModel):
return value.model_dump() params_dump = self.params.model_dump()
elif isinstance(value, list):
return [Config.__dump_field(item) for item in value]
elif isinstance(value, dict):
return {k: Config.__dump_field(v) for k, v in value.items()}
else: else:
return value params_dump = self.params
return {
"name": self.name,
"params": params_dump
}
class Config(BaseModel):
model: ComponentConfig
dataset_config: DatasetConfig
criterion: Optional[ComponentConfig] = None
optimizer: Optional[ComponentConfig] = None
scheduler: Optional[ComponentConfig] = None
def save_json(self, file_path: str, indent: int = 4) -> None: def save_json(self, file_path: str, indent: int = 4) -> None:
""" """
@ -34,15 +46,15 @@ class Config(BaseModel):
indent (int): Indentation level for the JSON file. indent (int): Indentation level for the JSON file.
""" """
config_dump = { config_dump = {
"model": self.__dump_field(self.model), "model": self.model.dump(),
"dataset_config": self.dataset_config.model_dump() "dataset_config": self.dataset_config.model_dump()
} }
if self.criterion is not None: if self.criterion is not None:
config_dump.update({"criterion": self.__dump_field(self.criterion)}) config_dump["criterion"] = self.criterion.dump()
if self.optimizer is not None: if self.optimizer is not None:
config_dump.update({"optimizer": self.__dump_field(self.optimizer)}) config_dump["optimizer"] = self.optimizer.dump()
if self.scheduler is not None: if self.scheduler is not None:
config_dump.update({"scheduler": self.__dump_field(self.scheduler)}) config_dump["scheduler"] = self.scheduler.dump()
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(config_dump, indent=indent)) f.write(json.dumps(config_dump, indent=indent))
@ -67,16 +79,15 @@ class Config(BaseModel):
dataset_config = DatasetConfig(**data.get("dataset_config", {})) dataset_config = DatasetConfig(**data.get("dataset_config", {}))
# Helper function to parse registry fields. # Helper function to parse registry fields.
def parse_field(field_data: Dict[str, Any], registry_getter) -> Dict[str, Union[BaseModel, List[BaseModel]]]: def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]:
result = {} name = component_data.get("name")
for key, value in field_data.items(): params_data = component_data.get("params", {})
expected = registry_getter(key)
# If the registry returns a tuple, then we expect a list of dictionaries. if name is not None:
if isinstance(expected, tuple): expected = registry_getter(name)
result[key] = [cls_param(**item) for cls_param, item in zip(expected, value)] params = expected(**params_data)
else: return ComponentConfig(name=name, params=params)
result[key] = expected(**value) return None
return result
from core import ( from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
@ -87,6 +98,9 @@ class Config(BaseModel):
parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key)) parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key))
parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key)) parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key))
if parsed_model is None:
raise ValueError('Failed to load model information')
return cls( return cls(
model=parsed_model, model=parsed_model,
dataset_config=dataset_config, dataset_config=dataset_config,

@ -98,6 +98,10 @@ class DatasetTrainingConfig(BaseModel):
- If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist). - If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist).
- If is_split is False, validates split (all_data_dir must be non-empty and exist). - If is_split is False, validates split (all_data_dir must be non-empty and exist).
""" """
if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)):
if (self.train_size + self.valid_size + self.test_size) > 1:
raise ValueError("The total sample size with dynamically defined sizes must be <= 1")
if not self.is_split: if not self.is_split:
if not self.split.all_data_dir: if not self.split.all_data_dir:
raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split") raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split")

@ -5,12 +5,12 @@ from .base import BaseLoss
from .ce import CrossEntropyLoss, CrossEntropyLossParams from .ce import CrossEntropyLoss, CrossEntropyLossParams
from .bce import BCELoss, BCELossParams from .bce import BCELoss, BCELossParams
from .mse import MSELoss, MSELossParams from .mse import MSELoss, MSELossParams
from .mse_with_bce import BCE_MSE_Loss from .mse_with_bce import BCE_MSE_Loss, BCE_MSE_LossParams
__all__ = [ __all__ = [
"CriterionRegistry", "CriterionRegistry", "BaseLoss",
"CrossEntropyLoss", "BCELoss", "MSELoss", "BCE_MSE_Loss", "CrossEntropyLoss", "BCELoss", "MSELoss", "BCE_MSE_Loss",
"CrossEntropyLossParams", "BCELossParams", "MSELossParams" "CrossEntropyLossParams", "BCELossParams", "MSELossParams", "BCE_MSE_LossParams"
] ]
class CriterionRegistry: class CriterionRegistry:
@ -31,7 +31,7 @@ class CriterionRegistry:
}, },
"BCE_MSE_Loss": { "BCE_MSE_Loss": {
"class": BCE_MSE_Loss, "class": BCE_MSE_Loss,
"params": (BCELossParams, MSELossParams), "params": BCE_MSE_LossParams,
}, },
} }
@ -73,7 +73,7 @@ class CriterionRegistry:
return entry["class"] return entry["class"]
@classmethod @classmethod
def get_criterion_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: def get_criterion_params(cls, name: str) -> Type[BaseModel]:
""" """
Retrieves the loss function parameter class (or classes) by name (case-insensitive). Retrieves the loss function parameter class (or classes) by name (case-insensitive).
@ -81,7 +81,7 @@ class CriterionRegistry:
name (str): Name of the loss function. name (str): Name of the loss function.
Returns: Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The loss function parameter class or a tuple of parameter classes. Type[BaseModel]: The loss function parameter class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["params"] return entry["params"]

@ -1,6 +1,7 @@
import abc import abc
import torch import torch
import torch.nn as nn import torch.nn as nn
from pydantic import BaseModel
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from monai.metrics.cumulative_average import CumulativeAverage from monai.metrics.cumulative_average import CumulativeAverage
@ -8,9 +9,7 @@ from monai.metrics.cumulative_average import CumulativeAverage
class BaseLoss(abc.ABC): class BaseLoss(abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self): def __init__(self, params: Optional[BaseModel] = None):
"""
"""
super().__init__() super().__init__()

@ -9,27 +9,26 @@ class BCELossParams(BaseModel):
""" """
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
with_logits: bool = False
weight: Optional[List[Union[int, float]]] = None # Sample weights weight: Optional[List[Union[int, float]]] = None # Sample weights
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss
def asdict(self, with_logits: bool = False) -> Dict[str, Any]: def asdict(self) -> Dict[str, Any]:
""" """
Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`. Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`.
- If `with_logits=False`, `pos_weight` is **removed** to avoid errors. - If `with_logits=False`, `pos_weight` is **removed** to avoid errors.
- Ensures only the valid parameters are passed based on the loss function. - Ensures only the valid parameters are passed based on the loss function.
Args:
with_logits (bool): If `True`, includes `pos_weight` (for `nn.BCEWithLogitsLoss`).
If `False`, removes `pos_weight` (for `nn.BCELoss`).
Returns: Returns:
Dict[str, Any]: Filtered dictionary of parameters. Dict[str, Any]: Filtered dictionary of parameters.
""" """
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
if not with_logits: if not self.with_logits:
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
loss_kwargs.pop("with_logits", None)
weight = loss_kwargs.get("weight") weight = loss_kwargs.get("weight")
pos_weight = loss_kwargs.get("pos_weight") pos_weight = loss_kwargs.get("pos_weight")
@ -48,15 +47,16 @@ class BCELoss(BaseLoss):
Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics. Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics.
""" """
def __init__(self, bce_params: Optional[BCELossParams] = None, with_logits: bool = False): def __init__(self, params: Optional[BCELossParams] = None):
""" """
Initializes the loss function with optional BCELoss parameters. Initializes the loss function with optional BCELoss parameters.
Args: Args:
bce_params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None). params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None).
""" """
super().__init__() super().__init__(params=params)
_bce_params = bce_params.asdict(with_logits=with_logits) if bce_params is not None else {} with_logits = params.with_logits if params is not None else False
_bce_params = params.asdict() if params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults # Initialize loss functions with user-provided parameters or PyTorch defaults
self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params) self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params)

@ -36,15 +36,15 @@ class CrossEntropyLoss(BaseLoss):
Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics. Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics.
""" """
def __init__(self, ce_params: Optional[CrossEntropyLossParams] = None): def __init__(self, params: Optional[CrossEntropyLossParams] = None):
""" """
Initializes the loss function with optional CrossEntropyLoss parameters. Initializes the loss function with optional CrossEntropyLoss parameters.
Args: Args:
ce_params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None). params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None).
""" """
super().__init__() super().__init__(params=params)
_ce_params = ce_params.asdict() if ce_params is not None else {} _ce_params = params.asdict() if params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults # Initialize loss functions with user-provided parameters or PyTorch defaults
self.ce_loss = nn.CrossEntropyLoss(**_ce_params) self.ce_loss = nn.CrossEntropyLoss(**_ce_params)

@ -27,15 +27,15 @@ class MSELoss(BaseLoss):
Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics. Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics.
""" """
def __init__(self, mse_params: Optional[MSELossParams] = None): def __init__(self, params: Optional[MSELossParams] = None):
""" """
Initializes the loss function with optional MSELoss parameters. Initializes the loss function with optional MSELoss parameters.
Args: Args:
mse_params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None). params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None).
""" """
super().__init__() super().__init__(params=params)
_mse_params = mse_params.asdict() if mse_params is not None else {} _mse_params = params.asdict() if params is not None else {}
# Initialize MSE loss with user-provided parameters or PyTorch defaults # Initialize MSE loss with user-provided parameters or PyTorch defaults
self.mse_loss = nn.MSELoss(**_mse_params) self.mse_loss = nn.MSELoss(**_mse_params)

@ -2,42 +2,59 @@ from .base import *
from .bce import BCELossParams from .bce import BCELossParams
from .mse import MSELossParams from .mse import MSELossParams
from pydantic import BaseModel, ConfigDict
class BCE_MSE_LossParams(BaseModel):
"""
Class for handling parameters for `nn.MSELoss` with `nn.BCELoss`.
"""
model_config = ConfigDict(frozen=True)
num_classes: int = 1
bce_params: BCELossParams = BCELossParams()
mse_params: MSELossParams = MSELossParams()
def asdict(self) -> Dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.BCELoss` and `nn.MSELoss`.
Returns:
Dict[str, Any]: Dictionary of parameters.
"""
return {
"num_classes": self.num_classes,
"bce_params": self.bce_params.asdict(),
"mse_params": self.mse_params.asdict()
}
class BCE_MSE_Loss(BaseLoss): class BCE_MSE_Loss(BaseLoss):
""" """
Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction. Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction.
""" """
def __init__( def __init__(self, params: Optional[BCE_MSE_LossParams]):
self,
num_classes: int,
bce_params: Optional[BCELossParams] = None,
mse_params: Optional[MSELossParams] = None,
bce_with_logits: bool = False,
):
""" """
Initializes the loss function with optional BCE and MSE parameters. Initializes the loss function with optional BCE and MSE parameters.
Args:
num_classes (int): Number of output classes, used for target shifting.
bce_params (Optional[BCELossParams]): Parameters for BCEWithLogitsLoss or BCELoss (default: None).
mse_params (Optional[MSELossParams]): Parameters for MSELoss (default: None).
bce_with_logits (bool): If True, uses BCEWithLogitsLoss; otherwise, uses BCELoss.
""" """
super().__init__() super().__init__(params=params)
_params = params if params is not None else BCE_MSE_LossParams()
self.num_classes = num_classes self.num_classes = _params.num_classes
# Process BCE parameters # Process BCE parameters
_bce_params = bce_params.asdict(bce_with_logits) if bce_params is not None else {} _bce_params = _params.bce_params.asdict()
# Choose BCE loss function # Choose BCE loss function
self.bce_loss = ( self.bce_loss = (
nn.BCEWithLogitsLoss(**_bce_params) if bce_with_logits else nn.BCELoss(**_bce_params) nn.BCEWithLogitsLoss(**_bce_params) if _params.bce_params.with_logits else nn.BCELoss(**_bce_params)
) )
# Process MSE parameters # Process MSE parameters
_mse_params = mse_params.asdict() if mse_params is not None else {} _mse_params = _params.mse_params.asdict()
# Initialize MSE loss # Initialize MSE loss
self.mse_loss = nn.MSELoss(**_mse_params) self.mse_loss = nn.MSELoss(**_mse_params)

@ -61,7 +61,7 @@ class ModelRegistry:
return entry["class"] return entry["class"]
@classmethod @classmethod
def get_model_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: def get_model_params(cls, name: str) -> Type[BaseModel]:
""" """
Retrieves the model parameter class by name (case-insensitive). Retrieves the model parameter class by name (case-insensitive).
@ -69,7 +69,7 @@ class ModelRegistry:
name (str): Name of the model. name (str): Name of the model.
Returns: Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The model parameter class or a tuple of parameter classes. Type[BaseModel]: The model parameter class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["params"] return entry["params"]

@ -1,14 +1,15 @@
import torch.optim as optim
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, Final, Tuple, Type, List, Any, Union from typing import Dict, Final, Tuple, Type, List, Any, Union
from .adam import AdamParams from .base import BaseOptimizer
from .adamw import AdamWParams from .adam import AdamParams, AdamOptimizer
from .sgd import SGDParams from .adamw import AdamWParams, AdamWOptimizer
from .sgd import SGDParams, SGDOptimizer
__all__ = [ __all__ = [
"OptimizerRegistry", "OptimizerRegistry", "BaseOptimizer",
"AdamParams", "AdamWParams", "SGDParams" "AdamParams", "AdamWParams", "SGDParams",
"AdamOptimizer", "AdamWOptimizer", "SGDOptimizer"
] ]
class OptimizerRegistry: class OptimizerRegistry:
@ -17,15 +18,15 @@ class OptimizerRegistry:
# Single dictionary storing both optimizer classes and parameter classes. # Single dictionary storing both optimizer classes and parameter classes.
__OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = { __OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
"SGD": { "SGD": {
"class": optim.SGD, "class": SGDOptimizer,
"params": SGDParams, "params": SGDParams,
}, },
"Adam": { "Adam": {
"class": optim.Adam, "class": AdamOptimizer,
"params": AdamParams, "params": AdamParams,
}, },
"AdamW": { "AdamW": {
"class": optim.AdamW, "class": AdamWOptimizer,
"params": AdamWParams, "params": AdamWParams,
}, },
} }
@ -54,7 +55,7 @@ class OptimizerRegistry:
return cls.__OPTIMIZERS[original_key] return cls.__OPTIMIZERS[original_key]
@classmethod @classmethod
def get_optimizer_class(cls, name: str) -> Type[optim.Optimizer]: def get_optimizer_class(cls, name: str) -> Type[BaseOptimizer]:
""" """
Retrieves the optimizer class by name (case-insensitive). Retrieves the optimizer class by name (case-insensitive).
@ -62,13 +63,13 @@ class OptimizerRegistry:
name (str): Name of the optimizer. name (str): Name of the optimizer.
Returns: Returns:
Type[optim.Optimizer]: The optimizer class. Type[BaseOptimizer]: The optimizer class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["class"] return entry["class"]
@classmethod @classmethod
def get_optimizer_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: def get_optimizer_params(cls, name: str) -> Type[BaseModel]:
""" """
Retrieves the optimizer parameter class by name (case-insensitive). Retrieves the optimizer parameter class by name (case-insensitive).
@ -76,7 +77,7 @@ class OptimizerRegistry:
name (str): Name of the optimizer. name (str): Name of the optimizer.
Returns: Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The optimizer parameter class or a tuple of parameter classes. Type[BaseModel]: The optimizer parameter class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["params"] return entry["params"]

@ -1,6 +1,9 @@
from typing import Any, Dict, Tuple import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
class AdamParams(BaseModel): class AdamParams(BaseModel):
"""Configuration for `torch.optim.Adam` optimizer.""" """Configuration for `torch.optim.Adam` optimizer."""
@ -15,3 +18,20 @@ class AdamParams(BaseModel):
def asdict(self) -> Dict[str, Any]: def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.Adam`.""" """Returns a dictionary of valid parameters for `torch.optim.Adam`."""
return self.model_dump() return self.model_dump()
class AdamOptimizer(BaseOptimizer):
"""
Wrapper around torch.optim.Adam.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams):
"""
Initializes the Adam optimizer with given parameters.
Args:
model_params (Iterable[Parameter]): Parameters to optimize.
optim_params (AdamParams): Optimizer parameters.
"""
super().__init__(model_params, optim_params)
self.optim = optim.Adam(model_params, **optim_params.asdict())

@ -1,6 +1,10 @@
from typing import Any, Dict, Tuple import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
class AdamWParams(BaseModel): class AdamWParams(BaseModel):
"""Configuration for `torch.optim.AdamW` optimizer.""" """Configuration for `torch.optim.AdamW` optimizer."""
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
@ -14,3 +18,20 @@ class AdamWParams(BaseModel):
def asdict(self) -> Dict[str, Any]: def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.AdamW`.""" """Returns a dictionary of valid parameters for `torch.optim.AdamW`."""
return self.model_dump() return self.model_dump()
class AdamWOptimizer(BaseOptimizer):
"""
Wrapper around torch.optim.AdamW.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams):
"""
Initializes the AdamW optimizer with given parameters.
Args:
model_params (Iterable[Parameter]): Parameters to optimize.
optim_params (AdamWParams): Optimizer parameters.
"""
super().__init__(model_params, optim_params)
self.optim = optim.AdamW(model_params, **optim_params.asdict())

@ -0,0 +1,40 @@
import torch
import torch.optim as optim
from pydantic import BaseModel
from typing import Any, Iterable, Optional
class BaseOptimizer:
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel):
super().__init__()
self.optim: Optional[optim.Optimizer] = None
def zero_grad(self, set_to_none: bool = True) -> None:
"""
Clears the gradients of all optimized tensors.
Args:
set_to_none (bool): If True, sets gradients to None instead of zero.
This can reduce memory usage and improve performance.
(Introduced in PyTorch 1.7+)
"""
if self.optim is not None:
self.optim.zero_grad(set_to_none=set_to_none)
def step(self, closure: Optional[Any] = None) -> Any:
"""
Performs a single optimization step (parameter update).
Args:
closure (Optional[Callable]): A closure that reevaluates the model and returns the loss.
This is required for optimizers like LBFGS that need multiple forward passes.
Returns:
Any: The return value depends on the specific optimizer implementation.
"""
if self.optim is not None:
return self.optim.step(closure=closure)

@ -1,6 +1,10 @@
from typing import Any, Dict import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
class SGDParams(BaseModel): class SGDParams(BaseModel):
"""Configuration for `torch.optim.SGD` optimizer.""" """Configuration for `torch.optim.SGD` optimizer."""
@ -15,3 +19,20 @@ class SGDParams(BaseModel):
def asdict(self) -> Dict[str, Any]: def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.SGD`.""" """Returns a dictionary of valid parameters for `torch.optim.SGD`."""
return self.model_dump() return self.model_dump()
class SGDOptimizer(BaseOptimizer):
"""
Wrapper around torch.optim.SGD.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams):
"""
Initializes the SGD optimizer with given parameters.
Args:
model_params (Iterable[Parameter]): Parameters to optimize.
optim_params (SGDParams): Optimizer parameters.
"""
super().__init__(model_params, optim_params)
self.optim = optim.SGD(model_params, **optim_params.asdict())

@ -2,14 +2,16 @@ import torch.optim.lr_scheduler as lr_scheduler
from typing import Dict, Final, Tuple, Type, List, Any, Union from typing import Dict, Final, Tuple, Type, List, Any, Union
from pydantic import BaseModel from pydantic import BaseModel
from .step import StepLRParams from .base import BaseScheduler
from .multi_step import MultiStepLRParams from .step import StepLRParams, StepLRScheduler
from .exponential import ExponentialLRParams from .multi_step import MultiStepLRParams, MultiStepLRScheduler
from .cosine_annealing import CosineAnnealingLRParams from .exponential import ExponentialLRParams, ExponentialLRScheduler
from .cosine_annealing import CosineAnnealingLRParams, CosineAnnealingLRScheduler
__all__ = [ __all__ = [
"SchedulerRegistry", "SchedulerRegistry", "BaseScheduler",
"StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams" "StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams",
"StepLRScheduler", "MultiStepLRScheduler", "ExponentialLRScheduler", "CosineAnnealingLRScheduler"
] ]
class SchedulerRegistry: class SchedulerRegistry:
@ -17,19 +19,19 @@ class SchedulerRegistry:
__SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = { __SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
"Step": { "Step": {
"class": lr_scheduler.StepLR, "class": StepLRScheduler,
"params": StepLRParams, "params": StepLRParams,
}, },
"Exponential": { "Exponential": {
"class": lr_scheduler.ExponentialLR, "class": ExponentialLRScheduler,
"params": ExponentialLRParams, "params": ExponentialLRParams,
}, },
"MultiStep": { "MultiStep": {
"class": lr_scheduler.MultiStepLR, "class": MultiStepLRScheduler,
"params": MultiStepLRParams, "params": MultiStepLRParams,
}, },
"CosineAnnealing": { "CosineAnnealing": {
"class": lr_scheduler.CosineAnnealingLR, "class": CosineAnnealingLRScheduler,
"params": CosineAnnealingLRParams, "params": CosineAnnealingLRParams,
}, },
} }
@ -58,7 +60,7 @@ class SchedulerRegistry:
return cls.__SCHEDULERS[original_key] return cls.__SCHEDULERS[original_key]
@classmethod @classmethod
def get_scheduler_class(cls, name: str) -> Type[lr_scheduler.LRScheduler]: def get_scheduler_class(cls, name: str) -> Type[BaseScheduler]:
""" """
Retrieves the scheduler class by name (case-insensitive). Retrieves the scheduler class by name (case-insensitive).
@ -66,13 +68,13 @@ class SchedulerRegistry:
name (str): Name of the scheduler. name (str): Name of the scheduler.
Returns: Returns:
Type[lr_scheduler.LRScheduler]: The scheduler class. Type[BaseScheduler]: The scheduler class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["class"] return entry["class"]
@classmethod @classmethod
def get_scheduler_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: def get_scheduler_params(cls, name: str) -> Type[BaseModel]:
""" """
Retrieves the scheduler parameter class by name (case-insensitive). Retrieves the scheduler parameter class by name (case-insensitive).
@ -80,7 +82,7 @@ class SchedulerRegistry:
name (str): Name of the scheduler. name (str): Name of the scheduler.
Returns: Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The scheduler parameter class or a tuple of parameter classes. Type[BaseModel]: The scheduler parameter class.
""" """
entry = cls.__get_entry(name) entry = cls.__get_entry(name)
return entry["params"] return entry["params"]

@ -0,0 +1,27 @@
import torch.optim as optim
from pydantic import BaseModel
from typing import List, Optional
class BaseScheduler:
"""
Abstract base class for learning rate schedulers.
Wraps a PyTorch LR scheduler and provides a unified interface.
"""
def __init__(self, optimizer: optim.Optimizer, params: BaseModel):
self.scheduler: Optional[optim.lr_scheduler.LRScheduler] = None
def step(self) -> None:
"""
Performs a single scheduler step. This typically updates the learning rate
based on the current epoch or step count.
"""
if self.scheduler is not None:
self.scheduler.step()
def get_last_lr(self) -> List[float]:
"""
Returns the most recent learning rate(s).
"""
return self.scheduler.get_last_lr() if self.scheduler else []

@ -1,5 +1,10 @@
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from .base import BaseScheduler
class CosineAnnealingLRParams(BaseModel): class CosineAnnealingLRParams(BaseModel):
@ -8,7 +13,24 @@ class CosineAnnealingLRParams(BaseModel):
T_max: int = 100 # Maximum number of iterations T_max: int = 100 # Maximum number of iterations
eta_min: float = 0.0 # Minimum learning rate eta_min: float = 0.0 # Minimum learning rate
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]: def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`.""" """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`."""
return self.model_dump() return self.model_dump()
class CosineAnnealingLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.CosineAnnealingLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams):
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (CosineAnnealingLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = CosineAnnealingLR(optimizer, **params.asdict())

@ -1,5 +1,9 @@
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import ExponentialLR
from .base import BaseScheduler
class ExponentialLRParams(BaseModel): class ExponentialLRParams(BaseModel):
@ -7,7 +11,23 @@ class ExponentialLRParams(BaseModel):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
gamma: float = 0.95 # Multiplicative factor of learning rate decay gamma: float = 0.95 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]: def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`.""" """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`."""
return self.model_dump() return self.model_dump()
class ExponentialLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.ExponentialLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams):
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (ExponentialLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = ExponentialLR(optimizer, **params.asdict())

@ -1,5 +1,9 @@
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from .base import BaseScheduler
class MultiStepLRParams(BaseModel): class MultiStepLRParams(BaseModel):
@ -8,7 +12,23 @@ class MultiStepLRParams(BaseModel):
milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]: def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`.""" """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`."""
return self.model_dump() return self.model_dump()
class MultiStepLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.MultiStepLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams):
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (MultiStepLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = MultiStepLR(optimizer, **params.asdict())

@ -1,5 +1,9 @@
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import StepLR
from .base import BaseScheduler
class StepLRParams(BaseModel): class StepLRParams(BaseModel):
@ -8,7 +12,24 @@ class StepLRParams(BaseModel):
step_size: int = 30 # Period of learning rate decay step_size: int = 30 # Period of learning rate decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]: def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`.""" """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`."""
return self.model_dump() return self.model_dump()
class StepLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.StepLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: StepLRParams):
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (StepLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = StepLR(optimizer, **params.asdict())

@ -1,8 +1,7 @@
import os import os
from pydantic import BaseModel from typing import Tuple
from typing import Any, Dict, Tuple, Type, Union, List
from config.config import Config from config.config import *
from config.dataset_config import DatasetConfig from config.dataset_config import DatasetConfig
from core import ( from core import (
@ -10,18 +9,6 @@ from core import (
) )
def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]:
"""
Instantiates the parameter class(es) with default values.
If 'param' is a tuple, instantiate each class and return a list of instances.
Otherwise, instantiate the single class and return the instance.
"""
if isinstance(param, tuple):
return [cls() for cls in param]
else:
return param()
def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str: def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str:
""" """
Prompt the user with a list of options and return the selected option. Prompt the user with a list of options and return the selected option.
@ -55,11 +42,11 @@ def main():
model_options = ModelRegistry.get_available_models() model_options = ModelRegistry.get_available_models()
chosen_model = prompt_choice("\nSelect a model:", model_options) chosen_model = prompt_choice("\nSelect a model:", model_options)
model_param_class = ModelRegistry.get_model_params(chosen_model) model_param_class = ModelRegistry.get_model_params(chosen_model)
model_instance = instantiate_params(model_param_class) model_instance = model_param_class()
if is_training is False: if is_training is False:
config = Config( config = Config(
model={chosen_model: model_instance}, model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config dataset_config=dataset_config
) )
@ -71,27 +58,27 @@ def main():
criterion_options = CriterionRegistry.get_available_criterions() criterion_options = CriterionRegistry.get_available_criterions()
chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options) chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options)
criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion) criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion)
criterion_instance = instantiate_params(criterion_param_class) criterion_instance = criterion_param_class()
# Prompt the user to select an optimizer. # Prompt the user to select an optimizer.
optimizer_options = OptimizerRegistry.get_available_optimizers() optimizer_options = OptimizerRegistry.get_available_optimizers()
chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options) chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options)
optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer) optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer)
optimizer_instance = instantiate_params(optimizer_param_class) optimizer_instance = optimizer_param_class()
# Prompt the user to select a scheduler. # Prompt the user to select a scheduler.
scheduler_options = SchedulerRegistry.get_available_schedulers() scheduler_options = SchedulerRegistry.get_available_schedulers()
chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options) chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options)
scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler) scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler)
scheduler_instance = instantiate_params(scheduler_param_class) scheduler_instance = scheduler_param_class()
# Assemble the overall configuration using the registry names as keys. # Assemble the overall configuration using the registry names as keys.
config = Config( config = Config(
model={chosen_model: model_instance}, model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config, dataset_config=dataset_config,
criterion={chosen_criterion: criterion_instance}, criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance),
optimizer={chosen_optimizer: optimizer_instance}, optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance),
scheduler={chosen_scheduler: scheduler_instance} scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance)
) )
# Construct a base filename from the selected registry names. # Construct a base filename from the selected registry names.

@ -8,7 +8,3 @@ pprint(config, indent=4)
print('\n\n') print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json') config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json')
pprint(config, indent=4) pprint(config, indent=4)
print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV_1.json')
pprint(config, indent=4)
Loading…
Cancel
Save