You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

38 lines
1.3 KiB

import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional
from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
class SGDParams(BaseModel):
"""Configuration for `torch.optim.SGD` optimizer."""
model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate
momentum: float = 0.0 # Momentum factor
dampening: float = 0.0 # Dampening for momentum
weight_decay: float = 0.0 # L2 penalty
nesterov: bool = False # Enables Nesterov momentum
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.SGD`."""
return self.model_dump()
class SGDOptimizer(BaseOptimizer):
"""
Wrapper around torch.optim.SGD.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: SGDParams):
"""
Initializes the SGD optimizer with given parameters.
Args:
model_params (Iterable[Parameter]): Parameters to optimize.
optim_params (SGDParams): Optimizer parameters.
"""
super().__init__(model_params, optim_params)
self.optim = optim.SGD(model_params, **optim_params.asdict())