You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
3.1 KiB

from typing import Dict, Final, Tuple, Type, List, Any, Union
from pydantic import BaseModel
from .base import BaseLoss
from .ce import CrossEntropyLoss, CrossEntropyLossParams
from .bce import BCELoss, BCELossParams
from .mse import MSELoss, MSELossParams
from .mse_with_bce import BCE_MSE_Loss
__all__ = [
"CriterionRegistry",
"CrossEntropyLoss", "BCELoss", "MSELoss", "BCE_MSE_Loss",
"CrossEntropyLossParams", "BCELossParams", "MSELossParams"
]
class CriterionRegistry:
"""Registry of loss functions and their parameter classes with case-insensitive lookup."""
__CRITERIONS: Final[Dict[str, Dict[str, Any]]] = {
"CrossEntropyLoss": {
"class": CrossEntropyLoss,
"params": CrossEntropyLossParams,
},
"BCELoss": {
"class": BCELoss,
"params": BCELossParams,
},
"MSELoss": {
"class": MSELoss,
"params": MSELossParams,
},
"BCE_MSE_Loss": {
"class": BCE_MSE_Loss,
"params": (BCELossParams, MSELossParams),
},
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Any]:
"""
Private method to retrieve the criterion entry from the registry using case-insensitive lookup.
Args:
name (str): The name of the loss function.
Returns:
Dict[str, Any]: A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the loss function is not found.
"""
name_lower = name.lower()
mapping = {key.lower(): key for key in cls.__CRITERIONS}
original_key = mapping.get(name_lower)
if original_key is None:
raise ValueError(
f"Criterion '{name}' not found! Available options: {list(cls.__CRITERIONS.keys())}"
)
return cls.__CRITERIONS[original_key]
@classmethod
def get_criterion_class(cls, name: str) -> Type[BaseLoss]:
"""
Retrieves the loss function class by name (case-insensitive).
Args:
name (str): Name of the loss function.
Returns:
Type[BaseLoss]: The loss function class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_criterion_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
"""
Retrieves the loss function parameter class (or classes) by name (case-insensitive).
Args:
name (str): Name of the loss function.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The loss function parameter class or a tuple of parameter classes.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_criterions(cls) -> Tuple[str, ...]:
"""
Returns a tuple of available loss function names in their original case.
Returns:
Tuple[str]: Tuple of available loss function names.
"""
return tuple(cls.__CRITERIONS.keys())