You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							98 lines
						
					
					
						
							3.2 KiB
						
					
					
				
			
		
		
	
	
							98 lines
						
					
					
						
							3.2 KiB
						
					
					
				from typing import Final, Type, Any
 | 
						|
from pydantic import BaseModel
 | 
						|
 | 
						|
from .base import BaseScheduler
 | 
						|
from .step import StepLRParams, StepLRScheduler
 | 
						|
from .multi_step import MultiStepLRParams, MultiStepLRScheduler
 | 
						|
from .exponential import ExponentialLRParams, ExponentialLRScheduler
 | 
						|
from .cosine_annealing import CosineAnnealingLRParams, CosineAnnealingLRScheduler
 | 
						|
 | 
						|
__all__ = [
 | 
						|
    "SchedulerRegistry", "BaseScheduler",
 | 
						|
    "StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams",
 | 
						|
    "StepLRScheduler", "MultiStepLRScheduler", "ExponentialLRScheduler", "CosineAnnealingLRScheduler"
 | 
						|
]
 | 
						|
 | 
						|
class SchedulerRegistry:
 | 
						|
    """Registry for learning rate schedulers and their parameter classes with case-insensitive lookup."""
 | 
						|
    
 | 
						|
    __SCHEDULERS: Final[dict[str, dict[str, Type[Any]]]] = {
 | 
						|
        "Step": {
 | 
						|
            "class": StepLRScheduler,
 | 
						|
            "params": StepLRParams,
 | 
						|
        },
 | 
						|
        "Exponential": {
 | 
						|
            "class": ExponentialLRScheduler,
 | 
						|
            "params": ExponentialLRParams,
 | 
						|
        },
 | 
						|
        "MultiStep": {
 | 
						|
            "class": MultiStepLRScheduler,
 | 
						|
            "params": MultiStepLRParams,
 | 
						|
        },
 | 
						|
        "CosineAnnealing": {
 | 
						|
            "class": CosineAnnealingLRScheduler,
 | 
						|
            "params": CosineAnnealingLRParams,
 | 
						|
        },
 | 
						|
    }
 | 
						|
    
 | 
						|
    @classmethod
 | 
						|
    def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
 | 
						|
        """
 | 
						|
        Private method to retrieve the scheduler entry from the registry using case-insensitive lookup.
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            name (str): The name of the scheduler.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            dict(str, Type(Any)): A dictionary containing the keys 'class' and 'params'.
 | 
						|
        
 | 
						|
        Raises:
 | 
						|
            ValueError: If the scheduler is not found.
 | 
						|
        """
 | 
						|
        name_lower = name.lower()
 | 
						|
        mapping = {key.lower(): key for key in cls.__SCHEDULERS}
 | 
						|
        original_key = mapping.get(name_lower)
 | 
						|
        if original_key is None:
 | 
						|
            raise ValueError(
 | 
						|
                f"Scheduler '{name}' not found! Available options: {list(cls.__SCHEDULERS.keys())}"
 | 
						|
            )
 | 
						|
        return cls.__SCHEDULERS[original_key]
 | 
						|
    
 | 
						|
    @classmethod
 | 
						|
    def get_scheduler_class(cls, name: str) -> Type[BaseScheduler]:
 | 
						|
        """
 | 
						|
        Retrieves the scheduler class by name (case-insensitive).
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            name (str): Name of the scheduler.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            Type(BaseScheduler): The scheduler class.
 | 
						|
        """
 | 
						|
        entry = cls.__get_entry(name)
 | 
						|
        return entry["class"]
 | 
						|
    
 | 
						|
    @classmethod
 | 
						|
    def get_scheduler_params(cls, name: str) -> Type[BaseModel]:
 | 
						|
        """
 | 
						|
        Retrieves the scheduler parameter class by name (case-insensitive).
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            name (str): Name of the scheduler.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            Type(BaseModel): The scheduler parameter class.
 | 
						|
        """
 | 
						|
        entry = cls.__get_entry(name)
 | 
						|
        return entry["params"]
 | 
						|
    
 | 
						|
    @classmethod
 | 
						|
    def get_available_schedulers(cls) -> tuple[str, ...]:
 | 
						|
        """
 | 
						|
        Returns a tuple of available scheduler names in their original case.
 | 
						|
        
 | 
						|
        Returns:
 | 
						|
            Tuple(str): Tuple of available scheduler names.
 | 
						|
        """
 | 
						|
        return tuple(cls.__SCHEDULERS.keys())
 |