modelV renamed to mediarV

master
laynholt 3 weeks ago
parent 4bdd1ee872
commit 39471b3a03

@ -2,13 +2,13 @@ from torch import nn
from typing import Final, Type, Any from typing import Final, Type, Any
from pydantic import BaseModel from pydantic import BaseModel
from .model_v import ModelV, ModelVParams from .mediar_v import MediarV, MediarVParams
__all__ = [ __all__ = [
"ModelRegistry", "ModelRegistry",
"ModelV", "MediarV",
"ModelVParams" "MediarVParams"
] ]
@ -17,9 +17,9 @@ class ModelRegistry:
# Single dictionary storing both model classes and parameter classes. # Single dictionary storing both model classes and parameter classes.
__MODELS: Final[dict[str, dict[str, Type[Any]]]] = { __MODELS: Final[dict[str, dict[str, Type[Any]]]] = {
"ModelV": { "MediarV": {
"class": ModelV, "class": MediarV,
"params": ModelVParams, "params": MediarVParams,
}, },
} }

@ -7,10 +7,10 @@ from segmentation_models_pytorch.base.modules import Activation
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
__all__ = ["ModelV"] __all__ = ["MediarV"]
class ModelVParams(BaseModel): class MediarVParams(BaseModel):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
encoder_name: str = "mit_b5" # Default encoder encoder_name: str = "mit_b5" # Default encoder
@ -26,19 +26,19 @@ class ModelVParams(BaseModel):
def asdict(self) -> dict[str, Any]: def asdict(self) -> dict[str, Any]:
""" """
Returns a dictionary of valid parameters for `nn.ModelV`. Returns a dictionary of valid parameters for `nn.MediarV`.
Returns: Returns:
dict(str, Any): Dictionary of parameters for nn.ModelV. dict(str, Any): Dictionary of parameters for nn.MediarV.
""" """
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class ModelV(MAnet): class MediarV(MAnet):
"""ModelV model""" """MediarV model"""
def __init__(self, params: ModelVParams) -> None: def __init__(self, params: MediarVParams) -> None:
# Initialize the MAnet model with provided parameters # Initialize the MAnet model with provided parameters
super().__init__(**params.asdict()) super().__init__(**params.asdict())
Loading…
Cancel
Save