1) Removed support for parameter lists;

2) added wrappers for optimizer and schedule;
3) changed parameters for losses.
master
laynholt 3 weeks ago
parent 78f97a72a2
commit 5db2220917

@ -1,29 +1,41 @@
import json
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional
from pydantic import BaseModel
from .dataset_config import DatasetConfig
class Config(BaseModel):
model: Dict[str, Union[BaseModel, List[BaseModel]]]
dataset_config: DatasetConfig
criterion: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
optimizer: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
scheduler: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
@staticmethod
def __dump_field(value: Any) -> Any:
__all__ = ["Config", "ComponentConfig"]
class ComponentConfig(BaseModel):
name: str
params: BaseModel
def dump(self) -> Dict[str, Any]:
"""
Recursively dumps a field if it is a BaseModel or a list/dict of BaseModels.
Recursively serializes the component into a dictionary.
Returns:
dict: A dictionary containing the component name and its serialized parameters.
"""
if isinstance(value, BaseModel):
return value.model_dump()
elif isinstance(value, list):
return [Config.__dump_field(item) for item in value]
elif isinstance(value, dict):
return {k: Config.__dump_field(v) for k, v in value.items()}
if isinstance(self.params, BaseModel):
params_dump = self.params.model_dump()
else:
return value
params_dump = self.params
return {
"name": self.name,
"params": params_dump
}
class Config(BaseModel):
model: ComponentConfig
dataset_config: DatasetConfig
criterion: Optional[ComponentConfig] = None
optimizer: Optional[ComponentConfig] = None
scheduler: Optional[ComponentConfig] = None
def save_json(self, file_path: str, indent: int = 4) -> None:
"""
@ -34,15 +46,15 @@ class Config(BaseModel):
indent (int): Indentation level for the JSON file.
"""
config_dump = {
"model": self.__dump_field(self.model),
"model": self.model.dump(),
"dataset_config": self.dataset_config.model_dump()
}
if self.criterion is not None:
config_dump.update({"criterion": self.__dump_field(self.criterion)})
config_dump["criterion"] = self.criterion.dump()
if self.optimizer is not None:
config_dump.update({"optimizer": self.__dump_field(self.optimizer)})
config_dump["optimizer"] = self.optimizer.dump()
if self.scheduler is not None:
config_dump.update({"scheduler": self.__dump_field(self.scheduler)})
config_dump["scheduler"] = self.scheduler.dump()
with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(config_dump, indent=indent))
@ -67,16 +79,15 @@ class Config(BaseModel):
dataset_config = DatasetConfig(**data.get("dataset_config", {}))
# Helper function to parse registry fields.
def parse_field(field_data: Dict[str, Any], registry_getter) -> Dict[str, Union[BaseModel, List[BaseModel]]]:
result = {}
for key, value in field_data.items():
expected = registry_getter(key)
# If the registry returns a tuple, then we expect a list of dictionaries.
if isinstance(expected, tuple):
result[key] = [cls_param(**item) for cls_param, item in zip(expected, value)]
else:
result[key] = expected(**value)
return result
def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]:
name = component_data.get("name")
params_data = component_data.get("params", {})
if name is not None:
expected = registry_getter(name)
params = expected(**params_data)
return ComponentConfig(name=name, params=params)
return None
from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
@ -87,6 +98,9 @@ class Config(BaseModel):
parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key))
parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key))
if parsed_model is None:
raise ValueError('Failed to load model information')
return cls(
model=parsed_model,
dataset_config=dataset_config,

@ -98,6 +98,10 @@ class DatasetTrainingConfig(BaseModel):
- If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist).
- If is_split is False, validates split (all_data_dir must be non-empty and exist).
"""
if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)):
if (self.train_size + self.valid_size + self.test_size) > 1:
raise ValueError("The total sample size with dynamically defined sizes must be <= 1")
if not self.is_split:
if not self.split.all_data_dir:
raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split")

@ -5,12 +5,12 @@ from .base import BaseLoss
from .ce import CrossEntropyLoss, CrossEntropyLossParams
from .bce import BCELoss, BCELossParams
from .mse import MSELoss, MSELossParams
from .mse_with_bce import BCE_MSE_Loss
from .mse_with_bce import BCE_MSE_Loss, BCE_MSE_LossParams
__all__ = [
"CriterionRegistry",
"CriterionRegistry", "BaseLoss",
"CrossEntropyLoss", "BCELoss", "MSELoss", "BCE_MSE_Loss",
"CrossEntropyLossParams", "BCELossParams", "MSELossParams"
"CrossEntropyLossParams", "BCELossParams", "MSELossParams", "BCE_MSE_LossParams"
]
class CriterionRegistry:
@ -31,7 +31,7 @@ class CriterionRegistry:
},
"BCE_MSE_Loss": {
"class": BCE_MSE_Loss,
"params": (BCELossParams, MSELossParams),
"params": BCE_MSE_LossParams,
},
}
@ -73,7 +73,7 @@ class CriterionRegistry:
return entry["class"]
@classmethod
def get_criterion_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
def get_criterion_params(cls, name: str) -> Type[BaseModel]:
"""
Retrieves the loss function parameter class (or classes) by name (case-insensitive).
@ -81,7 +81,7 @@ class CriterionRegistry:
name (str): Name of the loss function.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The loss function parameter class or a tuple of parameter classes.
Type[BaseModel]: The loss function parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]

