modelV renamed to mediarV

master
laynholt 3 weeks ago
parent 4bdd1ee872
commit 39471b3a03

@ -2,13 +2,13 @@ from torch import nn
from typing import Final, Type, Any
from pydantic import BaseModel
from .model_v import ModelV, ModelVParams
from .mediar_v import MediarV, MediarVParams
__all__ = [
"ModelRegistry",
"ModelV",
"ModelVParams"
"MediarV",
"MediarVParams"
]
@ -17,9 +17,9 @@ class ModelRegistry:
# Single dictionary storing both model classes and parameter classes.
__MODELS: Final[dict[str, dict[str, Type[Any]]]] = {
"ModelV": {
"class": ModelV,
"params": ModelVParams,
"MediarV": {
"class": MediarV,
"params": MediarVParams,
},
}

@ -7,10 +7,10 @@ from segmentation_models_pytorch.base.modules import Activation
from pydantic import BaseModel, ConfigDict
__all__ = ["ModelV"]
__all__ = ["MediarV"]
class ModelVParams(BaseModel):
class MediarVParams(BaseModel):
model_config = ConfigDict(frozen=True)
encoder_name: str = "mit_b5" # Default encoder
@ -26,19 +26,19 @@ class ModelVParams(BaseModel):
def asdict(self) -> dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.ModelV`.
Returns a dictionary of valid parameters for `nn.MediarV`.
Returns:
dict(str, Any): Dictionary of parameters for nn.ModelV.
dict(str, Any): Dictionary of parameters for nn.MediarV.
"""
loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class ModelV(MAnet):
"""ModelV model"""
class MediarV(MAnet):
"""MediarV model"""
def __init__(self, params: ModelVParams) -> None:
def __init__(self, params: MediarVParams) -> None:
# Initialize the MAnet model with provided parameters
super().__init__(**params.asdict())
Loading…
Cancel
Save