You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
2.8 KiB

import torch.optim as optim
from pydantic import BaseModel
from typing import Dict, Final, Tuple, Type, List, Any, Union
from .adam import AdamParams
from .adamw import AdamWParams
from .sgd import SGDParams
__all__ = [
"OptimizerRegistry",
"AdamParams", "AdamWParams", "SGDParams"
]
class OptimizerRegistry:
"""Registry for optimizers and their parameter classes with case-insensitive lookup."""
# Single dictionary storing both optimizer classes and parameter classes.
__OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
"SGD": {
"class": optim.SGD,
"params": SGDParams,
},
"Adam": {
"class": optim.Adam,
"params": AdamParams,
},
"AdamW": {
"class": optim.AdamW,
"params": AdamWParams,
},
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
"""
Private method to retrieve the optimizer entry from the registry using case-insensitive lookup.
Args:
name (str): The name of the optimizer.
Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the optimizer is not found.
"""
name_lower = name.lower()
mapping = {key.lower(): key for key in cls.__OPTIMIZERS}
original_key = mapping.get(name_lower)
if original_key is None:
raise ValueError(
f"Optimizer '{name}' not found! Available options: {list(cls.__OPTIMIZERS.keys())}"
)
return cls.__OPTIMIZERS[original_key]
@classmethod
def get_optimizer_class(cls, name: str) -> Type[optim.Optimizer]:
"""
Retrieves the optimizer class by name (case-insensitive).
Args:
name (str): Name of the optimizer.
Returns:
Type[optim.Optimizer]: The optimizer class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_optimizer_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
"""
Retrieves the optimizer parameter class by name (case-insensitive).
Args:
name (str): Name of the optimizer.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The optimizer parameter class or a tuple of parameter classes.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_optimizers(cls) -> List[str]:
"""
Returns a list of available optimizer names in their original case.
Returns:
List[str]: List of available optimizer names.
"""
return list(cls.__OPTIMIZERS.keys())