@ -1,6 +1,7 @@
import abc
import torch
import torch.nn as nn
from pydantic import BaseModel
from typing import Dict, Any, Optional
from monai.metrics.cumulative_average import CumulativeAverage
@ -8,9 +9,7 @@ from monai.metrics.cumulative_average import CumulativeAverage
class BaseLoss(abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self):
"""
"""
def __init__(self, params: Optional[BaseModel] = None):
super().__init__()

@ -9,27 +9,26 @@ class BCELossParams(BaseModel):
"""
model_config = ConfigDict(frozen=True)
with_logits: bool = False
weight: Optional[List[Union[int, float]]] = None # Sample weights
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss
def asdict(self, with_logits: bool = False) -> Dict[str, Any]:
def asdict(self) -> Dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`.
- If `with_logits=False`, `pos_weight` is **removed** to avoid errors.
- Ensures only the valid parameters are passed based on the loss function.
Args:
with_logits (bool): If `True`, includes `pos_weight` (for `nn.BCEWithLogitsLoss`).
If `False`, removes `pos_weight` (for `nn.BCELoss`).
Returns:
Dict[str, Any]: Filtered dictionary of parameters.
"""
loss_kwargs = self.model_dump()
if not with_logits:
if not self.with_logits:
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
loss_kwargs.pop("with_logits", None)
weight = loss_kwargs.get("weight")
pos_weight = loss_kwargs.get("pos_weight")
@ -48,15 +47,16 @@ class BCELoss(BaseLoss):
Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics.
"""
def __init__(self, bce_params: Optional[BCELossParams] = None, with_logits: bool = False):
def __init__(self, params: Optional[BCELossParams] = None):
"""
Initializes the loss function with optional BCELoss parameters.
Args:
bce_params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None).
params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None).
"""
super().__init__()
_bce_params = bce_params.asdict(with_logits=with_logits) if bce_params is not None else {}
super().__init__(params=params)
with_logits = params.with_logits if params is not None else False
_bce_params = params.asdict() if params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults
self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params)

@ -36,15 +36,15 @@ class CrossEntropyLoss(BaseLoss):
Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics.
"""
def __init__(self, ce_params: Optional[CrossEntropyLossParams] = None):
def __init__(self, params: Optional[CrossEntropyLossParams] = None):
"""
Initializes the loss function with optional CrossEntropyLoss parameters.
Args:
ce_params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None).
params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None).
"""
super().__init__()
_ce_params = ce_params.asdict() if ce_params is not None else {}
super().__init__(params=params)
_ce_params = params.asdict() if params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults
self.ce_loss = nn.CrossEntropyLoss(**_ce_params)

