|
|
|
|
@ -7,10 +7,10 @@ from segmentation_models_pytorch.base.modules import Activation
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
|
|
|
|
|
|
__all__ = ["ModelV"]
|
|
|
|
|
__all__ = ["MediarV"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelVParams(BaseModel):
|
|
|
|
|
class MediarVParams(BaseModel):
|
|
|
|
|
model_config = ConfigDict(frozen=True)
|
|
|
|
|
|
|
|
|
|
encoder_name: str = "mit_b5" # Default encoder
|
|
|
|
|
@ -26,19 +26,19 @@ class ModelVParams(BaseModel):
|
|
|
|
|
|
|
|
|
|
def asdict(self) -> dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Returns a dictionary of valid parameters for `nn.ModelV`.
|
|
|
|
|
Returns a dictionary of valid parameters for `nn.MediarV`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
dict(str, Any): Dictionary of parameters for nn.ModelV.
|
|
|
|
|
dict(str, Any): Dictionary of parameters for nn.MediarV.
|
|
|
|
|
"""
|
|
|
|
|
loss_kwargs = self.model_dump()
|
|
|
|
|
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelV(MAnet):
|
|
|
|
|
"""ModelV model"""
|
|
|
|
|
class MediarV(MAnet):
|
|
|
|
|
"""MediarV model"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, params: ModelVParams) -> None:
|
|
|
|
|
def __init__(self, params: MediarVParams) -> None:
|
|
|
|
|
# Initialize the MAnet model with provided parameters
|
|
|
|
|
super().__init__(**params.asdict())
|
|
|
|
|
|