You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							94 lines
						
					
					
						
							2.9 KiB
						
					
					
				
			
		
		
	
	
							94 lines
						
					
					
						
							2.9 KiB
						
					
					
				from pydantic import BaseModel
 | 
						|
from typing import Final, Type, Any
 | 
						|
 | 
						|
from .base import BaseOptimizer
 | 
						|
from .adam import AdamParams, AdamOptimizer
 | 
						|
from .adamw import AdamWParams, AdamWOptimizer
 | 
						|
from .sgd import SGDParams, SGDOptimizer
 | 
						|
 | 
						|
__all__ = [
 | 
						|
    "OptimizerRegistry", "BaseOptimizer",
 | 
						|
    "AdamParams", "AdamWParams", "SGDParams",
 | 
						|
    "AdamOptimizer", "AdamWOptimizer", "SGDOptimizer"
 | 
						|
]
 | 
						|
 | 
						|
class OptimizerRegistry:
 | 
						|
    """Registry for optimizers and their parameter classes with case-insensitive lookup."""
 | 
						|
    
 | 
						|
    # Single dictionary storing both optimizer classes and parameter classes.
 | 
						|
    __OPTIMIZERS: Final[dict[str, dict[str, Type[Any]]]] = {
 | 
						|
        "SGD": {
 | 
						|
            "class": SGDOptimizer,
 | 
						|
            "params": SGDParams,
 | 
						|
        },
 | 
						|
        "Adam": {
 | 
						|
            "class": AdamOptimizer,
 | 
						|
            "params": AdamParams,
 | 
						|
        },
 | 
						|
        "AdamW": {
 | 
						|
            "class": AdamWOptimizer,
 | 
						|
            "params": AdamWParams,
 | 
						|
        },
 | 
						|
    }
 | 
						|
    
 | 
						|
    @classmethod
 | 
						|
    def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
 | 
						|
        """
 | 
						|
        Private method to retrieve the optimizer entry from the registry using case-insensitive lookup.
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            name (str): The name of the optimizer.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            dict(str, Type(Any)): A dictionary containing the keys 'class' and 'params'.
 | 
						|
        
 | 
						|
        Raises:
 | 
						|
            ValueError: If the optimizer is not found.
 | 
						|
        """
 | 
						|
        name_lower = name.lower()
 | 
						|
        mapping = {key.lower(): key for key in cls.__OPTIMIZERS}
 | 
						|
        original_key = mapping.get(name_lower)
 | 
						|
        if original_key is None:
 | 
						|
            raise ValueError(
 | 
						|
                f"Optimizer '{name}' not found! Available options: {list(cls.__OPTIMIZERS.keys())}"
 | 
						|
            )
 | 
						|
        return cls.__OPTIMIZERS[original_key]
 | 
						|
    
 | 
						|
    @classmethod
 | 
						|
    def get_optimizer_class(cls, name: str) -> Type[BaseOptimizer]:
 | 
						|
        """
 | 
						|
        Retrieves the optimizer class by name (case-insensitive).
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            name (str): Name of the optimizer.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            Type(BaseOptimizer): The optimizer class.
 | 
						|
        """
 | 
						|
        entry = cls.__get_entry(name)
 | 
						|
        return entry["class"]
 | 
						|
    
 | 
						|
    @classmethod
 | 
						|
    def get_optimizer_params(cls, name: str) -> Type[BaseModel]:
 | 
						|
        """
 | 
						|
        Retrieves the optimizer parameter class by name (case-insensitive).
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            name (str): Name of the optimizer.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            Type(BaseModel): The optimizer parameter class.
 | 
						|
        """
 | 
						|
        entry = cls.__get_entry(name)
 | 
						|
        return entry["params"]
 | 
						|
    
 | 
						|
    @classmethod
 | 
						|
    def get_available_optimizers(cls) -> tuple[str, ...]:
 | 
						|
        """
 | 
						|
        Returns a tuple of available optimizer names in their original case.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            Tuple(str): Tuple of available optimizer names.
 | 
						|
        """
 | 
						|
        return tuple(cls.__OPTIMIZERS.keys())
 |