You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1.4 KiB
37 lines
1.4 KiB
import torch
|
|
from torch import optim
|
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
from .base import BaseOptimizer
|
|
|
|
class AdamWParams(BaseModel):
|
|
"""Configuration for `torch.optim.AdamW` optimizer."""
|
|
model_config = ConfigDict(frozen=True)
|
|
|
|
lr: float = 1e-3 # Learning rate
|
|
betas: Tuple[float, ...] = (0.9, 0.999) # Adam coefficients
|
|
eps: float = 1e-8 # Numerical stability
|
|
weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay)
|
|
amsgrad: bool = False # Whether to use the AMSGrad variant
|
|
|
|
def asdict(self) -> Dict[str, Any]:
|
|
"""Returns a dictionary of valid parameters for `torch.optim.AdamW`."""
|
|
return self.model_dump()
|
|
|
|
|
|
class AdamWOptimizer(BaseOptimizer):
|
|
"""
|
|
Wrapper around torch.optim.AdamW.
|
|
"""
|
|
|
|
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: AdamWParams):
|
|
"""
|
|
Initializes the AdamW optimizer with given parameters.
|
|
|
|
Args:
|
|
model_params (Iterable[Parameter]): Parameters to optimize.
|
|
optim_params (AdamWParams): Optimizer parameters.
|
|
"""
|
|
super().__init__(model_params, optim_params)
|
|
self.optim = optim.AdamW(model_params, **optim_params.asdict()) |