from pydantic import BaseModel from typing import Dict, Final, Tuple, Type, List, Any, Union from .base import BaseOptimizer from .adam import AdamParams, AdamOptimizer from .adamw import AdamWParams, AdamWOptimizer from .sgd import SGDParams, SGDOptimizer __all__ = [ "OptimizerRegistry", "BaseOptimizer", "AdamParams", "AdamWParams", "SGDParams", "AdamOptimizer", "AdamWOptimizer", "SGDOptimizer" ] class OptimizerRegistry: """Registry for optimizers and their parameter classes with case-insensitive lookup.""" # Single dictionary storing both optimizer classes and parameter classes. __OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = { "SGD": { "class": SGDOptimizer, "params": SGDParams, }, "Adam": { "class": AdamOptimizer, "params": AdamParams, }, "AdamW": { "class": AdamWOptimizer, "params": AdamWParams, }, } @classmethod def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: """ Private method to retrieve the optimizer entry from the registry using case-insensitive lookup. Args: name (str): The name of the optimizer. Returns: Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. Raises: ValueError: If the optimizer is not found. """ name_lower = name.lower() mapping = {key.lower(): key for key in cls.__OPTIMIZERS} original_key = mapping.get(name_lower) if original_key is None: raise ValueError( f"Optimizer '{name}' not found! Available options: {list(cls.__OPTIMIZERS.keys())}" ) return cls.__OPTIMIZERS[original_key] @classmethod def get_optimizer_class(cls, name: str) -> Type[BaseOptimizer]: """ Retrieves the optimizer class by name (case-insensitive). Args: name (str): Name of the optimizer. Returns: Type[BaseOptimizer]: The optimizer class. """ entry = cls.__get_entry(name) return entry["class"] @classmethod def get_optimizer_params(cls, name: str) -> Type[BaseModel]: """ Retrieves the optimizer parameter class by name (case-insensitive). Args: name (str): Name of the optimizer. Returns: Type[BaseModel]: The optimizer parameter class. """ entry = cls.__get_entry(name) return entry["params"] @classmethod def get_available_optimizers(cls) -> Tuple[str, ...]: """ Returns a tuple of available optimizer names in their original case. Returns: Tuple[str]: Tuple of available optimizer names. """ return tuple(cls.__OPTIMIZERS.keys())