You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

34 lines
1.1 KiB

from .base import BaseScheduler
from typing import Any
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from pydantic import BaseModel, ConfigDict
class MultiStepLRParams(BaseModel):
"""Configuration for `torch.optim.lr_scheduler.MultiStepLR`."""
model_config = ConfigDict(frozen=True)
milestones: tuple[int, ...] = (30, 80) # List of epoch indices for LR decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay
last_epoch: int = -1
def asdict(self) -> dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`."""
return self.model_dump()
class MultiStepLRScheduler(BaseScheduler):
"""
Wrapper around torch.optim.lr_scheduler.MultiStepLR.
"""
def __init__(self, optimizer: optim.Optimizer, params: MultiStepLRParams) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
params (MultiStepLRParams): Scheduler parameters.
"""
super().__init__(optimizer, params)
self.scheduler = MultiStepLR(optimizer, **params.asdict())