You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

86 lines
2.4 KiB

from torch import nn
from typing import Final, Type, Any
from pydantic import BaseModel
from .model_v import ModelV, ModelVParams
__all__ = [
"ModelRegistry",
"ModelV",
"ModelVParams"
]
class ModelRegistry:
"""Registry for models and their parameter classes with case-insensitive lookup."""
# Single dictionary storing both model classes and parameter classes.
__MODELS: Final[dict[str, dict[str, Type[Any]]]] = {
"ModelV": {
"class": ModelV,
"params": ModelVParams,
},
}
@classmethod
def __get_entry(cls, name: str) -> dict[str, Type[Any]]:
"""
Private method to retrieve the model entry from the registry using case-insensitive lookup.
Args:
name (str): The name of the model.
Returns:
dict(str, Type[Any]): A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the model is not found.
"""
name_lower = name.lower()
mapping = {key.lower(): key for key in cls.__MODELS}
original_key = mapping.get(name_lower)
if original_key is None:
raise ValueError(
f"Model '{name}' not found! Available options: {list(cls.__MODELS.keys())}"
)
return cls.__MODELS[original_key]
@classmethod
def get_model_class(cls, name: str) -> Type[nn.Module]:
"""
Retrieves the model class by name (case-insensitive).
Args:
name (str): Name of the model.
Returns:
Type(torch.nn.Module): The model class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_model_params(cls, name: str) -> Type[BaseModel]:
"""
Retrieves the model parameter class by name (case-insensitive).
Args:
name (str): Name of the model.
Returns:
Type(BaseModel): The model parameter class.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_models(cls) -> tuple[str, ...]:
"""
Returns a tuple of available model names in their original case.
Returns:
Tuple(str): Tuple of available model names.
"""
return tuple(cls.__MODELS.keys())