commit
e00331b503
@ -0,0 +1,6 @@
|
||||
*.pyc
|
||||
__pycache__/
|
||||
**/__pycache__/
|
||||
|
||||
.vscode/
|
||||
*.json
|
@ -0,0 +1,3 @@
|
||||
from .config import Config
|
||||
|
||||
__all__ = ["Config"]
|
@ -0,0 +1,96 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .dataset_config import DatasetConfig
|
||||
|
||||
class Config(BaseModel):
|
||||
model: Dict[str, Union[BaseModel, List[BaseModel]]]
|
||||
dataset_config: DatasetConfig
|
||||
criterion: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
|
||||
optimizer: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
|
||||
scheduler: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
|
||||
|
||||
@staticmethod
|
||||
def __dump_field(value: Any) -> Any:
|
||||
"""
|
||||
Recursively dumps a field if it is a BaseModel or a list/dict of BaseModels.
|
||||
"""
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump()
|
||||
elif isinstance(value, list):
|
||||
return [Config.__dump_field(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: Config.__dump_field(v) for k, v in value.items()}
|
||||
else:
|
||||
return value
|
||||
|
||||
def save_json(self, file_path: str, indent: int = 4) -> None:
|
||||
"""
|
||||
Saves the configuration to a JSON file using dumps of each individual field.
|
||||
|
||||
Args:
|
||||
file_path (str): Destination path for the JSON file.
|
||||
indent (int): Indentation level for the JSON file.
|
||||
"""
|
||||
config_dump = {
|
||||
"model": self.__dump_field(self.model),
|
||||
"dataset_config": self.dataset_config.model_dump()
|
||||
}
|
||||
if self.criterion is not None:
|
||||
config_dump.update({"criterion": self.__dump_field(self.criterion)})
|
||||
if self.optimizer is not None:
|
||||
config_dump.update({"optimizer": self.__dump_field(self.optimizer)})
|
||||
if self.scheduler is not None:
|
||||
config_dump.update({"scheduler": self.__dump_field(self.scheduler)})
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(config_dump, indent=indent))
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_json(cls, file_path: str) -> "Config":
|
||||
"""
|
||||
Loads a configuration from a JSON file and re-instantiates each section using
|
||||
the registry keys to recover the original parameter class(es).
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the JSON file.
|
||||
|
||||
Returns:
|
||||
Config: An instance of Config with the proper parameter classes.
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Parse dataset_config using its Pydantic model.
|
||||
dataset_config = DatasetConfig(**data.get("dataset_config", {}))
|
||||
|
||||
# Helper function to parse registry fields.
|
||||
def parse_field(field_data: Dict[str, Any], registry_getter) -> Dict[str, Union[BaseModel, List[BaseModel]]]:
|
||||
result = {}
|
||||
for key, value in field_data.items():
|
||||
expected = registry_getter(key)
|
||||
# If the registry returns a tuple, then we expect a list of dictionaries.
|
||||
if isinstance(expected, tuple):
|
||||
result[key] = [cls_param(**item) for cls_param, item in zip(expected, value)]
|
||||
else:
|
||||
result[key] = expected(**value)
|
||||
return result
|
||||
|
||||
from core import (
|
||||
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
|
||||
)
|
||||
|
||||
parsed_model = parse_field(data.get("model", {}), lambda key: ModelRegistry.get_model_params(key))
|
||||
parsed_criterion = parse_field(data.get("criterion", {}), lambda key: CriterionRegistry.get_criterion_params(key))
|
||||
parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key))
|
||||
parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key))
|
||||
|
||||
return cls(
|
||||
model=parsed_model,
|
||||
dataset_config=dataset_config,
|
||||
criterion=parsed_criterion,
|
||||
optimizer=parsed_optimizer,
|
||||
scheduler=parsed_scheduler
|
||||
)
|
@ -0,0 +1,246 @@
|
||||
from pydantic import BaseModel, model_validator, field_validator
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import os
|
||||
|
||||
|
||||
class DatasetCommonConfig(BaseModel):
|
||||
"""
|
||||
Common configuration fields shared by both training and testing.
|
||||
"""
|
||||
device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
|
||||
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
|
||||
predictions_dir: str = "." # Directory to save predictions
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_common(self) -> "DatasetCommonConfig":
|
||||
"""
|
||||
Validates that device is non-empty.
|
||||
"""
|
||||
if not self.device:
|
||||
raise ValueError("device must be provided and non-empty")
|
||||
return self
|
||||
|
||||
|
||||
class TrainingSplitInfo(BaseModel):
|
||||
"""
|
||||
Configuration for training mode when data is NOT pre-split (is_split is False).
|
||||
Contains:
|
||||
- split_seed: Seed for splitting.
|
||||
- all_data_dir: Directory containing all data.
|
||||
"""
|
||||
split_seed: int = 0 # Seed for splitting if data is not pre-split
|
||||
all_data_dir: str = "." # Directory containing all data if not pre-split
|
||||
|
||||
class TrainingPreSplitInfo(BaseModel):
|
||||
"""
|
||||
Configuration for training mode when data is pre-split (is_split is True).
|
||||
Contains:
|
||||
- train_dir, valid_dir, test_dir: Directories for training, validation, and testing data.
|
||||
- train_size, valid_size, test_size: Data split ratios or counts.
|
||||
- train_offset, valid_offset, test_offset: Offsets for respective splits.
|
||||
"""
|
||||
train_dir: str = "." # Directory for training data if data is pre-split
|
||||
valid_dir: str = "" # Directory for validation data if data is pre-split
|
||||
test_dir: str = "" # Directory for testing data if data is pre-split
|
||||
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
|
||||
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic)
|
||||
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic)
|
||||
train_offset: int = 0 # Offset for training data
|
||||
valid_offset: int = 0 # Offset for validation data
|
||||
test_offset: int = 0 # Offset for testing data
|
||||
|
||||
|
||||
@field_validator("train_size", "valid_size", "test_size", mode="before")
|
||||
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]:
|
||||
"""
|
||||
Validates size values:
|
||||
- If provided as a float, must be in the range (0, 1].
|
||||
- If provided as an int, must be non-negative.
|
||||
"""
|
||||
if isinstance(v, float):
|
||||
if not (0 <= v <= 1):
|
||||
raise ValueError("When provided as a float, size must be in the range (0, 1]")
|
||||
elif isinstance(v, int):
|
||||
if v < 0:
|
||||
raise ValueError("When provided as an int, size must be non-negative")
|
||||
else:
|
||||
raise ValueError("Size must be either an int or a float")
|
||||
return v
|
||||
|
||||
|
||||
class DatasetTrainingConfig(BaseModel):
|
||||
"""
|
||||
Main training configuration.
|
||||
Contains:
|
||||
- is_split: Determines whether data is pre-split.
|
||||
- pre_split: Configuration for when data is NOT pre-split.
|
||||
- split: Configuration for when data is pre-split.
|
||||
- Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights.
|
||||
|
||||
Both pre_split and split objects are always created, but only the one corresponding
|
||||
to is_split is validated.
|
||||
"""
|
||||
is_split: bool = False # Whether the data is already split into train/validation sets
|
||||
pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo()
|
||||
split: TrainingSplitInfo = TrainingSplitInfo()
|
||||
|
||||
batch_size: int = 1 # Batch size for training
|
||||
num_epochs: int = 100 # Number of training epochs
|
||||
val_freq: int = 1 # Frequency of validation during training
|
||||
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
|
||||
pretrained_weights: str = "" # Path to pretrained weights for training
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_split_info(self) -> "DatasetTrainingConfig":
|
||||
"""
|
||||
Conditionally validates the nested split objects:
|
||||
- If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist).
|
||||
- If is_split is False, validates split (all_data_dir must be non-empty and exist).
|
||||
"""
|
||||
if not self.is_split:
|
||||
if not self.split.all_data_dir:
|
||||
raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split")
|
||||
if not os.path.exists(self.split.all_data_dir):
|
||||
raise ValueError(f"Path for all_data_dir does not exist: {self.split.all_data_dir}")
|
||||
else:
|
||||
if not self.pre_split.train_dir:
|
||||
raise ValueError("When is_split is True, train_dir must be provided and non-empty in split")
|
||||
if not os.path.exists(self.pre_split.train_dir):
|
||||
raise ValueError(f"Path for train_dir does not exist: {self.pre_split.train_dir}")
|
||||
if self.pre_split.valid_dir and not os.path.exists(self.pre_split.valid_dir):
|
||||
raise ValueError(f"Path for valid_dir does not exist: {self.pre_split.valid_dir}")
|
||||
if self.pre_split.test_dir and not os.path.exists(self.pre_split.test_dir):
|
||||
raise ValueError(f"Path for test_dir does not exist: {self.pre_split.test_dir}")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_numeric_fields(self) -> "DatasetTrainingConfig":
|
||||
"""
|
||||
Validates numeric fields:
|
||||
- batch_size and num_epochs must be > 0.
|
||||
- val_freq must be >= 0.
|
||||
"""
|
||||
if self.batch_size <= 0:
|
||||
raise ValueError("batch_size must be > 0")
|
||||
if self.num_epochs <= 0:
|
||||
raise ValueError("num_epochs must be > 0")
|
||||
if self.val_freq < 0:
|
||||
raise ValueError("val_freq must be >= 0")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_pretrained(self) -> "DatasetTrainingConfig":
|
||||
"""
|
||||
Validates that pretrained_weights is provided and exists.
|
||||
"""
|
||||
if self.pretrained_weights and not os.path.exists(self.pretrained_weights):
|
||||
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
|
||||
return self
|
||||
|
||||
|
||||
class DatasetTestingConfig(BaseModel):
|
||||
"""
|
||||
Configuration fields used only in testing mode.
|
||||
"""
|
||||
test_dir: str = "." # Test data directory; must be non-empty
|
||||
test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
|
||||
test_offset: int = 0 # Offset for testing data
|
||||
use_ensemble: bool = False # Flag to use ensemble mode in testing
|
||||
ensemble_pretrained_weights1: str = "."
|
||||
ensemble_pretrained_weights2: str = "."
|
||||
pretrained_weights: str = "."
|
||||
|
||||
@field_validator("test_size", mode="before")
|
||||
def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]:
|
||||
"""
|
||||
Validates the test_size value.
|
||||
"""
|
||||
if isinstance(v, float):
|
||||
if not (0 < v <= 1):
|
||||
raise ValueError("When provided as a float, test_size must be in the range (0, 1]")
|
||||
elif isinstance(v, int):
|
||||
if v < 0:
|
||||
raise ValueError("When provided as an int, test_size must be non-negative")
|
||||
else:
|
||||
raise ValueError("test_size must be either an int or a float")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_testing(self) -> "DatasetTestingConfig":
|
||||
"""
|
||||
Validates the testing configuration:
|
||||
- test_dir must be non-empty and exist.
|
||||
- If use_ensemble is True, both ensemble_pretrained_weights1 and ensemble_pretrained_weights2 must be provided and exist.
|
||||
- If use_ensemble is False, pretrained_weights must be provided and exist.
|
||||
"""
|
||||
if not self.test_dir:
|
||||
raise ValueError("In testing configuration, test_dir must be provided and non-empty")
|
||||
if not os.path.exists(self.test_dir):
|
||||
raise ValueError(f"Path for test_dir does not exist: {self.test_dir}")
|
||||
if self.use_ensemble:
|
||||
for field in ["ensemble_pretrained_weights1", "ensemble_pretrained_weights2"]:
|
||||
value = getattr(self, field)
|
||||
if not value:
|
||||
raise ValueError(f"When use_ensemble is True, {field} must be provided and non-empty")
|
||||
if not os.path.exists(value):
|
||||
raise ValueError(f"Path for {field} does not exist: {value}")
|
||||
else:
|
||||
if not self.pretrained_weights:
|
||||
raise ValueError("When use_ensemble is False, pretrained_weights must be provided and non-empty")
|
||||
if not os.path.exists(self.pretrained_weights):
|
||||
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
|
||||
if self.test_offset < 0:
|
||||
raise ValueError("test_offset must be >= 0")
|
||||
return self
|
||||
|
||||
|
||||
class DatasetConfig(BaseModel):
|
||||
"""
|
||||
Main dataset configuration that groups fields into nested models for a structured and readable JSON.
|
||||
"""
|
||||
is_training: bool = True # Flag indicating whether the configuration is for training (True) or testing (False)
|
||||
common: DatasetCommonConfig = DatasetCommonConfig()
|
||||
training: DatasetTrainingConfig = DatasetTrainingConfig()
|
||||
testing: DatasetTestingConfig = DatasetTestingConfig()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_config(self) -> "DatasetConfig":
|
||||
"""
|
||||
Validates the overall dataset configuration:
|
||||
"""
|
||||
if self.is_training:
|
||||
if self.training is None:
|
||||
raise ValueError("Training configuration must be provided when is_training is True")
|
||||
# Check predictions_dir if training.split.test_dir and test_size are set
|
||||
if self.training.pre_split.test_dir and self.training.pre_split.test_size > 0:
|
||||
if not self.common.predictions_dir:
|
||||
raise ValueError("predictions_dir must be provided when training.split.test_dir and test_size are non-zero")
|
||||
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
|
||||
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
|
||||
else:
|
||||
if self.testing is None:
|
||||
raise ValueError("Testing configuration must be provided when is_training is False")
|
||||
if self.testing.test_dir and self.testing.test_size > 0:
|
||||
if not self.common.predictions_dir:
|
||||
raise ValueError("predictions_dir must be provided when testing.test_dir and test_size are non-zero")
|
||||
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
|
||||
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
|
||||
return self
|
||||
|
||||
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Dumps only the relevant configuration depending on the is_training flag.
|
||||
Only the nested configuration (training or testing) along with common fields is returned.
|
||||
"""
|
||||
if self.is_training:
|
||||
return {
|
||||
"is_training": self.is_training,
|
||||
"common": self.common.model_dump(),
|
||||
"training": self.training.model_dump() if self.training else {}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"is_training": self.is_training,
|
||||
"common": self.common.model_dump(),
|
||||
"testing": self.testing.model_dump() if self.testing else {}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
from .models import ModelRegistry
|
||||
from .criteria import CriterionRegistry
|
||||
from .optimizers import OptimizerRegistry
|
||||
from .schedulers import SchedulerRegistry
|
||||
|
||||
|
||||
__all__ = ["ModelRegistry", "CriterionRegistry", "OptimizerRegistry", "SchedulerRegistry"]
|
@ -0,0 +1,97 @@
|
||||
from typing import Dict, Final, Tuple, Type, List, Any, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .base import BaseLoss
|
||||
from .ce import CrossEntropyLoss, CrossEntropyLossParams
|
||||
from .bce import BCELoss, BCELossParams
|
||||
from .mse import MSELoss, MSELossParams
|
||||
from .mse_with_bce import BCE_MSE_Loss
|
||||
|
||||
__all__ = [
|
||||
"CriterionRegistry",
|
||||
"CrossEntropyLoss", "BCELoss", "MSELoss", "BCE_MSE_Loss",
|
||||
"CrossEntropyLossParams", "BCELossParams", "MSELossParams"
|
||||
]
|
||||
|
||||
class CriterionRegistry:
|
||||
"""Registry of loss functions and their parameter classes with case-insensitive lookup."""
|
||||
|
||||
__CRITERIONS: Final[Dict[str, Dict[str, Any]]] = {
|
||||
"CrossEntropyLoss": {
|
||||
"class": CrossEntropyLoss,
|
||||
"params": CrossEntropyLossParams,
|
||||
},
|
||||
"BCELoss": {
|
||||
"class": BCELoss,
|
||||
"params": BCELossParams,
|
||||
},
|
||||
"MSELoss": {
|
||||
"class": MSELoss,
|
||||
"params": MSELossParams,
|
||||
},
|
||||
"BCE_MSE_Loss": {
|
||||
"class": BCE_MSE_Loss,
|
||||
"params": (BCELossParams, MSELossParams),
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __get_entry(cls, name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Private method to retrieve the criterion entry from the registry using case-insensitive lookup.
|
||||
|
||||
Args:
|
||||
name (str): The name of the loss function.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the keys 'class' and 'params'.
|
||||
|
||||
Raises:
|
||||
ValueError: If the loss function is not found.
|
||||
"""
|
||||
name_lower = name.lower()
|
||||
mapping = {key.lower(): key for key in cls.__CRITERIONS}
|
||||
original_key = mapping.get(name_lower)
|
||||
if original_key is None:
|
||||
raise ValueError(
|
||||
f"Criterion '{name}' not found! Available options: {list(cls.__CRITERIONS.keys())}"
|
||||
)
|
||||
return cls.__CRITERIONS[original_key]
|
||||
|
||||
@classmethod
|
||||
def get_criterion_class(cls, name: str) -> Type[BaseLoss]:
|
||||
"""
|
||||
Retrieves the loss function class by name (case-insensitive).
|
||||
|
||||
Args:
|
||||
name (str): Name of the loss function.
|
||||
|
||||
Returns:
|
||||
Type[BaseLoss]: The loss function class.
|
||||
"""
|
||||
entry = cls.__get_entry(name)
|
||||
return entry["class"]
|
||||
|
||||
@classmethod
|
||||
def get_criterion_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
|
||||
"""
|
||||
Retrieves the loss function parameter class (or classes) by name (case-insensitive).
|
||||
|
||||
Args:
|
||||
name (str): Name of the loss function.
|
||||
|
||||
Returns:
|
||||
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The loss function parameter class or a tuple of parameter classes.
|
||||
"""
|
||||
entry = cls.__get_entry(name)
|
||||
return entry["params"]
|
||||
|
||||
@classmethod
|
||||
def get_available_criterions(cls) -> List[str]:
|
||||
"""
|
||||
Returns a list of available loss function names in their original case.
|
||||
|
||||
Returns:
|
||||
List[str]: List of available loss function names.
|
||||
"""
|
||||
return list(cls.__CRITERIONS.keys())
|
@ -0,0 +1,44 @@
|
||||
import abc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Any, Optional
|
||||
from monai.metrics.cumulative_average import CumulativeAverage
|
||||
|
||||
|
||||
class BaseLoss(abc.ABC):
|
||||
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the loss between true labels and prediction outputs.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): Model predictions.
|
||||
target (torch.Tensor): Ground truth.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The total loss value.
|
||||
"""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_loss_metrics(self) -> Dict[str, float]:
|
||||
"""
|
||||
Retrieves the tracked loss metrics.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: A dictionary containing the loss name and average loss value.
|
||||
"""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset_metrics(self):
|
||||
"""Resets the stored loss metrics."""
|
||||
|
@ -0,0 +1,98 @@
|
||||
from .base import *
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class BCELossParams(BaseModel):
|
||||
"""
|
||||
Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`.
|
||||
"""
|
||||
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
||||
|
||||
weight: Optional[torch.Tensor] = None # Sample weights
|
||||
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
|
||||
pos_weight: Optional[torch.Tensor] = None # Used only for BCEWithLogitsLoss
|
||||
|
||||
def asdict(self, with_logits: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`.
|
||||
|
||||
- If `with_logits=False`, `pos_weight` is **removed** to avoid errors.
|
||||
- Ensures only the valid parameters are passed based on the loss function.
|
||||
|
||||
Args:
|
||||
with_logits (bool): If `True`, includes `pos_weight` (for `nn.BCEWithLogitsLoss`).
|
||||
If `False`, removes `pos_weight` (for `nn.BCELoss`).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Filtered dictionary of parameters.
|
||||
"""
|
||||
loss_kwargs = self.model_dump()
|
||||
if not with_logits:
|
||||
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
|
||||
|
||||
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
|
||||
|
||||
|
||||
class BCELoss(BaseLoss):
|
||||
"""
|
||||
Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, bce_params: Optional[BCELossParams] = None, with_logits: bool = False):
|
||||
"""
|
||||
Initializes the loss function with optional BCELoss parameters.
|
||||
|
||||
Args:
|
||||
bce_params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None).
|
||||
"""
|
||||
super().__init__()
|
||||
_bce_params = bce_params.asdict(with_logits=with_logits) if bce_params is not None else {}
|
||||
|
||||
# Initialize loss functions with user-provided parameters or PyTorch defaults
|
||||
self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params)
|
||||
|
||||
# Using CumulativeAverage from MONAI to track loss metrics
|
||||
self.loss_bce_metric = CumulativeAverage()
|
||||
|
||||
|
||||
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the loss between true labels and prediction outputs.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
|
||||
target (torch.Tensor): Ground truth labels in one-hot format.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The total loss value.
|
||||
"""
|
||||
# Ensure target is on the same device as outputs
|
||||
assert (
|
||||
target.device == outputs.device
|
||||
), (
|
||||
"Target tensor must be moved to the same device as outputs "
|
||||
"before calling forward()."
|
||||
)
|
||||
|
||||
loss = self.bce_loss(outputs, target)
|
||||
self.loss_bce_metric.append(loss.item())
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_loss_metrics(self) -> Dict[str, float]:
|
||||
"""
|
||||
Retrieves the tracked loss metrics.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: A dictionary containing the average BCE loss.
|
||||
"""
|
||||
return {
|
||||
"loss": round(self.loss_bce_metric.aggregate().item(), 4),
|
||||
}
|
||||
|
||||
|
||||
def reset_metrics(self):
|
||||
"""Resets the stored loss metrics."""
|
||||
self.loss_bce_metric.reset()
|
@ -0,0 +1,90 @@
|
||||
from .base import *
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class CrossEntropyLossParams(BaseModel):
|
||||
"""
|
||||
Class for handling parameters for `nn.CrossEntropyLoss`.
|
||||
"""
|
||||
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
||||
|
||||
weight: Optional[torch.Tensor] = None
|
||||
ignore_index: int = -100
|
||||
reduction: Literal["none", "mean", "sum"] = "mean"
|
||||
label_smoothing: float = 0.0
|
||||
|
||||
def asdict(self):
|
||||
"""
|
||||
Returns a dictionary of valid parameters for `nn.CrossEntropyLoss`.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss.
|
||||
"""
|
||||
loss_kwargs = self.model_dump()
|
||||
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLoss(BaseLoss):
|
||||
"""
|
||||
Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, ce_params: Optional[CrossEntropyLossParams] = None):
|
||||
"""
|
||||
Initializes the loss function with optional CrossEntropyLoss parameters.
|
||||
|
||||
Args:
|
||||
ce_params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None).
|
||||
"""
|
||||
super().__init__()
|
||||
_ce_params = ce_params.asdict() if ce_params is not None else {}
|
||||
|
||||
# Initialize loss functions with user-provided parameters or PyTorch defaults
|
||||
self.ce_loss = nn.CrossEntropyLoss(**_ce_params)
|
||||
|
||||
# Using CumulativeAverage from MONAI to track loss metrics
|
||||
self.loss_ce_metric = CumulativeAverage()
|
||||
|
||||
|
||||
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the loss between true labels and prediction outputs.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
|
||||
target (torch.Tensor): Ground truth labels of shape (batch_size, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The total loss value.
|
||||
"""
|
||||
# Ensure target is on the same device as outputs
|
||||
assert (
|
||||
target.device == outputs.device
|
||||
), (
|
||||
"Target tensor must be moved to the same device as outputs "
|
||||
"before calling forward()."
|
||||
)
|
||||
|
||||
loss = self.ce_loss(outputs, target)
|
||||
self.loss_ce_metric.append(loss.item())
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_loss_metrics(self) -> Dict[str, float]:
|
||||
"""
|
||||
Retrieves the tracked loss metrics.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: A dictionary containing the average CrossEntropy loss.
|
||||
"""
|
||||
return {
|
||||
"loss": round(self.loss_ce_metric.aggregate().item(), 4),
|
||||
}
|
||||
|
||||
|
||||
def reset_metrics(self):
|
||||
"""Resets the stored loss metrics."""
|
||||
self.loss_ce_metric.reset()
|
@ -0,0 +1,83 @@
|
||||
from .base import *
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class MSELossParams(BaseModel):
|
||||
"""
|
||||
Class for MSE loss parameters, compatible with `nn.MSELoss`.
|
||||
"""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
reduction: Literal["none", "mean", "sum"] = "mean"
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of valid parameters for `nn.MSELoss`.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of parameters for `nn.MSELoss`.
|
||||
"""
|
||||
loss_kwargs = self.model_dump()
|
||||
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
|
||||
|
||||
|
||||
class MSELoss(BaseLoss):
|
||||
"""
|
||||
Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, mse_params: Optional[MSELossParams] = None):
|
||||
"""
|
||||
Initializes the loss function with optional MSELoss parameters.
|
||||
|
||||
Args:
|
||||
mse_params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None).
|
||||
"""
|
||||
super().__init__()
|
||||
_mse_params = mse_params.asdict() if mse_params is not None else {}
|
||||
|
||||
# Initialize MSE loss with user-provided parameters or PyTorch defaults
|
||||
self.mse_loss = nn.MSELoss(**_mse_params)
|
||||
|
||||
# Using CumulativeAverage from MONAI to track loss metrics
|
||||
self.loss_mse_metric = CumulativeAverage()
|
||||
|
||||
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the loss between true values and predictions.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
|
||||
target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The total loss value.
|
||||
"""
|
||||
# Ensure target is on the same device as outputs
|
||||
assert (
|
||||
target.device == outputs.device
|
||||
), (
|
||||
"Target tensor must be moved to the same device as outputs "
|
||||
"before calling forward()."
|
||||
)
|
||||
|
||||
loss = self.mse_loss(outputs, target)
|
||||
self.loss_mse_metric.append(loss.item())
|
||||
|
||||
return loss
|
||||
|
||||
def get_loss_metrics(self) -> Dict[str, float]:
|
||||
"""
|
||||
Retrieves the tracked loss metrics.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: A dictionary containing the average MSE loss.
|
||||
"""
|
||||
return {
|
||||
"loss": round(self.loss_mse_metric.aggregate().item(), 4),
|
||||
}
|
||||
|
||||
def reset_metrics(self):
|
||||
"""Resets the stored loss metrics."""
|
||||
self.loss_mse_metric.reset()
|
@ -0,0 +1,105 @@
|
||||
from .base import *
|
||||
from .bce import BCELossParams
|
||||
from .mse import MSELossParams
|
||||
|
||||
|
||||
class BCE_MSE_Loss(BaseLoss):
|
||||
"""
|
||||
Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
bce_params: Optional[BCELossParams] = None,
|
||||
mse_params: Optional[MSELossParams] = None,
|
||||
bce_with_logits: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the loss function with optional BCE and MSE parameters.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of output classes, used for target shifting.
|
||||
bce_params (Optional[BCELossParams]): Parameters for BCEWithLogitsLoss or BCELoss (default: None).
|
||||
mse_params (Optional[MSELossParams]): Parameters for MSELoss (default: None).
|
||||
bce_with_logits (bool): If True, uses BCEWithLogitsLoss; otherwise, uses BCELoss.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
# Process BCE parameters
|
||||
_bce_params = bce_params.asdict(bce_with_logits) if bce_params is not None else {}
|
||||
|
||||
# Choose BCE loss function
|
||||
self.bce_loss = (
|
||||
nn.BCEWithLogitsLoss(**_bce_params) if bce_with_logits else nn.BCELoss(**_bce_params)
|
||||
)
|
||||
|
||||
# Process MSE parameters
|
||||
_mse_params = mse_params.asdict() if mse_params is not None else {}
|
||||
|
||||
# Initialize MSE loss
|
||||
self.mse_loss = nn.MSELoss(**_mse_params)
|
||||
|
||||
# Using CumulativeAverage from MONAI to track loss metrics
|
||||
self.loss_bce_metric = CumulativeAverage()
|
||||
self.loss_mse_metric = CumulativeAverage()
|
||||
|
||||
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the loss between true labels and prediction outputs.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
|
||||
target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The total loss value.
|
||||
"""
|
||||
# Ensure target is on the same device as outputs
|
||||
assert (
|
||||
target.device == outputs.device
|
||||
), (
|
||||
"Target tensor must be moved to the same device as outputs "
|
||||
"before calling forward()."
|
||||
)
|
||||
|
||||
# Cell Recognition Loss
|
||||
cellprob_loss = self.bce_loss(
|
||||
outputs[:, -self.num_classes:], target[:, self.num_classes:2 * self.num_classes].float()
|
||||
)
|
||||
|
||||
# Cell Distinction Loss
|
||||
gradflow_loss = 0.5 * self.mse_loss(
|
||||
outputs[:, :2 * self.num_classes], 5.0 * target[:, 2 * self.num_classes:]
|
||||
)
|
||||
|
||||
# Total loss
|
||||
total_loss = cellprob_loss + gradflow_loss
|
||||
|
||||
# Track individual losses
|
||||
self.loss_bce_metric.append(cellprob_loss.item())
|
||||
self.loss_mse_metric.append(gradflow_loss.item())
|
||||
|
||||
return total_loss
|
||||
|
||||
def get_loss_metrics(self) -> Dict[str, float]:
|
||||
"""
|
||||
Retrieves the tracked loss metrics.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: A dictionary containing the average BCE and MSE loss.
|
||||
"""
|
||||
return {
|
||||
"bce_loss": round(self.loss_bce_metric.aggregate().item(), 4),
|
||||
"mse_loss": round(self.loss_mse_metric.aggregate().item(), 4),
|
||||
"loss": round(
|
||||
self.loss_bce_metric.aggregate().item() + self.loss_mse_metric.aggregate().item(), 4
|
||||
),
|
||||
}
|
||||
|
||||
def reset_metrics(self):
|
||||
"""Resets the stored loss metrics."""
|
||||
self.loss_bce_metric.reset()
|
||||
self.loss_mse_metric.reset()
|
@ -0,0 +1,85 @@
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Final, Tuple, Type, Any, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .model_v import ModelV, ModelVParams
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModelRegistry",
|
||||
"ModelV",
|
||||
"ModelVParams"
|
||||
]
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Registry for models and their parameter classes with case-insensitive lookup."""
|
||||
|
||||
# Single dictionary storing both model classes and parameter classes.
|
||||
__MODELS: Final[Dict[str, Dict[str, Type[Any]]]] = {
|
||||
"ModelV": {
|
||||
"class": ModelV,
|
||||
"params": ModelVParams,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
|
||||
"""
|
||||
Private method to retrieve the model entry from the registry using case-insensitive lookup.
|
||||
|
||||
Args:
|
||||
name (str): The name of the model.
|
||||
|
||||
Returns:
|
||||
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not found.
|
||||
"""
|
||||
name_lower = name.lower()
|
||||
mapping = {key.lower(): key for key in cls.__MODELS}
|
||||
original_key = mapping.get(name_lower)
|
||||
if original_key is None:
|
||||
raise ValueError(
|
||||
f"Model '{name}' not found! Available options: {list(cls.__MODELS.keys())}"
|
||||
)
|
||||
return cls.__MODELS[original_key]
|
||||
|
||||
@classmethod
|
||||
def get_model_class(cls, name: str) -> Type[nn.Module]:
|
||||
"""
|
||||
Retrieves the model class by name (case-insensitive).
|
||||
|
||||
Args:
|
||||
name (str): Name of the model.
|
||||
|
||||
Returns:
|
||||
Type[nn.Module]: The model class.
|
||||
"""
|
||||
entry = cls.__get_entry(name)
|
||||
return entry["class"]
|
||||
|
||||
@classmethod
|
||||
def get_model_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
|
||||
"""
|
||||
Retrieves the model parameter class by name (case-insensitive).
|
||||
|
||||
Args:
|
||||
name (str): Name of the model.
|
||||
|
||||
Returns:
|
||||
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The model parameter class or a tuple of parameter classes.
|
||||
"""
|
||||
entry = cls.__get_entry(name)
|
||||
return entry["params"]
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> List[str]:
|
||||
"""
|
||||
Returns a list of available model names in their original case.
|
||||
|
||||
Returns:
|
||||
List[str]: List of available model names.
|
||||
"""
|
||||
return list(cls.__MODELS.keys())
|
@ -0,0 +1,117 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from segmentation_models_pytorch import MAnet
|
||||
from segmentation_models_pytorch.base.modules import Activation
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
__all__ = ["ModelV"]
|
||||
|
||||
|
||||
class ModelVParams(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
encoder_name: str = "mit_b5" # Default encoder
|
||||
encoder_weights: Optional[str] = "imagenet" # Pre-trained weights
|
||||
decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration
|
||||
decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels
|
||||
in_channels: int = 3 # Number of input channels
|
||||
out_classes: int = 3 # Number of output classes
|
||||
|
||||
def asdict(self):
|
||||
"""
|
||||
Returns a dictionary of valid parameters for `nn.ModelV`.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of parameters for nn.ModelV.
|
||||
"""
|
||||
loss_kwargs = self.model_dump()
|
||||
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
|
||||
|
||||
|
||||
class ModelV(MAnet):
|
||||
"""ModelV model"""
|
||||
|
||||
def __init__(self, params: ModelVParams) -> None:
|
||||
# Initialize the MAnet model with provided parameters
|
||||
super().__init__(**params.asdict())
|
||||
|
||||
# Remove the default segmentation head as it's not used in this architecture
|
||||
self.segmentation_head = None
|
||||
|
||||
# Modify all activation functions in the encoder and decoder from ReLU to Mish
|
||||
_convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
|
||||
_convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
|
||||
|
||||
# Add custom segmentation heads for different segmentation tasks
|
||||
self.cellprob_head = DeepSegmentationHead(
|
||||
in_channels=params.decoder_channels[-1], out_channels=params.out_classes
|
||||
)
|
||||
self.gradflow_head = DeepSegmentationHead(
|
||||
in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass through the network"""
|
||||
# Ensure the input shape is correct
|
||||
self.check_input_shape(x)
|
||||
|
||||
# Encode the input and then decode it
|
||||
features = self.encoder(x)
|
||||
decoder_output = self.decoder(*features)
|
||||
|
||||
# Generate masks for cell probability and gradient flows
|
||||
cellprob_mask = self.cellprob_head(decoder_output)
|
||||
gradflow_mask = self.gradflow_head(decoder_output)
|
||||
|
||||
# Concatenate the masks for output
|
||||
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
class DeepSegmentationHead(nn.Sequential):
|
||||
"""Custom segmentation head for generating specific masks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
activation: Optional[str] = None,
|
||||
upsampling: int = 1,
|
||||
) -> None:
|
||||
# Define a sequence of layers for the segmentation head
|
||||
layers: List[nn.Module] = [
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels // 2,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
),
|
||||
nn.Mish(inplace=True),
|
||||
nn.BatchNorm2d(in_channels // 2),
|
||||
nn.Conv2d(
|
||||
in_channels // 2,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
),
|
||||
nn.UpsamplingBilinear2d(scale_factor=upsampling)
|
||||
if upsampling > 1
|
||||
else nn.Identity(),
|
||||
Activation(activation) if activation else nn.Identity(),
|
||||
]
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
def _convert_activations(module: nn.Module, from_activation: type, to_activation: nn.Module) -> None:
|
||||
"""Recursively convert activation functions in a module"""
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, from_activation):
|
||||
setattr(module, name, to_activation)
|
||||
else:
|
||||
_convert_activations(child, from_activation, to_activation)
|
@ -0,0 +1,92 @@
|
||||
import torch.optim as optim
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Final, Tuple, Type, List, Any, Union
|
||||
|
||||
from .adam import AdamParams
|
||||
from .adamw import AdamWParams
|
||||
from .sgd import SGDParams
|
||||
|
||||
__all__ = [
|
||||
"OptimizerRegistry",
|
||||
"AdamParams", "AdamWParams", "SGDParams"
|
||||
]
|
||||
|
||||
class OptimizerRegistry:
|
||||
"""Registry for optimizers and their parameter classes with case-insensitive lookup."""
|
||||
|
||||
# Single dictionary storing both optimizer classes and parameter classes.
|
||||
__OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
|
||||
"SGD": {
|
||||
"class": optim.SGD,
|
||||
"params": SGDParams,
|
||||
},
|
||||
"Adam": {
|
||||
"class": optim.Adam,
|
||||
"params": AdamParams,
|
||||
},
|
||||
"AdamW": {
|
||||
"class": optim.AdamW,
|
||||
"params": AdamWParams,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
|
||||
"""
|
||||
Private method to retrieve the optimizer entry from the registry using case-insensitive lookup.
|
||||
|
||||
Args:
|
||||
name (str): The name of the optimizer.
|
||||
|
||||
Returns:
|
||||
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
|
||||
|
||||
Raises:
|
||||
ValueError: If the optimizer is not found.
|
||||
"""
|
||||
name_lower = name.lower()
|
||||
mapping = {key.lower(): key for key in cls.__OPTIMIZERS}
|
||||
original_key = mapping.get(name_lower)
|
||||
if original_key is None:
|
||||
raise ValueError(
|
||||
f"Optimizer '{name}' not found! Available options: {list(cls.__OPTIMIZERS.keys())}"
|
||||
)
|
||||
return cls.__OPTIMIZERS[original_key]
|
||||
|
||||
@classmethod
|
||||
def get_optimizer_class(cls, name: str) -> Type[optim.Optimizer]:
|
||||
"""
|
||||
Retrieves the optimizer class by name (case-insensitive).
|
||||
|
||||
Args:
|
||||
name (str): Name of the optimizer.
|
||||
|
||||
Returns:
|
||||
Type[optim.Optimizer]: The optimizer class.
|
||||
"""
|
||||
entry = cls.__get_entry(name)
|
||||
return entry["class"]
|
||||
|
||||
@classmethod
|
||||
def get_optimizer_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
|
||||
"""
|
||||
Retrieves the optimizer parameter class by name (case-insensitive).
|
||||
|
||||
Args:
|
||||
name (str): Name of the optimizer.
|
||||
|
||||
Returns:
|
||||
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The optimizer parameter class or a tuple of parameter classes.
|
||||
"""
|
||||
entry = cls.__get_entry(name)
|
||||
return entry["params"]
|
||||
|
||||
@classmethod
|
||||
def get_available_optimizers(cls) -> List[str]:
|
||||
"""
|
||||
Returns a list of available optimizer names in their original case.
|
||||
|
||||
Returns:
|
||||
List[str]: List of available optimizer names.
|
||||
"""
|
||||
return list(cls.__OPTIMIZERS.keys())
|
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Tuple
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class AdamParams(BaseModel):
|
||||
"""Configuration for `torch.optim.Adam` optimizer."""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
lr: float = 1e-3 # Learning rate
|
||||
betas: Tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages
|
||||
eps: float = 1e-8 # Term added to denominator for numerical stability
|
||||
weight_decay: float = 0.0 # L2 regularization
|
||||
amsgrad: bool = False # Whether to use the AMSGrad variant
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of valid parameters for `torch.optim.Adam`."""
|
||||
return self.model_dump()
|
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Tuple
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
class AdamWParams(BaseModel):
|
||||
"""Configuration for `torch.optim.AdamW` optimizer."""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
lr: float = 1e-3 # Learning rate
|
||||
betas: Tuple[float, ...] = (0.9, 0.999) # Adam coefficients
|
||||
eps: float = 1e-8 # Numerical stability
|
||||
weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay)
|
||||
amsgrad: bool = False # Whether to use the AMSGrad variant
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of valid parameters for `torch.optim.AdamW`."""
|
||||
return self.model_dump()
|
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class SGDParams(BaseModel):
|
||||
"""Configuration for `torch.optim.SGD` optimizer."""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
lr: float = 1e-3 # Learning rate
|
||||
momentum: float = 0.0 # Momentum factor
|
||||
dampening: float = 0.0 # Dampening for momentum
|
||||
weight_decay: float = 0.0 # L2 penalty
|
||||
nesterov: bool = False # Enables Nesterov momentum
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of valid parameters for `torch.optim.SGD`."""
|
||||
return self.model_dump()
|
@ -0,0 +1,96 @@
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
from typing import Dict, Final, Tuple, Type, List, Any, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .step import StepLRParams
|
||||
from .multi_step import MultiStepLRParams
|
||||
from .exponential import ExponentialLRParams
|
||||
from .cosine_annealing import CosineAnnealingLRParams
|
||||
|
||||
__all__ = [
|
||||
"SchedulerRegistry",
|
||||
"StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams"
|
||||
]
|
||||
|
||||
class SchedulerRegistry:
|
||||
"""Registry for learning rate schedulers and their parameter classes with case-insensitive lookup."""
|
||||
|
||||
__SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
|
||||
"Step": {
|
||||
"class": lr_scheduler.StepLR,
|
||||
"params": StepLRParams,
|
||||
},
|
||||
"Exponential": {
|
||||
"class": lr_scheduler.ExponentialLR,
|
||||
"params": ExponentialLRParams,
|
||||
},
|
||||
"MultiStep": {
|
||||
"class": lr_scheduler.MultiStepLR,
|
||||
"params": MultiStepLRParams,
|
||||
},
|
||||
"CosineAnnealing": {
|
||||
"class": lr_scheduler.CosineAnnealingLR,
|
||||
"params": CosineAnnealingLRParams,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
|
||||
"""
|
||||
Private method to retrieve the scheduler entry from the registry using case-insensitive lookup.
|
||||
|
||||
Args:
|
||||
name (str): The name of the scheduler.
|
||||
|
||||
Returns:
|
||||
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
|
||||
|
||||
Raises:
|
||||
ValueError: If the scheduler is not found.
|
||||
"""
|
||||
name_lower = name.lower()
|
||||
mapping = {key.lower(): key for key in cls.__SCHEDULERS}
|
||||
original_key = mapping.get(name_lower)
|
||||
if original_key is None:
|
||||
raise ValueError(
|
||||
f"Scheduler '{name}' not found! Available options: {list(cls.__SCHEDULERS.keys())}"
|
||||
)
|
||||
return cls.__SCHEDULERS[original_key]
|
||||
|
||||
@classmethod
|
||||
def get_scheduler_class(cls, name: str) -> Type[lr_scheduler.LRScheduler]:
|
||||
"""
|
||||
Retrieves the scheduler class by name (case-insensitive).
|
||||
|
||||
Args:
|
||||
name (str): Name of the scheduler.
|
||||
|
||||
Returns:
|
||||
Type[lr_scheduler.LRScheduler]: The scheduler class.
|
||||
"""
|
||||
entry = cls.__get_entry(name)
|
||||
return entry["class"]
|
||||
|
||||
@classmethod
|
||||
def get_scheduler_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
|
||||
"""
|
||||
Retrieves the scheduler parameter class by name (case-insensitive).
|
||||
|
||||
Args:
|
||||
name (str): Name of the scheduler.
|
||||
|
||||
Returns:
|
||||
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The scheduler parameter class or a tuple of parameter classes.
|
||||
"""
|
||||
entry = cls.__get_entry(name)
|
||||
return entry["params"]
|
||||
|
||||
@classmethod
|
||||
def get_available_schedulers(cls) -> List[str]:
|
||||
"""
|
||||
Returns a list of available scheduler names in their original case.
|
||||
|
||||
Returns:
|
||||
List[str]: List of available scheduler names.
|
||||
"""
|
||||
return list(cls.__SCHEDULERS.keys())
|
@ -0,0 +1,14 @@
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class CosineAnnealingLRParams(BaseModel):
|
||||
"""Configuration for `torch.optim.lr_scheduler.CosineAnnealingLR`."""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
T_max: int = 100 # Maximum number of iterations
|
||||
eta_min: float = 0.0 # Minimum learning rate
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`."""
|
||||
return self.model_dump()
|
@ -0,0 +1,13 @@
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class ExponentialLRParams(BaseModel):
|
||||
"""Configuration for `torch.optim.lr_scheduler.ExponentialLR`."""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
gamma: float = 0.95 # Multiplicative factor of learning rate decay
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`."""
|
||||
return self.model_dump()
|
@ -0,0 +1,14 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class MultiStepLRParams(BaseModel):
|
||||
"""Configuration for `torch.optim.lr_scheduler.MultiStepLR`."""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay
|
||||
gamma: float = 0.1 # Multiplicative factor of learning rate decay
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`."""
|
||||
return self.model_dump()
|
@ -0,0 +1,14 @@
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class StepLRParams(BaseModel):
|
||||
"""Configuration for `torch.optim.lr_scheduler.StepLR`."""
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
step_size: int = 30 # Period of learning rate decay
|
||||
gamma: float = 0.1 # Multiplicative factor of learning rate decay
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`."""
|
||||
return self.model_dump()
|
@ -0,0 +1,120 @@
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, Type, Union, List
|
||||
|
||||
from config.config import Config
|
||||
from config.dataset_config import DatasetConfig
|
||||
|
||||
from core import (
|
||||
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
|
||||
)
|
||||
|
||||
|
||||
def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]:
|
||||
"""
|
||||
Instantiates the parameter class(es) with default values.
|
||||
|
||||
If 'param' is a tuple, instantiate each class and return a list of instances.
|
||||
Otherwise, instantiate the single class and return the instance.
|
||||
"""
|
||||
if isinstance(param, tuple):
|
||||
return [cls() for cls in param]
|
||||
else:
|
||||
return param()
|
||||
|
||||
def prompt_choice(prompt_message: str, options: List[str]) -> str:
|
||||
"""
|
||||
Prompt the user with a list of options and return the selected option.
|
||||
"""
|
||||
print(prompt_message)
|
||||
for i, option in enumerate(options, start=1):
|
||||
print(f"{i}. {option}")
|
||||
while True:
|
||||
try:
|
||||
choice = int(input("Enter your choice (number): "))
|
||||
if 1 <= choice <= len(options):
|
||||
return options[choice - 1]
|
||||
else:
|
||||
print("Invalid choice. Please try again.")
|
||||
except ValueError:
|
||||
print("Please enter a valid number.")
|
||||
|
||||
def main():
|
||||
# Determine the directory of this script.
|
||||
script_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Ask the user whether this is training mode.
|
||||
training_input = input("Is this training mode? (y/n): ").strip().lower()
|
||||
is_training = training_input in ("y", "yes")
|
||||
|
||||
# Create a default DatasetConfig based on the training mode.
|
||||
# The DatasetConfig.default_config method fills in required fields with zero-values.
|
||||
dataset_config = DatasetConfig(is_training=is_training)
|
||||
|
||||
# Prompt the user to select a model.
|
||||
model_options = ModelRegistry.get_available_models()
|
||||
chosen_model = prompt_choice("\nSelect a model:", model_options)
|
||||
model_param_class = ModelRegistry.get_model_params(chosen_model)
|
||||
model_instance = instantiate_params(model_param_class)
|
||||
|
||||
if is_training is False:
|
||||
config = Config(
|
||||
model={chosen_model: model_instance},
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
|
||||
# Construct a base filename from the selected registry names.
|
||||
base_filename = f"{chosen_model}"
|
||||
|
||||
else:
|
||||
# Prompt the user to select a criterion.
|
||||
criterion_options = CriterionRegistry.get_available_criterions()
|
||||
chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options)
|
||||
criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion)
|
||||
criterion_instance = instantiate_params(criterion_param_class)
|
||||
|
||||
# Prompt the user to select an optimizer.
|
||||
optimizer_options = OptimizerRegistry.get_available_optimizers()
|
||||
chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options)
|
||||
optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer)
|
||||
optimizer_instance = instantiate_params(optimizer_param_class)
|
||||
|
||||
# Prompt the user to select a scheduler.
|
||||
scheduler_options = SchedulerRegistry.get_available_schedulers()
|
||||
chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options)
|
||||
scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler)
|
||||
scheduler_instance = instantiate_params(scheduler_param_class)
|
||||
|
||||
# Assemble the overall configuration using the registry names as keys.
|
||||
config = Config(
|
||||
model={chosen_model: model_instance},
|
||||
dataset_config=dataset_config,
|
||||
criterion={chosen_criterion: criterion_instance},
|
||||
optimizer={chosen_optimizer: optimizer_instance},
|
||||
scheduler={chosen_scheduler: scheduler_instance}
|
||||
)
|
||||
|
||||
# Construct a base filename from the selected registry names.
|
||||
base_filename = f"{chosen_model}_{chosen_criterion}_{chosen_optimizer}_{chosen_scheduler}"
|
||||
|
||||
# Determine the output directory relative to this script.
|
||||
base_dir = os.path.join(script_path, "config/jsons", "train" if is_training else "predict")
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
|
||||
filename = f"{base_filename}.json"
|
||||
full_path = os.path.join(base_dir, filename)
|
||||
counter = 1
|
||||
|
||||
# Append a counter if a file with the same name exists.
|
||||
while os.path.exists(full_path):
|
||||
filename = f"{base_filename}_{counter}.json"
|
||||
full_path = os.path.join(base_dir, filename)
|
||||
counter += 1
|
||||
|
||||
# Save the configuration as a JSON file.
|
||||
config.save_json(full_path)
|
||||
|
||||
print(f"\nConfiguration saved to: {full_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,14 @@
|
||||
from config.config import Config
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json')
|
||||
pprint(config, indent=4)
|
||||
|
||||
print('\n\n')
|
||||
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV.json')
|
||||
pprint(config, indent=4)
|
||||
|
||||
print('\n\n')
|
||||
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV_1.json')
|
||||
pprint(config, indent=4)
|
Loading…
Reference in new issue