@ -27,15 +27,15 @@ class MSELoss(BaseLoss):
Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics.
"""
def __init__(self, mse_params: Optional[MSELossParams] = None):
def __init__(self, params: Optional[MSELossParams] = None):
"""
Initializes the loss function with optional MSELoss parameters.
Args:
mse_params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None).
params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None).
"""
super().__init__()
_mse_params = mse_params.asdict() if mse_params is not None else {}
super().__init__(params=params)
_mse_params = params.asdict() if params is not None else {}
# Initialize MSE loss with user-provided parameters or PyTorch defaults
self.mse_loss = nn.MSELoss(**_mse_params)

@ -2,42 +2,59 @@ from .base import *
from .bce import BCELossParams
from .mse import MSELossParams
from pydantic import BaseModel, ConfigDict
class BCE_MSE_LossParams(BaseModel):
"""
Class for handling parameters for `nn.MSELoss` with `nn.BCELoss`.
"""
model_config = ConfigDict(frozen=True)
num_classes: int = 1
bce_params: BCELossParams = BCELossParams()
mse_params: MSELossParams = MSELossParams()
def asdict(self) -> Dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.BCELoss` and `nn.MSELoss`.
Returns:
Dict[str, Any]: Dictionary of parameters.
"""
return {
"num_classes": self.num_classes,
"bce_params": self.bce_params.asdict(),
"mse_params": self.mse_params.asdict()
}
class BCE_MSE_Loss(BaseLoss):
"""
Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction.
"""
def __init__(
self,
num_classes: int,
bce_params: Optional[BCELossParams] = None,
mse_params: Optional[MSELossParams] = None,
bce_with_logits: bool = False,
):
def __init__(self, params: Optional[BCE_MSE_LossParams]):
"""
Initializes the loss function with optional BCE and MSE parameters.
Args:
num_classes (int): Number of output classes, used for target shifting.
bce_params (Optional[BCELossParams]): Parameters for BCEWithLogitsLoss or BCELoss (default: None).
mse_params (Optional[MSELossParams]): Parameters for MSELoss (default: None).
bce_with_logits (bool): If True, uses BCEWithLogitsLoss; otherwise, uses BCELoss.
"""
super().__init__()
super().__init__(params=params)
_params = params if params is not None else BCE_MSE_LossParams()
self.num_classes = num_classes
self.num_classes = _params.num_classes
# Process BCE parameters
_bce_params = bce_params.asdict(bce_with_logits) if bce_params is not None else {}
_bce_params = _params.bce_params.asdict()
# Choose BCE loss function
self.bce_loss = (
nn.BCEWithLogitsLoss(**_bce_params) if bce_with_logits else nn.BCELoss(**_bce_params)
nn.BCEWithLogitsLoss(**_bce_params) if _params.bce_params.with_logits else nn.BCELoss(**_bce_params)
)
# Process MSE parameters
_mse_params = mse_params.asdict() if mse_params is not None else {}
_mse_params = _params.mse_params.asdict()
# Initialize MSE loss
self.mse_loss = nn.MSELoss(**_mse_params)

