From 5db22209170aa7793ac65de90fc67880f1f7e709 Mon Sep 17 00:00:00 2001 From: laynholt Date: Wed, 16 Apr 2025 12:40:36 +0000 Subject: [PATCH] 1) Removed support for parameter lists; 2) added wrappers for optimizer and schedule; 3) changed parameters for losses. --- config/config.py | 76 +++++++++++++++++------------ config/dataset_config.py | 4 ++ core/losses/__init__.py | 12 ++--- core/losses/base.py | 5 +- core/losses/bce.py | 20 ++++---- core/losses/ce.py | 8 +-- core/losses/mse.py | 8 +-- core/losses/mse_with_bce.py | 53 +++++++++++++------- core/models/__init__.py | 4 +- core/optimizers/__init__.py | 27 +++++----- core/optimizers/adam.py | 24 ++++++++- core/optimizers/adamw.py | 25 +++++++++- core/optimizers/base.py | 40 +++++++++++++++ core/optimizers/sgd.py | 25 +++++++++- core/schedulers/__init__.py | 30 ++++++------ core/schedulers/base.py | 27 ++++++++++ core/schedulers/cosine_annealing.py | 22 +++++++++ core/schedulers/exponential.py | 22 ++++++++- core/schedulers/multi_step.py | 22 ++++++++- core/schedulers/step.py | 23 ++++++++- generate_config.py | 35 +++++-------- train.py | 4 -- 22 files changed, 374 insertions(+), 142 deletions(-) create mode 100644 core/optimizers/base.py create mode 100644 core/schedulers/base.py diff --git a/config/config.py b/config/config.py index 6971ea6..6f8b4e0 100644 --- a/config/config.py +++ b/config/config.py @@ -1,29 +1,41 @@ import json -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional from pydantic import BaseModel from .dataset_config import DatasetConfig -class Config(BaseModel): - model: Dict[str, Union[BaseModel, List[BaseModel]]] - dataset_config: DatasetConfig - criterion: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None - optimizer: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None - scheduler: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None - @staticmethod - def __dump_field(value: Any) -> Any: +__all__ = ["Config", "ComponentConfig"] + + +class ComponentConfig(BaseModel): + name: str + params: BaseModel + + def dump(self) -> Dict[str, Any]: """ - Recursively dumps a field if it is a BaseModel or a list/dict of BaseModels. + Recursively serializes the component into a dictionary. + + Returns: + dict: A dictionary containing the component name and its serialized parameters. """ - if isinstance(value, BaseModel): - return value.model_dump() - elif isinstance(value, list): - return [Config.__dump_field(item) for item in value] - elif isinstance(value, dict): - return {k: Config.__dump_field(v) for k, v in value.items()} + if isinstance(self.params, BaseModel): + params_dump = self.params.model_dump() else: - return value + params_dump = self.params + return { + "name": self.name, + "params": params_dump + } + + + +class Config(BaseModel): + model: ComponentConfig + dataset_config: DatasetConfig + criterion: Optional[ComponentConfig] = None + optimizer: Optional[ComponentConfig] = None + scheduler: Optional[ComponentConfig] = None def save_json(self, file_path: str, indent: int = 4) -> None: """ @@ -34,15 +46,15 @@ class Config(BaseModel): indent (int): Indentation level for the JSON file. """ config_dump = { - "model": self.__dump_field(self.model), + "model": self.model.dump(), "dataset_config": self.dataset_config.model_dump() } if self.criterion is not None: - config_dump.update({"criterion": self.__dump_field(self.criterion)}) + config_dump["criterion"] = self.criterion.dump() if self.optimizer is not None: - config_dump.update({"optimizer": self.__dump_field(self.optimizer)}) + config_dump["optimizer"] = self.optimizer.dump() if self.scheduler is not None: - config_dump.update({"scheduler": self.__dump_field(self.scheduler)}) + config_dump["scheduler"] = self.scheduler.dump() with open(file_path, "w", encoding="utf-8") as f: f.write(json.dumps(config_dump, indent=indent)) @@ -67,16 +79,15 @@ class Config(BaseModel): dataset_config = DatasetConfig(**data.get("dataset_config", {})) # Helper function to parse registry fields. - def parse_field(field_data: Dict[str, Any], registry_getter) -> Dict[str, Union[BaseModel, List[BaseModel]]]: - result = {} - for key, value in field_data.items(): - expected = registry_getter(key) - # If the registry returns a tuple, then we expect a list of dictionaries. - if isinstance(expected, tuple): - result[key] = [cls_param(**item) for cls_param, item in zip(expected, value)] - else: - result[key] = expected(**value) - return result + def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]: + name = component_data.get("name") + params_data = component_data.get("params", {}) + + if name is not None: + expected = registry_getter(name) + params = expected(**params_data) + return ComponentConfig(name=name, params=params) + return None from core import ( ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry @@ -87,6 +98,9 @@ class Config(BaseModel): parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key)) parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key)) + if parsed_model is None: + raise ValueError('Failed to load model information') + return cls( model=parsed_model, dataset_config=dataset_config, diff --git a/config/dataset_config.py b/config/dataset_config.py index 8ef5901..393d313 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -98,6 +98,10 @@ class DatasetTrainingConfig(BaseModel): - If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist). - If is_split is False, validates split (all_data_dir must be non-empty and exist). """ + if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)): + if (self.train_size + self.valid_size + self.test_size) > 1: + raise ValueError("The total sample size with dynamically defined sizes must be <= 1") + if not self.is_split: if not self.split.all_data_dir: raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split") diff --git a/core/losses/__init__.py b/core/losses/__init__.py index d8e1cdc..203d9ba 100644 --- a/core/losses/__init__.py +++ b/core/losses/__init__.py @@ -5,12 +5,12 @@ from .base import BaseLoss from .ce import CrossEntropyLoss, CrossEntropyLossParams from .bce import BCELoss, BCELossParams from .mse import MSELoss, MSELossParams -from .mse_with_bce import BCE_MSE_Loss +from .mse_with_bce import BCE_MSE_Loss, BCE_MSE_LossParams __all__ = [ - "CriterionRegistry", + "CriterionRegistry", "BaseLoss", "CrossEntropyLoss", "BCELoss", "MSELoss", "BCE_MSE_Loss", - "CrossEntropyLossParams", "BCELossParams", "MSELossParams" + "CrossEntropyLossParams", "BCELossParams", "MSELossParams", "BCE_MSE_LossParams" ] class CriterionRegistry: @@ -31,7 +31,7 @@ class CriterionRegistry: }, "BCE_MSE_Loss": { "class": BCE_MSE_Loss, - "params": (BCELossParams, MSELossParams), + "params": BCE_MSE_LossParams, }, } @@ -73,7 +73,7 @@ class CriterionRegistry: return entry["class"] @classmethod - def get_criterion_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: + def get_criterion_params(cls, name: str) -> Type[BaseModel]: """ Retrieves the loss function parameter class (or classes) by name (case-insensitive). @@ -81,7 +81,7 @@ class CriterionRegistry: name (str): Name of the loss function. Returns: - Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The loss function parameter class or a tuple of parameter classes. + Type[BaseModel]: The loss function parameter class. """ entry = cls.__get_entry(name) return entry["params"] diff --git a/core/losses/base.py b/core/losses/base.py index 4d8196f..65ddbad 100644 --- a/core/losses/base.py +++ b/core/losses/base.py @@ -1,6 +1,7 @@ import abc import torch import torch.nn as nn +from pydantic import BaseModel from typing import Dict, Any, Optional from monai.metrics.cumulative_average import CumulativeAverage @@ -8,9 +9,7 @@ from monai.metrics.cumulative_average import CumulativeAverage class BaseLoss(abc.ABC): """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" - def __init__(self): - """ - """ + def __init__(self, params: Optional[BaseModel] = None): super().__init__() diff --git a/core/losses/bce.py b/core/losses/bce.py index 451fb43..32b2a96 100644 --- a/core/losses/bce.py +++ b/core/losses/bce.py @@ -9,27 +9,26 @@ class BCELossParams(BaseModel): """ model_config = ConfigDict(frozen=True) + with_logits: bool = False + weight: Optional[List[Union[int, float]]] = None # Sample weights reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss - def asdict(self, with_logits: bool = False) -> Dict[str, Any]: + def asdict(self) -> Dict[str, Any]: """ Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`. - If `with_logits=False`, `pos_weight` is **removed** to avoid errors. - Ensures only the valid parameters are passed based on the loss function. - Args: - with_logits (bool): If `True`, includes `pos_weight` (for `nn.BCEWithLogitsLoss`). - If `False`, removes `pos_weight` (for `nn.BCELoss`). - Returns: Dict[str, Any]: Filtered dictionary of parameters. """ loss_kwargs = self.model_dump() - if not with_logits: + if not self.with_logits: loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss + loss_kwargs.pop("with_logits", None) weight = loss_kwargs.get("weight") pos_weight = loss_kwargs.get("pos_weight") @@ -48,15 +47,16 @@ class BCELoss(BaseLoss): Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics. """ - def __init__(self, bce_params: Optional[BCELossParams] = None, with_logits: bool = False): + def __init__(self, params: Optional[BCELossParams] = None): """ Initializes the loss function with optional BCELoss parameters. Args: - bce_params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None). + params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None). """ - super().__init__() - _bce_params = bce_params.asdict(with_logits=with_logits) if bce_params is not None else {} + super().__init__(params=params) + with_logits = params.with_logits if params is not None else False + _bce_params = params.asdict() if params is not None else {} # Initialize loss functions with user-provided parameters or PyTorch defaults self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params) diff --git a/core/losses/ce.py b/core/losses/ce.py index 00e641b..f1e2db7 100644 --- a/core/losses/ce.py +++ b/core/losses/ce.py @@ -36,15 +36,15 @@ class CrossEntropyLoss(BaseLoss): Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics. """ - def __init__(self, ce_params: Optional[CrossEntropyLossParams] = None): + def __init__(self, params: Optional[CrossEntropyLossParams] = None): """ Initializes the loss function with optional CrossEntropyLoss parameters. Args: - ce_params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None). + params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None). """ - super().__init__() - _ce_params = ce_params.asdict() if ce_params is not None else {} + super().__init__(params=params) + _ce_params = params.asdict() if params is not None else {} # Initialize loss functions with user-provided parameters or PyTorch defaults self.ce_loss = nn.CrossEntropyLoss(**_ce_params) diff --git a/core/losses/mse.py b/core/losses/mse.py index df9fd13..2b5ec98 100644 --- a/core/losses/mse.py +++ b/core/losses/mse.py @@ -27,15 +27,15 @@ class MSELoss(BaseLoss): Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics. """ - def __init__(self, mse_params: Optional[MSELossParams] = None): + def __init__(self, params: Optional[MSELossParams] = None): """ Initializes the loss function with optional MSELoss parameters. Args: - mse_params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None). + params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None). """ - super().__init__() - _mse_params = mse_params.asdict() if mse_params is not None else {} + super().__init__(params=params) + _mse_params = params.asdict() if params is not None else {} # Initialize MSE loss with user-provided parameters or PyTorch defaults self.mse_loss = nn.MSELoss(**_mse_params) diff --git a/core/losses/mse_with_bce.py b/core/losses/mse_with_bce.py index f36cf6f..90ecb44 100644 --- a/core/losses/mse_with_bce.py +++ b/core/losses/mse_with_bce.py @@ -2,42 +2,59 @@ from .base import * from .bce import BCELossParams from .mse import MSELossParams +from pydantic import BaseModel, ConfigDict + + +class BCE_MSE_LossParams(BaseModel): + """ + Class for handling parameters for `nn.MSELoss` with `nn.BCELoss`. + """ + model_config = ConfigDict(frozen=True) + + num_classes: int = 1 + bce_params: BCELossParams = BCELossParams() + mse_params: MSELossParams = MSELossParams() + + def asdict(self) -> Dict[str, Any]: + """ + Returns a dictionary of valid parameters for `nn.BCELoss` and `nn.MSELoss`. + + Returns: + Dict[str, Any]: Dictionary of parameters. + """ + + return { + "num_classes": self.num_classes, + "bce_params": self.bce_params.asdict(), + "mse_params": self.mse_params.asdict() + } + class BCE_MSE_Loss(BaseLoss): """ Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction. """ - def __init__( - self, - num_classes: int, - bce_params: Optional[BCELossParams] = None, - mse_params: Optional[MSELossParams] = None, - bce_with_logits: bool = False, - ): + def __init__(self, params: Optional[BCE_MSE_LossParams]): """ Initializes the loss function with optional BCE and MSE parameters. - - Args: - num_classes (int): Number of output classes, used for target shifting. - bce_params (Optional[BCELossParams]): Parameters for BCEWithLogitsLoss or BCELoss (default: None). - mse_params (Optional[MSELossParams]): Parameters for MSELoss (default: None). - bce_with_logits (bool): If True, uses BCEWithLogitsLoss; otherwise, uses BCELoss. """ - super().__init__() + super().__init__(params=params) + + _params = params if params is not None else BCE_MSE_LossParams() - self.num_classes = num_classes + self.num_classes = _params.num_classes # Process BCE parameters - _bce_params = bce_params.asdict(bce_with_logits) if bce_params is not None else {} + _bce_params = _params.bce_params.asdict() # Choose BCE loss function self.bce_loss = ( - nn.BCEWithLogitsLoss(**_bce_params) if bce_with_logits else nn.BCELoss(**_bce_params) + nn.BCEWithLogitsLoss(**_bce_params) if _params.bce_params.with_logits else nn.BCELoss(**_bce_params) ) # Process MSE parameters - _mse_params = mse_params.asdict() if mse_params is not None else {} + _mse_params = _params.mse_params.asdict() # Initialize MSE loss self.mse_loss = nn.MSELoss(**_mse_params) diff --git a/core/models/__init__.py b/core/models/__init__.py index 75ba2a8..c39baa5 100644 --- a/core/models/__init__.py +++ b/core/models/__init__.py @@ -61,7 +61,7 @@ class ModelRegistry: return entry["class"] @classmethod - def get_model_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: + def get_model_params(cls, name: str) -> Type[BaseModel]: """ Retrieves the model parameter class by name (case-insensitive). @@ -69,7 +69,7 @@ class ModelRegistry: name (str): Name of the model. Returns: - Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The model parameter class or a tuple of parameter classes. + Type[BaseModel]: The model parameter class. """ entry = cls.__get_entry(name) return entry["params"] diff --git a/core/optimizers/__init__.py b/core/optimizers/__init__.py index 2841b3c..f3a80d0 100644 --- a/core/optimizers/__init__.py +++ b/core/optimizers/__init__.py @@ -1,14 +1,15 @@ -import torch.optim as optim from pydantic import BaseModel from typing import Dict, Final, Tuple, Type, List, Any, Union -from .adam import AdamParams -from .adamw import AdamWParams -from .sgd import SGDParams +from .base import BaseOptimizer +from .adam import AdamParams, AdamOptimizer +from .adamw import AdamWParams, AdamWOptimizer +from .sgd import SGDParams, SGDOptimizer __all__ = [ - "OptimizerRegistry", - "AdamParams", "AdamWParams", "SGDParams" + "OptimizerRegistry", "BaseOptimizer", + "AdamParams", "AdamWParams", "SGDParams", + "AdamOptimizer", "AdamWOptimizer", "SGDOptimizer" ] class OptimizerRegistry: @@ -17,15 +18,15 @@ class OptimizerRegistry: # Single dictionary storing both optimizer classes and parameter classes. __OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = { "SGD": { - "class": optim.SGD, + "class": SGDOptimizer, "params": SGDParams, }, "Adam": { - "class": optim.Adam, + "class": AdamOptimizer, "params": AdamParams, }, "AdamW": { - "class": optim.AdamW, + "class": AdamWOptimizer, "params": AdamWParams, }, } @@ -54,7 +55,7 @@ class OptimizerRegistry: return cls.__OPTIMIZERS[original_key] @classmethod - def get_optimizer_class(cls, name: str) -> Type[optim.Optimizer]: + def get_optimizer_class(cls, name: str) -> Type[BaseOptimizer]: """ Retrieves the optimizer class by name (case-insensitive). @@ -62,13 +63,13 @@ class OptimizerRegistry: name (str): Name of the optimizer. Returns: - Type[optim.Optimizer]: The optimizer class. + Type[BaseOptimizer]: The optimizer class. """ entry = cls.__get_entry(name) return entry["class"] @classmethod - def get_optimizer_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: + def get_optimizer_params(cls, name: str) -> Type[BaseModel]: """ Retrieves the optimizer parameter class by name (case-insensitive). @@ -76,7 +77,7 @@ class OptimizerRegistry: name (str): Name of the optimizer. Returns: - Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The optimizer parameter class or a tuple of parameter classes. + Type[BaseModel]: The optimizer parameter class. """ entry = cls.__get_entry(name) return entry["params"] diff --git a/core/optimizers/adam.py b/core/optimizers/adam.py index 1919c77..7ebe157 100644 --- a/core/optimizers/adam.py +++ b/core/optimizers/adam.py @@ -1,6 +1,9 @@ -from typing import Any, Dict, Tuple +import torch +from torch import optim +from typing import Any, Dict, Iterable, Optional, Tuple from pydantic import BaseModel, ConfigDict +from .base import BaseOptimizer class AdamParams(BaseModel): """Configuration for `torch.optim.Adam` optimizer.""" @@ -14,4 +17,21 @@ class AdamParams(BaseModel): def asdict(self) -> Dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.Adam`.""" - return self.model_dump() \ No newline at end of file + return self.model_dump() + + +class AdamOptimizer(BaseOptimizer): + """ + Wrapper around torch.optim.Adam. + """ + + def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams): + """ + Initializes the Adam optimizer with given parameters. + + Args: + model_params (Iterable[Parameter]): Parameters to optimize. + optim_params (AdamParams): Optimizer parameters. + """ + super().__init__(model_params, optim_params) + self.optim = optim.Adam(model_params, **optim_params.asdict()) \ No newline at end of file diff --git a/core/optimizers/adamw.py b/core/optimizers/adamw.py index 755e3c0..f2d7ddb 100644 --- a/core/optimizers/adamw.py +++ b/core/optimizers/adamw.py @@ -1,6 +1,10 @@ -from typing import Any, Dict, Tuple +import torch +from torch import optim +from typing import Any, Dict, Iterable, Optional, Tuple from pydantic import BaseModel, ConfigDict +from .base import BaseOptimizer + class AdamWParams(BaseModel): """Configuration for `torch.optim.AdamW` optimizer.""" model_config = ConfigDict(frozen=True) @@ -13,4 +17,21 @@ class AdamWParams(BaseModel): def asdict(self) -> Dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.AdamW`.""" - return self.model_dump() \ No newline at end of file + return self.model_dump() + + +class AdamWOptimizer(BaseOptimizer): + """ + Wrapper around torch.optim.AdamW. + """ + + def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams): + """ + Initializes the AdamW optimizer with given parameters. + + Args: + model_params (Iterable[Parameter]): Parameters to optimize. + optim_params (AdamWParams): Optimizer parameters. + """ + super().__init__(model_params, optim_params) + self.optim = optim.AdamW(model_params, **optim_params.asdict()) \ No newline at end of file diff --git a/core/optimizers/base.py b/core/optimizers/base.py new file mode 100644 index 0000000..32cef7c --- /dev/null +++ b/core/optimizers/base.py @@ -0,0 +1,40 @@ +import torch +import torch.optim as optim +from pydantic import BaseModel +from typing import Any, Iterable, Optional + + +class BaseOptimizer: + """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" + + def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel): + super().__init__() + self.optim: Optional[optim.Optimizer] = None + + + def zero_grad(self, set_to_none: bool = True) -> None: + """ + Clears the gradients of all optimized tensors. + + Args: + set_to_none (bool): If True, sets gradients to None instead of zero. + This can reduce memory usage and improve performance. + (Introduced in PyTorch 1.7+) + """ + if self.optim is not None: + self.optim.zero_grad(set_to_none=set_to_none) + + + def step(self, closure: Optional[Any] = None) -> Any: + """ + Performs a single optimization step (parameter update). + + Args: + closure (Optional[Callable]): A closure that reevaluates the model and returns the loss. + This is required for optimizers like LBFGS that need multiple forward passes. + + Returns: + Any: The return value depends on the specific optimizer implementation. + """ + if self.optim is not None: + return self.optim.step(closure=closure) \ No newline at end of file diff --git a/core/optimizers/sgd.py b/core/optimizers/sgd.py index 0de3cc3..3870b9e 100644 --- a/core/optimizers/sgd.py +++ b/core/optimizers/sgd.py @@ -1,6 +1,10 @@ -from typing import Any, Dict +import torch +from torch import optim +from typing import Any, Dict, Iterable, Optional from pydantic import BaseModel, ConfigDict +from .base import BaseOptimizer + class SGDParams(BaseModel): """Configuration for `torch.optim.SGD` optimizer.""" @@ -14,4 +18,21 @@ class SGDParams(BaseModel): def asdict(self) -> Dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.SGD`.""" - return self.model_dump() \ No newline at end of file + return self.model_dump() + + +class SGDOptimizer(BaseOptimizer): + """ + Wrapper around torch.optim.SGD. + """ + + def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams): + """ + Initializes the SGD optimizer with given parameters. + + Args: + model_params (Iterable[Parameter]): Parameters to optimize. + optim_params (SGDParams): Optimizer parameters. + """ + super().__init__(model_params, optim_params) + self.optim = optim.SGD(model_params, **optim_params.asdict()) \ No newline at end of file diff --git a/core/schedulers/__init__.py b/core/schedulers/__init__.py index 2e92f1d..f2ff6ca 100644 --- a/core/schedulers/__init__.py +++ b/core/schedulers/__init__.py @@ -2,14 +2,16 @@ import torch.optim.lr_scheduler as lr_scheduler from typing import Dict, Final, Tuple, Type, List, Any, Union from pydantic import BaseModel -from .step import StepLRParams -from .multi_step import MultiStepLRParams -from .exponential import ExponentialLRParams -from .cosine_annealing import CosineAnnealingLRParams +from .base import BaseScheduler +from .step import StepLRParams, StepLRScheduler +from .multi_step import MultiStepLRParams, MultiStepLRScheduler +from .exponential import ExponentialLRParams, ExponentialLRScheduler +from .cosine_annealing import CosineAnnealingLRParams, CosineAnnealingLRScheduler __all__ = [ - "SchedulerRegistry", - "StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams" + "SchedulerRegistry", "BaseScheduler", + "StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams", + "StepLRScheduler", "MultiStepLRScheduler", "ExponentialLRScheduler", "CosineAnnealingLRScheduler" ] class SchedulerRegistry: @@ -17,19 +19,19 @@ class SchedulerRegistry: __SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = { "Step": { - "class": lr_scheduler.StepLR, + "class": StepLRScheduler, "params": StepLRParams, }, "Exponential": { - "class": lr_scheduler.ExponentialLR, + "class": ExponentialLRScheduler, "params": ExponentialLRParams, }, "MultiStep": { - "class": lr_scheduler.MultiStepLR, + "class": MultiStepLRScheduler, "params": MultiStepLRParams, }, "CosineAnnealing": { - "class": lr_scheduler.CosineAnnealingLR, + "class": CosineAnnealingLRScheduler, "params": CosineAnnealingLRParams, }, } @@ -58,7 +60,7 @@ class SchedulerRegistry: return cls.__SCHEDULERS[original_key] @classmethod - def get_scheduler_class(cls, name: str) -> Type[lr_scheduler.LRScheduler]: + def get_scheduler_class(cls, name: str) -> Type[BaseScheduler]: """ Retrieves the scheduler class by name (case-insensitive). @@ -66,13 +68,13 @@ class SchedulerRegistry: name (str): Name of the scheduler. Returns: - Type[lr_scheduler.LRScheduler]: The scheduler class. + Type[BaseScheduler]: The scheduler class. """ entry = cls.__get_entry(name) return entry["class"] @classmethod - def get_scheduler_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: + def get_scheduler_params(cls, name: str) -> Type[BaseModel]: """ Retrieves the scheduler parameter class by name (case-insensitive). @@ -80,7 +82,7 @@ class SchedulerRegistry: name (str): Name of the scheduler. Returns: - Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The scheduler parameter class or a tuple of parameter classes. + Type[BaseModel]: The scheduler parameter class. """ entry = cls.__get_entry(name) return entry["params"] diff --git a/core/schedulers/base.py b/core/schedulers/base.py new file mode 100644 index 0000000..892f7fd --- /dev/null +++ b/core/schedulers/base.py @@ -0,0 +1,27 @@ +import torch.optim as optim +from pydantic import BaseModel +from typing import List, Optional + + +class BaseScheduler: + """ + Abstract base class for learning rate schedulers. + Wraps a PyTorch LR scheduler and provides a unified interface. + """ + + def __init__(self, optimizer: optim.Optimizer, params: BaseModel): + self.scheduler: Optional[optim.lr_scheduler.LRScheduler] = None + + def step(self) -> None: + """ + Performs a single scheduler step. This typically updates the learning rate + based on the current epoch or step count. + """ + if self.scheduler is not None: + self.scheduler.step() + + def get_last_lr(self) -> List[float]: + """ + Returns the most recent learning rate(s). + """ + return self.scheduler.get_last_lr() if self.scheduler else [] diff --git a/core/schedulers/cosine_annealing.py b/core/schedulers/cosine_annealing.py index c086ffe..f0bba2b 100644 --- a/core/schedulers/cosine_annealing.py +++ b/core/schedulers/cosine_annealing.py @@ -1,5 +1,10 @@ from typing import Any, Dict from pydantic import BaseModel, ConfigDict +from torch import optim +from torch.optim.lr_scheduler import CosineAnnealingLR + +from .base import BaseScheduler + class CosineAnnealingLRParams(BaseModel): @@ -8,7 +13,24 @@ class CosineAnnealingLRParams(BaseModel): T_max: int = 100 # Maximum number of iterations eta_min: float = 0.0 # Minimum learning rate + last_epoch: int = -1 + def asdict(self) -> Dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`.""" return self.model_dump() + + +class CosineAnnealingLRScheduler(BaseScheduler): + """ + Wrapper around torch.optim.lr_scheduler.CosineAnnealingLR. + """ + + def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams): + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + params (CosineAnnealingLRParams): Scheduler parameters. + """ + super().__init__(optimizer, params) + self.scheduler = CosineAnnealingLR(optimizer, **params.asdict()) \ No newline at end of file diff --git a/core/schedulers/exponential.py b/core/schedulers/exponential.py index d708a7e..807aed7 100644 --- a/core/schedulers/exponential.py +++ b/core/schedulers/exponential.py @@ -1,5 +1,9 @@ from typing import Any, Dict from pydantic import BaseModel, ConfigDict +from torch import optim +from torch.optim.lr_scheduler import ExponentialLR + +from .base import BaseScheduler class ExponentialLRParams(BaseModel): @@ -7,7 +11,23 @@ class ExponentialLRParams(BaseModel): model_config = ConfigDict(frozen=True) gamma: float = 0.95 # Multiplicative factor of learning rate decay + last_epoch: int = -1 def asdict(self) -> Dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`.""" - return self.model_dump() \ No newline at end of file + return self.model_dump() + + +class ExponentialLRScheduler(BaseScheduler): + """ + Wrapper around torch.optim.lr_scheduler.ExponentialLR. + """ + + def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams): + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + params (ExponentialLRParams): Scheduler parameters. + """ + super().__init__(optimizer, params) + self.scheduler = ExponentialLR(optimizer, **params.asdict()) \ No newline at end of file diff --git a/core/schedulers/multi_step.py b/core/schedulers/multi_step.py index 1b02788..99d3858 100644 --- a/core/schedulers/multi_step.py +++ b/core/schedulers/multi_step.py @@ -1,5 +1,9 @@ from typing import Any, Dict, Tuple from pydantic import BaseModel, ConfigDict +from torch import optim +from torch.optim.lr_scheduler import MultiStepLR + +from .base import BaseScheduler class MultiStepLRParams(BaseModel): @@ -8,7 +12,23 @@ class MultiStepLRParams(BaseModel): milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay gamma: float = 0.1 # Multiplicative factor of learning rate decay + last_epoch: int = -1 def asdict(self) -> Dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`.""" - return self.model_dump() \ No newline at end of file + return self.model_dump() + + +class MultiStepLRScheduler(BaseScheduler): + """ + Wrapper around torch.optim.lr_scheduler.MultiStepLR. + """ + + def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams): + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + params (MultiStepLRParams): Scheduler parameters. + """ + super().__init__(optimizer, params) + self.scheduler = MultiStepLR(optimizer, **params.asdict()) \ No newline at end of file diff --git a/core/schedulers/step.py b/core/schedulers/step.py index 1bf4c56..0b9b609 100644 --- a/core/schedulers/step.py +++ b/core/schedulers/step.py @@ -1,5 +1,9 @@ from typing import Any, Dict from pydantic import BaseModel, ConfigDict +from torch import optim +from torch.optim.lr_scheduler import StepLR + +from .base import BaseScheduler class StepLRParams(BaseModel): @@ -8,7 +12,24 @@ class StepLRParams(BaseModel): step_size: int = 30 # Period of learning rate decay gamma: float = 0.1 # Multiplicative factor of learning rate decay + last_epoch: int = -1 def asdict(self) -> Dict[str, Any]: """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`.""" - return self.model_dump() \ No newline at end of file + return self.model_dump() + + + +class StepLRScheduler(BaseScheduler): + """ + Wrapper around torch.optim.lr_scheduler.StepLR. + """ + + def __init__(self, optimizer: optim.Optimizer, params: StepLRParams): + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + params (StepLRParams): Scheduler parameters. + """ + super().__init__(optimizer, params) + self.scheduler = StepLR(optimizer, **params.asdict()) \ No newline at end of file diff --git a/generate_config.py b/generate_config.py index 9fa783d..aa5dd7a 100644 --- a/generate_config.py +++ b/generate_config.py @@ -1,8 +1,7 @@ import os -from pydantic import BaseModel -from typing import Any, Dict, Tuple, Type, Union, List +from typing import Tuple -from config.config import Config +from config.config import * from config.dataset_config import DatasetConfig from core import ( @@ -10,18 +9,6 @@ from core import ( ) -def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]: - """ - Instantiates the parameter class(es) with default values. - - If 'param' is a tuple, instantiate each class and return a list of instances. - Otherwise, instantiate the single class and return the instance. - """ - if isinstance(param, tuple): - return [cls() for cls in param] - else: - return param() - def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str: """ Prompt the user with a list of options and return the selected option. @@ -55,11 +42,11 @@ def main(): model_options = ModelRegistry.get_available_models() chosen_model = prompt_choice("\nSelect a model:", model_options) model_param_class = ModelRegistry.get_model_params(chosen_model) - model_instance = instantiate_params(model_param_class) + model_instance = model_param_class() if is_training is False: config = Config( - model={chosen_model: model_instance}, + model=ComponentConfig(name=chosen_model, params=model_instance), dataset_config=dataset_config ) @@ -71,27 +58,27 @@ def main(): criterion_options = CriterionRegistry.get_available_criterions() chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options) criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion) - criterion_instance = instantiate_params(criterion_param_class) + criterion_instance = criterion_param_class() # Prompt the user to select an optimizer. optimizer_options = OptimizerRegistry.get_available_optimizers() chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options) optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer) - optimizer_instance = instantiate_params(optimizer_param_class) + optimizer_instance = optimizer_param_class() # Prompt the user to select a scheduler. scheduler_options = SchedulerRegistry.get_available_schedulers() chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options) scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler) - scheduler_instance = instantiate_params(scheduler_param_class) + scheduler_instance = scheduler_param_class() # Assemble the overall configuration using the registry names as keys. config = Config( - model={chosen_model: model_instance}, + model=ComponentConfig(name=chosen_model, params=model_instance), dataset_config=dataset_config, - criterion={chosen_criterion: criterion_instance}, - optimizer={chosen_optimizer: optimizer_instance}, - scheduler={chosen_scheduler: scheduler_instance} + criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance), + optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance), + scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance) ) # Construct a base filename from the selected registry names. diff --git a/train.py b/train.py index 1cd24f2..2ce5b9c 100644 --- a/train.py +++ b/train.py @@ -7,8 +7,4 @@ pprint(config, indent=4) print('\n\n') config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json') -pprint(config, indent=4) - -print('\n\n') -config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV_1.json') pprint(config, indent=4) \ No newline at end of file