You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

37 lines
1.4 KiB

import torch
from torch import optim
from typing import Any, Dict, Iterable, Optional, Tuple
from pydantic import BaseModel, ConfigDict
from .base import BaseOptimizer
class AdamWParams(BaseModel):
"""Configuration for `torch.optim.AdamW` optimizer."""
model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate
betas: Tuple[float, ...] = (0.9, 0.999) # Adam coefficients
eps: float = 1e-8 # Numerical stability
weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay)
amsgrad: bool = False # Whether to use the AMSGrad variant
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.AdamW`."""
return self.model_dump()
class AdamWOptimizer(BaseOptimizer):
"""
Wrapper around torch.optim.AdamW.
"""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams):
"""
Initializes the AdamW optimizer with given parameters.
Args:
model_params (Iterable[Parameter]): Parameters to optimize.
optim_params (AdamWParams): Optimizer parameters.
"""
super().__init__(model_params, optim_params)
self.optim = optim.AdamW(model_params, **optim_params.asdict())