@ -61,7 +61,7 @@ class ModelRegistry:
return entry["class"]
@classmethod
def get_model_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
def get_model_params(cls, name: str) -> Type[BaseModel]:
"""
Retrieves the model parameter class by name (case-insensitive).
@ -69,7 +69,7 @@ class ModelRegistry:
name (str): Name of the model.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The model parameter class or a tuple of parameter classes.
Type[BaseModel]: The model parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]

@ -1,14 +1,15 @@
import torch.optim as optim
from pydantic import BaseModel
from typing import Dict, Final, Tuple, Type, List, Any, Union
from .adam import AdamParams
from .adamw import AdamWParams
from .sgd import SGDParams
from .base import BaseOptimizer
from .adam import AdamParams, AdamOptimizer
from .adamw import AdamWParams, AdamWOptimizer
from .sgd import SGDParams, SGDOptimizer
__all__ = [
"OptimizerRegistry",
"AdamParams", "AdamWParams", "SGDParams"
"OptimizerRegistry", "BaseOptimizer",
"AdamParams", "AdamWParams", "SGDParams",
"AdamOptimizer", "AdamWOptimizer", "SGDOptimizer"
]
class OptimizerRegistry:
@ -17,15 +18,15 @@ class OptimizerRegistry:
# Single dictionary storing both optimizer classes and parameter classes.
__OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
"SGD": {
"class": optim.SGD,
"class": SGDOptimizer,
"params": SGDParams,
},
"Adam": {
"class": optim.Adam,
"class": AdamOptimizer,
"params": AdamParams,
},
"AdamW": {
"class": optim.AdamW,
"class": AdamWOptimizer,
"params": AdamWParams,
},
}
@ -54,7 +55,7 @@ class OptimizerRegistry:
return cls.__OPTIMIZERS[original_key]
@classmethod
def get_optimizer_class(cls, name: str) -> Type[optim.Optimizer]:
def get_optimizer_class(cls, name: str) -> Type[BaseOptimizer]:
"""
Retrieves the optimizer class by name (case-insensitive).
@ -62,13 +63,13 @@ class OptimizerRegistry:
name (str): Name of the optimizer.
Returns:
Type[optim.Optimizer]: The optimizer class.
Type[BaseOptimizer]: The optimizer class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_optimizer_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
def get_optimizer_params(cls, name: str) -> Type[BaseModel]:
"""
Retrieves the optimizer parameter class by name (case-insensitive).
@ -76,7 +77,7 @@ class OptimizerRegistry:
name (str): Name of the optimizer.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The optimizer parameter class or a tuple of parameter classes.
Type[BaseModel]: The optimizer parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]

@ -1,6 +1,9 @@
from typing import Any, Dict, Tuple
import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple
from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
class AdamParams(BaseModel):
"""Configuration for `torch.optim.Adam` optimizer."""
@ -14,4 +17,21 @@ class AdamParams(BaseModel):
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.Adam`."""
return self.model_dump()
return self.model_dump()
class AdamOptimizer(BaseOptimizer):
"""
Wrapper around torch.optim.Adam.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamParams):
"""
Initializes the Adam optimizer with given parameters.
Args:
model_params (Iterable[Parameter]): Parameters to optimize.
optim_params (AdamParams): Optimizer parameters.
"""
super().__init__(model_params, optim_params)
self.optim = optim.Adam(model_params, **optim_params.asdict())

@ -1,6 +1,10 @@
from typing import Any, Dict, Tuple
import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple
from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
class AdamWParams(BaseModel):
"""Configuration for `torch.optim.AdamW` optimizer."""
model_config = ConfigDict(frozen=True)
@ -13,4 +17,21 @@ class AdamWParams(BaseModel):
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.AdamW`."""
return self.model_dump()
return self.model_dump()
class AdamWOptimizer(BaseOptimizer):
"""
Wrapper around torch.optim.AdamW.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams):
"""
Initializes the AdamW optimizer with given parameters.
Args:
model_params (Iterable[Parameter]): Parameters to optimize.
optim_params (AdamWParams): Optimizer parameters.
"""
super().__init__(model_params, optim_params)
self.optim = optim.AdamW(model_params, **optim_params.asdict())

@ -0,0 +1,40 @@
import torch
import torch.optim as optim
from pydantic import BaseModel
from typing import Any, Iterable, Optional
class BaseOptimizer:
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel):
super().__init__()
self.optim: Optional[optim.Optimizer] = None
def zero_grad(self, set_to_none: bool = True) -> None:
"""
Clears the gradients of all optimized tensors.
Args:
set_to_none (bool): If True, sets gradients to None instead of zero.
This can reduce memory usage and improve performance.
(Introduced in PyTorch 1.7+)
"""
if self.optim is not None:
self.optim.zero_grad(set_to_none=set_to_none)
def step(self, closure: Optional[Any] = None) -> Any:
"""
Performs a single optimization step (parameter update).
Args:
closure (Optional[Callable]): A closure that reevaluates the model and returns the loss.
This is required for optimizers like LBFGS that need multiple forward passes.
Returns:
Any: The return value depends on the specific optimizer implementation.
"""
if self.optim is not None:
return self.optim.step(closure=closure)

@ -1,6 +1,10 @@
from typing import Any, Dict
import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional
from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
class SGDParams(BaseModel):
"""Configuration for `torch.optim.SGD` optimizer."""
@ -14,4 +18,21 @@ class SGDParams(BaseModel):
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.SGD`."""
return self.model_dump()
return self.model_dump()
class SGDOptimizer(BaseOptimizer):
"""
Wrapper around torch.optim.SGD.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams):
"""
Initializes the SGD optimizer with given parameters.
Args:
model_params (Iterable[Parameter]): Parameters to optimize.
optim_params (SGDParams): Optimizer parameters.
"""
super().__init__(model_params, optim_params)
self.optim = optim.SGD(model_params, **optim_params.asdict())

@ -2,14 +2,16 @@ import torch.optim.lr_scheduler as lr_scheduler
from typing import Dict, Final, Tuple, Type, List, Any, Union
from pydantic import BaseModel
from .step import StepLRParams
from .multi_step import MultiStepLRParams
from .exponential import ExponentialLRParams
from .cosine_annealing import CosineAnnealingLRParams
from .base import BaseScheduler
from .step import StepLRParams, StepLRScheduler
from .multi_step import MultiStepLRParams, MultiStepLRScheduler
from .exponential import ExponentialLRParams, ExponentialLRScheduler
from .cosine_annealing import CosineAnnealingLRParams, CosineAnnealingLRScheduler
__all__ = [
"SchedulerRegistry",
"StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams"
"SchedulerRegistry", "BaseScheduler",
"StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams",
"StepLRScheduler", "MultiStepLRScheduler", "ExponentialLRScheduler", "CosineAnnealingLRScheduler"
]
class SchedulerRegistry:
@ -17,19 +19,19 @@ class SchedulerRegistry:
__SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
"Step": {
"class": lr_scheduler.StepLR,
"class": StepLRScheduler,
"params": StepLRParams,
},
"Exponential": {
"class": lr_scheduler.ExponentialLR,
"class": ExponentialLRScheduler,
"params": ExponentialLRParams,
},
"MultiStep": {
"class": lr_scheduler.MultiStepLR,
"class": MultiStepLRScheduler,
"params": MultiStepLRParams,
},
"CosineAnnealing": {
"class": lr_scheduler.CosineAnnealingLR,
"class": CosineAnnealingLRScheduler,
"params": CosineAnnealingLRParams,
},
}
@ -58,7 +60,7 @@ class SchedulerRegistry:
return cls.__SCHEDULERS[original_key]
@classmethod
def get_scheduler_class(cls, name: str) -> Type[lr_scheduler.LRScheduler]:
def get_scheduler_class(cls, name: str) -> Type[BaseScheduler]:
"""
Retrieves the scheduler class by name (case-insensitive).
@ -66,13 +68,13 @@ class SchedulerRegistry:
name (str): Name of the scheduler.
Returns:
Type[lr_scheduler.LRScheduler]: The scheduler class.
Type[BaseScheduler]: The scheduler class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_scheduler_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
def get_scheduler_params(cls, name: str) -> Type[BaseModel]:
"""
Retrieves the scheduler parameter class by name (case-insensitive).
@ -80,7 +82,7 @@ class SchedulerRegistry:
name (str): Name of the scheduler.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The scheduler parameter class or a tuple of parameter classes.
Type[BaseModel]: The scheduler parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]

@ -0,0 +1,27 @@
import torch.optim as optim
from pydantic import BaseModel
from typing import List, Optional
class BaseScheduler:
"""
Abstract base class for learning rate schedulers.
Wraps a PyTorch LR scheduler and provides a unified interface.
"""
def __init__(self, optimizer: optim.Optimizer, params: BaseModel):
self.scheduler: Optional[optim.lr_scheduler.LRScheduler] = None
def step(self) -> None:
"""
Performs a single scheduler step. This typically updates the learning rate
based on the current epoch or step count.
"""
if self.scheduler is not None:
self.scheduler.step()
def get_last_lr(self) -> List[float]:
"""
Returns the most recent learning rate(s).
"""
return self.scheduler.get_last_lr() if self.scheduler else []

@ -1,5 +1,10 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from .base import BaseScheduler
class CosineAnnealingLRParams(BaseModel):
@ -8,7 +13,24 @@ class CosineAnnealingLRParams(BaseModel):
T_max: int = 100 # Maximum number of iterations
eta_min: float = 0.0 # Minimum learning rate
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`."""
return self.model_dump()
class CosineAnnealingLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.CosineAnnealingLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: CosineAnnealingLRParams):
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (CosineAnnealingLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = CosineAnnealingLR(optimizer, **params.asdict())

@ -1,5 +1,9 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import ExponentialLR
from .base import BaseScheduler
class ExponentialLRParams(BaseModel):
@ -7,7 +11,23 @@ class ExponentialLRParams(BaseModel):
model_config = ConfigDict(frozen=True)
gamma: float = 0.95 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`."""
return self.model_dump()
return self.model_dump()
class ExponentialLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.ExponentialLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: ExponentialLRParams):
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (ExponentialLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = ExponentialLR(optimizer, **params.asdict())

@ -1,5 +1,9 @@
from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from .base import BaseScheduler
class MultiStepLRParams(BaseModel):
@ -8,7 +12,23 @@ class MultiStepLRParams(BaseModel):
milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`."""
return self.model_dump()
return self.model_dump()
class MultiStepLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.MultiStepLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams):
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (MultiStepLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = MultiStepLR(optimizer, **params.asdict())

@ -1,5 +1,9 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
from torch import optim
from torch.optim.lr_scheduler import StepLR
from .base import BaseScheduler
class StepLRParams(BaseModel):
@ -8,7 +12,24 @@ class StepLRParams(BaseModel):
step_size: int = 30 # Period of learning rate decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`."""
return self.model_dump()
return self.model_dump()
class StepLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.StepLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: StepLRParams):
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (StepLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = StepLR(optimizer, **params.asdict())

@ -1,8 +1,7 @@
import os
from pydantic import BaseModel
from typing import Any, Dict, Tuple, Type, Union, List
from typing import Tuple
from config.config import Config
from config.config import *
from config.dataset_config import DatasetConfig
from core import (
@ -10,18 +9,6 @@ from core import (
)
def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]:
"""
Instantiates the parameter class(es) with default values.
If 'param' is a tuple, instantiate each class and return a list of instances.
Otherwise, instantiate the single class and return the instance.
"""
if isinstance(param, tuple):
return [cls() for cls in param]
else:
return param()
def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str:
"""
Prompt the user with a list of options and return the selected option.
@ -55,11 +42,11 @@ def main():
model_options = ModelRegistry.get_available_models()
chosen_model = prompt_choice("\nSelect a model:", model_options)
model_param_class = ModelRegistry.get_model_params(chosen_model)
model_instance = instantiate_params(model_param_class)
model_instance = model_param_class()
if is_training is False:
config = Config(
model={chosen_model: model_instance},
model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config
)
@ -71,27 +58,27 @@ def main():
criterion_options = CriterionRegistry.get_available_criterions()
chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options)
criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion)
criterion_instance = instantiate_params(criterion_param_class)
criterion_instance = criterion_param_class()
# Prompt the user to select an optimizer.
optimizer_options = OptimizerRegistry.get_available_optimizers()
chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options)
optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer)
optimizer_instance = instantiate_params(optimizer_param_class)
optimizer_instance = optimizer_param_class()
# Prompt the user to select a scheduler.
scheduler_options = SchedulerRegistry.get_available_schedulers()
chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options)
scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler)
scheduler_instance = instantiate_params(scheduler_param_class)
scheduler_instance = scheduler_param_class()
# Assemble the overall configuration using the registry names as keys.
config = Config(
model={chosen_model: model_instance},
model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config,
criterion={chosen_criterion: criterion_instance},
optimizer={chosen_optimizer: optimizer_instance},
scheduler={chosen_scheduler: scheduler_instance}
criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance),
optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance),
scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance)
)
# Construct a base filename from the selected registry names.

@ -7,8 +7,4 @@ pprint(config, indent=4)
print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json')
pprint(config, indent=4)
print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV_1.json')
pprint(config, indent=4)
Loading…
Cancel
Save