definitions of basic classes, as well as the ability to create config files

master
laynholt 1 month ago
commit e00331b503

6
.gitignore vendored

@ -0,0 +1,6 @@
*.pyc
__pycache__/
**/__pycache__/
.vscode/
*.json

@ -0,0 +1,3 @@
from .config import Config
__all__ = ["Config"]

@ -0,0 +1,96 @@
import json
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
from .dataset_config import DatasetConfig
class Config(BaseModel):
model: Dict[str, Union[BaseModel, List[BaseModel]]]
dataset_config: DatasetConfig
criterion: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
optimizer: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
scheduler: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
@staticmethod
def __dump_field(value: Any) -> Any:
"""
Recursively dumps a field if it is a BaseModel or a list/dict of BaseModels.
"""
if isinstance(value, BaseModel):
return value.model_dump()
elif isinstance(value, list):
return [Config.__dump_field(item) for item in value]
elif isinstance(value, dict):
return {k: Config.__dump_field(v) for k, v in value.items()}
else:
return value
def save_json(self, file_path: str, indent: int = 4) -> None:
"""
Saves the configuration to a JSON file using dumps of each individual field.
Args:
file_path (str): Destination path for the JSON file.
indent (int): Indentation level for the JSON file.
"""
config_dump = {
"model": self.__dump_field(self.model),
"dataset_config": self.dataset_config.model_dump()
}
if self.criterion is not None:
config_dump.update({"criterion": self.__dump_field(self.criterion)})
if self.optimizer is not None:
config_dump.update({"optimizer": self.__dump_field(self.optimizer)})
if self.scheduler is not None:
config_dump.update({"scheduler": self.__dump_field(self.scheduler)})
with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(config_dump, indent=indent))
@classmethod
def load_json(cls, file_path: str) -> "Config":
"""
Loads a configuration from a JSON file and re-instantiates each section using
the registry keys to recover the original parameter class(es).
Args:
file_path (str): Path to the JSON file.
Returns:
Config: An instance of Config with the proper parameter classes.
"""
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Parse dataset_config using its Pydantic model.
dataset_config = DatasetConfig(**data.get("dataset_config", {}))
# Helper function to parse registry fields.
def parse_field(field_data: Dict[str, Any], registry_getter) -> Dict[str, Union[BaseModel, List[BaseModel]]]:
result = {}
for key, value in field_data.items():
expected = registry_getter(key)
# If the registry returns a tuple, then we expect a list of dictionaries.
if isinstance(expected, tuple):
result[key] = [cls_param(**item) for cls_param, item in zip(expected, value)]
else:
result[key] = expected(**value)
return result
from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
)
parsed_model = parse_field(data.get("model", {}), lambda key: ModelRegistry.get_model_params(key))
parsed_criterion = parse_field(data.get("criterion", {}), lambda key: CriterionRegistry.get_criterion_params(key))
parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key))
parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key))
return cls(
model=parsed_model,
dataset_config=dataset_config,
criterion=parsed_criterion,
optimizer=parsed_optimizer,
scheduler=parsed_scheduler
)

@ -0,0 +1,246 @@
from pydantic import BaseModel, model_validator, field_validator
from typing import Any, Dict, Optional, Union
import os
class DatasetCommonConfig(BaseModel):
"""
Common configuration fields shared by both training and testing.
"""
device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
predictions_dir: str = "." # Directory to save predictions
@model_validator(mode="after")
def validate_common(self) -> "DatasetCommonConfig":
"""
Validates that device is non-empty.
"""
if not self.device:
raise ValueError("device must be provided and non-empty")
return self
class TrainingSplitInfo(BaseModel):
"""
Configuration for training mode when data is NOT pre-split (is_split is False).
Contains:
- split_seed: Seed for splitting.
- all_data_dir: Directory containing all data.
"""
split_seed: int = 0 # Seed for splitting if data is not pre-split
all_data_dir: str = "." # Directory containing all data if not pre-split
class TrainingPreSplitInfo(BaseModel):
"""
Configuration for training mode when data is pre-split (is_split is True).
Contains:
- train_dir, valid_dir, test_dir: Directories for training, validation, and testing data.
- train_size, valid_size, test_size: Data split ratios or counts.
- train_offset, valid_offset, test_offset: Offsets for respective splits.
"""
train_dir: str = "." # Directory for training data if data is pre-split
valid_dir: str = "" # Directory for validation data if data is pre-split
test_dir: str = "" # Directory for testing data if data is pre-split
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data
@field_validator("train_size", "valid_size", "test_size", mode="before")
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]:
"""
Validates size values:
- If provided as a float, must be in the range (0, 1].
- If provided as an int, must be non-negative.
"""
if isinstance(v, float):
if not (0 <= v <= 1):
raise ValueError("When provided as a float, size must be in the range (0, 1]")
elif isinstance(v, int):
if v < 0:
raise ValueError("When provided as an int, size must be non-negative")
else:
raise ValueError("Size must be either an int or a float")
return v
class DatasetTrainingConfig(BaseModel):
"""
Main training configuration.
Contains:
- is_split: Determines whether data is pre-split.
- pre_split: Configuration for when data is NOT pre-split.
- split: Configuration for when data is pre-split.
- Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights.
Both pre_split and split objects are always created, but only the one corresponding
to is_split is validated.
"""
is_split: bool = False # Whether the data is already split into train/validation sets
pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo()
split: TrainingSplitInfo = TrainingSplitInfo()
batch_size: int = 1 # Batch size for training
num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
pretrained_weights: str = "" # Path to pretrained weights for training
@model_validator(mode="after")
def validate_split_info(self) -> "DatasetTrainingConfig":
"""
Conditionally validates the nested split objects:
- If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist).
- If is_split is False, validates split (all_data_dir must be non-empty and exist).
"""
if not self.is_split:
if not self.split.all_data_dir:
raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split")
if not os.path.exists(self.split.all_data_dir):
raise ValueError(f"Path for all_data_dir does not exist: {self.split.all_data_dir}")
else:
if not self.pre_split.train_dir:
raise ValueError("When is_split is True, train_dir must be provided and non-empty in split")
if not os.path.exists(self.pre_split.train_dir):
raise ValueError(f"Path for train_dir does not exist: {self.pre_split.train_dir}")
if self.pre_split.valid_dir and not os.path.exists(self.pre_split.valid_dir):
raise ValueError(f"Path for valid_dir does not exist: {self.pre_split.valid_dir}")
if self.pre_split.test_dir and not os.path.exists(self.pre_split.test_dir):
raise ValueError(f"Path for test_dir does not exist: {self.pre_split.test_dir}")
return self
@model_validator(mode="after")
def validate_numeric_fields(self) -> "DatasetTrainingConfig":
"""
Validates numeric fields:
- batch_size and num_epochs must be > 0.
- val_freq must be >= 0.
"""
if self.batch_size <= 0:
raise ValueError("batch_size must be > 0")
if self.num_epochs <= 0:
raise ValueError("num_epochs must be > 0")
if self.val_freq < 0:
raise ValueError("val_freq must be >= 0")
return self
@model_validator(mode="after")
def validate_pretrained(self) -> "DatasetTrainingConfig":
"""
Validates that pretrained_weights is provided and exists.
"""
if self.pretrained_weights and not os.path.exists(self.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
return self
class DatasetTestingConfig(BaseModel):
"""
Configuration fields used only in testing mode.
"""
test_dir: str = "." # Test data directory; must be non-empty
test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
test_offset: int = 0 # Offset for testing data
use_ensemble: bool = False # Flag to use ensemble mode in testing
ensemble_pretrained_weights1: str = "."
ensemble_pretrained_weights2: str = "."
pretrained_weights: str = "."
@field_validator("test_size", mode="before")
def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]:
"""
Validates the test_size value.
"""
if isinstance(v, float):
if not (0 < v <= 1):
raise ValueError("When provided as a float, test_size must be in the range (0, 1]")
elif isinstance(v, int):
if v < 0:
raise ValueError("When provided as an int, test_size must be non-negative")
else:
raise ValueError("test_size must be either an int or a float")
return v
@model_validator(mode="after")
def validate_testing(self) -> "DatasetTestingConfig":
"""
Validates the testing configuration:
- test_dir must be non-empty and exist.
- If use_ensemble is True, both ensemble_pretrained_weights1 and ensemble_pretrained_weights2 must be provided and exist.
- If use_ensemble is False, pretrained_weights must be provided and exist.
"""
if not self.test_dir:
raise ValueError("In testing configuration, test_dir must be provided and non-empty")
if not os.path.exists(self.test_dir):
raise ValueError(f"Path for test_dir does not exist: {self.test_dir}")
if self.use_ensemble:
for field in ["ensemble_pretrained_weights1", "ensemble_pretrained_weights2"]:
value = getattr(self, field)
if not value:
raise ValueError(f"When use_ensemble is True, {field} must be provided and non-empty")
if not os.path.exists(value):
raise ValueError(f"Path for {field} does not exist: {value}")
else:
if not self.pretrained_weights:
raise ValueError("When use_ensemble is False, pretrained_weights must be provided and non-empty")
if not os.path.exists(self.pretrained_weights):
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
if self.test_offset < 0:
raise ValueError("test_offset must be >= 0")
return self
class DatasetConfig(BaseModel):
"""
Main dataset configuration that groups fields into nested models for a structured and readable JSON.
"""
is_training: bool = True # Flag indicating whether the configuration is for training (True) or testing (False)
common: DatasetCommonConfig = DatasetCommonConfig()
training: DatasetTrainingConfig = DatasetTrainingConfig()
testing: DatasetTestingConfig = DatasetTestingConfig()
@model_validator(mode="after")
def validate_config(self) -> "DatasetConfig":
"""
Validates the overall dataset configuration:
"""
if self.is_training:
if self.training is None:
raise ValueError("Training configuration must be provided when is_training is True")
# Check predictions_dir if training.split.test_dir and test_size are set
if self.training.pre_split.test_dir and self.training.pre_split.test_size > 0:
if not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when training.split.test_dir and test_size are non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
else:
if self.testing is None:
raise ValueError("Testing configuration must be provided when is_training is False")
if self.testing.test_dir and self.testing.test_size > 0:
if not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when testing.test_dir and test_size are non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
return self
def model_dump(self, **kwargs) -> Dict[str, Any]:
"""
Dumps only the relevant configuration depending on the is_training flag.
Only the nested configuration (training or testing) along with common fields is returned.
"""
if self.is_training:
return {
"is_training": self.is_training,
"common": self.common.model_dump(),
"training": self.training.model_dump() if self.training else {}
}
else:
return {
"is_training": self.is_training,
"common": self.common.model_dump(),
"testing": self.testing.model_dump() if self.testing else {}
}

@ -0,0 +1,7 @@
from .models import ModelRegistry
from .criteria import CriterionRegistry
from .optimizers import OptimizerRegistry
from .schedulers import SchedulerRegistry
__all__ = ["ModelRegistry", "CriterionRegistry", "OptimizerRegistry", "SchedulerRegistry"]

@ -0,0 +1,97 @@
from typing import Dict, Final, Tuple, Type, List, Any, Union
from pydantic import BaseModel
from .base import BaseLoss
from .ce import CrossEntropyLoss, CrossEntropyLossParams
from .bce import BCELoss, BCELossParams
from .mse import MSELoss, MSELossParams
from .mse_with_bce import BCE_MSE_Loss
__all__ = [
"CriterionRegistry",
"CrossEntropyLoss", "BCELoss", "MSELoss", "BCE_MSE_Loss",
"CrossEntropyLossParams", "BCELossParams", "MSELossParams"
]
class CriterionRegistry:
"""Registry of loss functions and their parameter classes with case-insensitive lookup."""
__CRITERIONS: Final[Dict[str, Dict[str, Any]]] = {
"CrossEntropyLoss": {
"class": CrossEntropyLoss,
"params": CrossEntropyLossParams,
},
"BCELoss": {
"class": BCELoss,
"params": BCELossParams,
},
"MSELoss": {
"class": MSELoss,
"params": MSELossParams,
},
"BCE_MSE_Loss": {
"class": BCE_MSE_Loss,
"params": (BCELossParams, MSELossParams),
},
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Any]:
"""
Private method to retrieve the criterion entry from the registry using case-insensitive lookup.
Args:
name (str): The name of the loss function.
Returns:
Dict[str, Any]: A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the loss function is not found.
"""
name_lower = name.lower()
mapping = {key.lower(): key for key in cls.__CRITERIONS}
original_key = mapping.get(name_lower)
if original_key is None:
raise ValueError(
f"Criterion '{name}' not found! Available options: {list(cls.__CRITERIONS.keys())}"
)
return cls.__CRITERIONS[original_key]
@classmethod
def get_criterion_class(cls, name: str) -> Type[BaseLoss]:
"""
Retrieves the loss function class by name (case-insensitive).
Args:
name (str): Name of the loss function.
Returns:
Type[BaseLoss]: The loss function class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_criterion_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
"""
Retrieves the loss function parameter class (or classes) by name (case-insensitive).
Args:
name (str): Name of the loss function.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The loss function parameter class or a tuple of parameter classes.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_criterions(cls) -> List[str]:
"""
Returns a list of available loss function names in their original case.
Returns:
List[str]: List of available loss function names.
"""
return list(cls.__CRITERIONS.keys())

@ -0,0 +1,44 @@
import abc
import torch
import torch.nn as nn
from typing import Dict, Any, Optional
from monai.metrics.cumulative_average import CumulativeAverage
class BaseLoss(abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self):
"""
"""
super().__init__()
@abc.abstractmethod
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the loss between true labels and prediction outputs.
Args:
outputs (torch.Tensor): Model predictions.
target (torch.Tensor): Ground truth.
Returns:
torch.Tensor: The total loss value.
"""
@abc.abstractmethod
def get_loss_metrics(self) -> Dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the loss name and average loss value.
"""
@abc.abstractmethod
def reset_metrics(self):
"""Resets the stored loss metrics."""

@ -0,0 +1,98 @@
from .base import *
from typing import Literal
from pydantic import BaseModel, ConfigDict
class BCELossParams(BaseModel):
"""
Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`.
"""
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
weight: Optional[torch.Tensor] = None # Sample weights
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
pos_weight: Optional[torch.Tensor] = None # Used only for BCEWithLogitsLoss
def asdict(self, with_logits: bool = False) -> Dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`.
- If `with_logits=False`, `pos_weight` is **removed** to avoid errors.
- Ensures only the valid parameters are passed based on the loss function.
Args:
with_logits (bool): If `True`, includes `pos_weight` (for `nn.BCEWithLogitsLoss`).
If `False`, removes `pos_weight` (for `nn.BCELoss`).
Returns:
Dict[str, Any]: Filtered dictionary of parameters.
"""
loss_kwargs = self.model_dump()
if not with_logits:
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class BCELoss(BaseLoss):
"""
Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics.
"""
def __init__(self, bce_params: Optional[BCELossParams] = None, with_logits: bool = False):
"""
Initializes the loss function with optional BCELoss parameters.
Args:
bce_params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None).
"""
super().__init__()
_bce_params = bce_params.asdict(with_logits=with_logits) if bce_params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults
self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_bce_metric = CumulativeAverage()
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the loss between true labels and prediction outputs.
Args:
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
target (torch.Tensor): Ground truth labels in one-hot format.
Returns:
torch.Tensor: The total loss value.
"""
# Ensure target is on the same device as outputs
assert (
target.device == outputs.device
), (
"Target tensor must be moved to the same device as outputs "
"before calling forward()."
)
loss = self.bce_loss(outputs, target)
self.loss_bce_metric.append(loss.item())
return loss
def get_loss_metrics(self) -> Dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average BCE loss.
"""
return {
"loss": round(self.loss_bce_metric.aggregate().item(), 4),
}
def reset_metrics(self):
"""Resets the stored loss metrics."""
self.loss_bce_metric.reset()

@ -0,0 +1,90 @@
from .base import *
from typing import Literal
from pydantic import BaseModel, ConfigDict
class CrossEntropyLossParams(BaseModel):
"""
Class for handling parameters for `nn.CrossEntropyLoss`.
"""
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
weight: Optional[torch.Tensor] = None
ignore_index: int = -100
reduction: Literal["none", "mean", "sum"] = "mean"
label_smoothing: float = 0.0
def asdict(self):
"""
Returns a dictionary of valid parameters for `nn.CrossEntropyLoss`.
Returns:
Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss.
"""
loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class CrossEntropyLoss(BaseLoss):
"""
Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics.
"""
def __init__(self, ce_params: Optional[CrossEntropyLossParams] = None):
"""
Initializes the loss function with optional CrossEntropyLoss parameters.
Args:
ce_params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None).
"""
super().__init__()
_ce_params = ce_params.asdict() if ce_params is not None else {}
# Initialize loss functions with user-provided parameters or PyTorch defaults
self.ce_loss = nn.CrossEntropyLoss(**_ce_params)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_ce_metric = CumulativeAverage()
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the loss between true labels and prediction outputs.
Args:
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
target (torch.Tensor): Ground truth labels of shape (batch_size, H, W).
Returns:
torch.Tensor: The total loss value.
"""
# Ensure target is on the same device as outputs
assert (
target.device == outputs.device
), (
"Target tensor must be moved to the same device as outputs "
"before calling forward()."
)
loss = self.ce_loss(outputs, target)
self.loss_ce_metric.append(loss.item())
return loss
def get_loss_metrics(self) -> Dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average CrossEntropy loss.
"""
return {
"loss": round(self.loss_ce_metric.aggregate().item(), 4),
}
def reset_metrics(self):
"""Resets the stored loss metrics."""
self.loss_ce_metric.reset()

@ -0,0 +1,83 @@
from .base import *
from typing import Literal
from pydantic import BaseModel, ConfigDict
class MSELossParams(BaseModel):
"""
Class for MSE loss parameters, compatible with `nn.MSELoss`.
"""
model_config = ConfigDict(frozen=True)
reduction: Literal["none", "mean", "sum"] = "mean"
def asdict(self) -> Dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.MSELoss`.
Returns:
Dict[str, Any]: Dictionary of parameters for `nn.MSELoss`.
"""
loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class MSELoss(BaseLoss):
"""
Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics.
"""
def __init__(self, mse_params: Optional[MSELossParams] = None):
"""
Initializes the loss function with optional MSELoss parameters.
Args:
mse_params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None).
"""
super().__init__()
_mse_params = mse_params.asdict() if mse_params is not None else {}
# Initialize MSE loss with user-provided parameters or PyTorch defaults
self.mse_loss = nn.MSELoss(**_mse_params)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_mse_metric = CumulativeAverage()
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the loss between true values and predictions.
Args:
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W).
Returns:
torch.Tensor: The total loss value.
"""
# Ensure target is on the same device as outputs
assert (
target.device == outputs.device
), (
"Target tensor must be moved to the same device as outputs "
"before calling forward()."
)
loss = self.mse_loss(outputs, target)
self.loss_mse_metric.append(loss.item())
return loss
def get_loss_metrics(self) -> Dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average MSE loss.
"""
return {
"loss": round(self.loss_mse_metric.aggregate().item(), 4),
}
def reset_metrics(self):
"""Resets the stored loss metrics."""
self.loss_mse_metric.reset()

@ -0,0 +1,105 @@
from .base import *
from .bce import BCELossParams
from .mse import MSELossParams
class BCE_MSE_Loss(BaseLoss):
"""
Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction.
"""
def __init__(
self,
num_classes: int,
bce_params: Optional[BCELossParams] = None,
mse_params: Optional[MSELossParams] = None,
bce_with_logits: bool = False,
):
"""
Initializes the loss function with optional BCE and MSE parameters.
Args:
num_classes (int): Number of output classes, used for target shifting.
bce_params (Optional[BCELossParams]): Parameters for BCEWithLogitsLoss or BCELoss (default: None).
mse_params (Optional[MSELossParams]): Parameters for MSELoss (default: None).
bce_with_logits (bool): If True, uses BCEWithLogitsLoss; otherwise, uses BCELoss.
"""
super().__init__()
self.num_classes = num_classes
# Process BCE parameters
_bce_params = bce_params.asdict(bce_with_logits) if bce_params is not None else {}
# Choose BCE loss function
self.bce_loss = (
nn.BCEWithLogitsLoss(**_bce_params) if bce_with_logits else nn.BCELoss(**_bce_params)
)
# Process MSE parameters
_mse_params = mse_params.asdict() if mse_params is not None else {}
# Initialize MSE loss
self.mse_loss = nn.MSELoss(**_mse_params)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_bce_metric = CumulativeAverage()
self.loss_mse_metric = CumulativeAverage()
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the loss between true labels and prediction outputs.
Args:
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W).
Returns:
torch.Tensor: The total loss value.
"""
# Ensure target is on the same device as outputs
assert (
target.device == outputs.device
), (
"Target tensor must be moved to the same device as outputs "
"before calling forward()."
)
# Cell Recognition Loss
cellprob_loss = self.bce_loss(
outputs[:, -self.num_classes:], target[:, self.num_classes:2 * self.num_classes].float()
)
# Cell Distinction Loss
gradflow_loss = 0.5 * self.mse_loss(
outputs[:, :2 * self.num_classes], 5.0 * target[:, 2 * self.num_classes:]
)
# Total loss
total_loss = cellprob_loss + gradflow_loss
# Track individual losses
self.loss_bce_metric.append(cellprob_loss.item())
self.loss_mse_metric.append(gradflow_loss.item())
return total_loss
def get_loss_metrics(self) -> Dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average BCE and MSE loss.
"""
return {
"bce_loss": round(self.loss_bce_metric.aggregate().item(), 4),
"mse_loss": round(self.loss_mse_metric.aggregate().item(), 4),
"loss": round(
self.loss_bce_metric.aggregate().item() + self.loss_mse_metric.aggregate().item(), 4
),
}
def reset_metrics(self):
"""Resets the stored loss metrics."""
self.loss_bce_metric.reset()
self.loss_mse_metric.reset()

@ -0,0 +1,85 @@
import torch.nn as nn
from typing import Dict, Final, Tuple, Type, Any, List, Union
from pydantic import BaseModel
from .model_v import ModelV, ModelVParams
__all__ = [
"ModelRegistry",
"ModelV",
"ModelVParams"
]
class ModelRegistry:
"""Registry for models and their parameter classes with case-insensitive lookup."""
# Single dictionary storing both model classes and parameter classes.
__MODELS: Final[Dict[str, Dict[str, Type[Any]]]] = {
"ModelV": {
"class": ModelV,
"params": ModelVParams,
},
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
"""
Private method to retrieve the model entry from the registry using case-insensitive lookup.
Args:
name (str): The name of the model.
Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the model is not found.
"""
name_lower = name.lower()
mapping = {key.lower(): key for key in cls.__MODELS}
original_key = mapping.get(name_lower)
if original_key is None:
raise ValueError(
f"Model '{name}' not found! Available options: {list(cls.__MODELS.keys())}"
)
return cls.__MODELS[original_key]
@classmethod
def get_model_class(cls, name: str) -> Type[nn.Module]:
"""
Retrieves the model class by name (case-insensitive).
Args:
name (str): Name of the model.
Returns:
Type[nn.Module]: The model class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_model_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
"""
Retrieves the model parameter class by name (case-insensitive).
Args:
name (str): Name of the model.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The model parameter class or a tuple of parameter classes.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_models(cls) -> List[str]:
"""
Returns a list of available model names in their original case.
Returns:
List[str]: List of available model names.
"""
return list(cls.__MODELS.keys())

@ -0,0 +1,117 @@
from typing import List, Optional
import torch
import torch.nn as nn
from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation
from pydantic import BaseModel, ConfigDict
__all__ = ["ModelV"]
class ModelVParams(BaseModel):
model_config = ConfigDict(frozen=True)
encoder_name: str = "mit_b5" # Default encoder
encoder_weights: Optional[str] = "imagenet" # Pre-trained weights
decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration
decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels
in_channels: int = 3 # Number of input channels
out_classes: int = 3 # Number of output classes
def asdict(self):
"""
Returns a dictionary of valid parameters for `nn.ModelV`.
Returns:
Dict[str, Any]: Dictionary of parameters for nn.ModelV.
"""
loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class ModelV(MAnet):
"""ModelV model"""
def __init__(self, params: ModelVParams) -> None:
# Initialize the MAnet model with provided parameters
super().__init__(**params.asdict())
# Remove the default segmentation head as it's not used in this architecture
self.segmentation_head = None
# Modify all activation functions in the encoder and decoder from ReLU to Mish
_convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
_convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
# Add custom segmentation heads for different segmentation tasks
self.cellprob_head = DeepSegmentationHead(
in_channels=params.decoder_channels[-1], out_channels=params.out_classes
)
self.gradflow_head = DeepSegmentationHead(
in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network"""
# Ensure the input shape is correct
self.check_input_shape(x)
# Encode the input and then decode it
features = self.encoder(x)
decoder_output = self.decoder(*features)
# Generate masks for cell probability and gradient flows
cellprob_mask = self.cellprob_head(decoder_output)
gradflow_mask = self.gradflow_head(decoder_output)
# Concatenate the masks for output
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
return masks
class DeepSegmentationHead(nn.Sequential):
"""Custom segmentation head for generating specific masks"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
activation: Optional[str] = None,
upsampling: int = 1,
) -> None:
# Define a sequence of layers for the segmentation head
layers: List[nn.Module] = [
nn.Conv2d(
in_channels,
in_channels // 2,
kernel_size=kernel_size,
padding=kernel_size // 2,
),
nn.Mish(inplace=True),
nn.BatchNorm2d(in_channels // 2),
nn.Conv2d(
in_channels // 2,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
),
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity(),
Activation(activation) if activation else nn.Identity(),
]
super().__init__(*layers)
def _convert_activations(module: nn.Module, from_activation: type, to_activation: nn.Module) -> None:
"""Recursively convert activation functions in a module"""
for name, child in module.named_children():
if isinstance(child, from_activation):
setattr(module, name, to_activation)
else:
_convert_activations(child, from_activation, to_activation)

@ -0,0 +1,92 @@
import torch.optim as optim
from pydantic import BaseModel
from typing import Dict, Final, Tuple, Type, List, Any, Union
from .adam import AdamParams
from .adamw import AdamWParams
from .sgd import SGDParams
__all__ = [
"OptimizerRegistry",
"AdamParams", "AdamWParams", "SGDParams"
]
class OptimizerRegistry:
"""Registry for optimizers and their parameter classes with case-insensitive lookup."""
# Single dictionary storing both optimizer classes and parameter classes.
__OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
"SGD": {
"class": optim.SGD,
"params": SGDParams,
},
"Adam": {
"class": optim.Adam,
"params": AdamParams,
},
"AdamW": {
"class": optim.AdamW,
"params": AdamWParams,
},
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
"""
Private method to retrieve the optimizer entry from the registry using case-insensitive lookup.
Args:
name (str): The name of the optimizer.
Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the optimizer is not found.
"""
name_lower = name.lower()
mapping = {key.lower(): key for key in cls.__OPTIMIZERS}
original_key = mapping.get(name_lower)
if original_key is None:
raise ValueError(
f"Optimizer '{name}' not found! Available options: {list(cls.__OPTIMIZERS.keys())}"
)
return cls.__OPTIMIZERS[original_key]
@classmethod
def get_optimizer_class(cls, name: str) -> Type[optim.Optimizer]:
"""
Retrieves the optimizer class by name (case-insensitive).
Args:
name (str): Name of the optimizer.
Returns:
Type[optim.Optimizer]: The optimizer class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_optimizer_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
"""
Retrieves the optimizer parameter class by name (case-insensitive).
Args:
name (str): Name of the optimizer.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The optimizer parameter class or a tuple of parameter classes.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_optimizers(cls) -> List[str]:
"""
Returns a list of available optimizer names in their original case.
Returns:
List[str]: List of available optimizer names.
"""
return list(cls.__OPTIMIZERS.keys())

@ -0,0 +1,18 @@
import torch
from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict
class AdamParams(BaseModel):
"""Configuration for `torch.optim.Adam` optimizer."""
model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate
betas: Tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages
eps: float = 1e-8 # Term added to denominator for numerical stability
weight_decay: float = 0.0 # L2 regularization
amsgrad: bool = False # Whether to use the AMSGrad variant
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.Adam`."""
return self.model_dump()

@ -0,0 +1,17 @@
import torch
from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict
class AdamWParams(BaseModel):
"""Configuration for `torch.optim.AdamW` optimizer."""
model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate
betas: Tuple[float, ...] = (0.9, 0.999) # Adam coefficients
eps: float = 1e-8 # Numerical stability
weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay)
amsgrad: bool = False # Whether to use the AMSGrad variant
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.AdamW`."""
return self.model_dump()

@ -0,0 +1,18 @@
import torch
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
class SGDParams(BaseModel):
"""Configuration for `torch.optim.SGD` optimizer."""
model_config = ConfigDict(frozen=True)
lr: float = 1e-3 # Learning rate
momentum: float = 0.0 # Momentum factor
dampening: float = 0.0 # Dampening for momentum
weight_decay: float = 0.0 # L2 penalty
nesterov: bool = False # Enables Nesterov momentum
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.SGD`."""
return self.model_dump()

@ -0,0 +1,96 @@
import torch.optim.lr_scheduler as lr_scheduler
from typing import Dict, Final, Tuple, Type, List, Any, Union
from pydantic import BaseModel
from .step import StepLRParams
from .multi_step import MultiStepLRParams
from .exponential import ExponentialLRParams
from .cosine_annealing import CosineAnnealingLRParams
__all__ = [
"SchedulerRegistry",
"StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams"
]
class SchedulerRegistry:
"""Registry for learning rate schedulers and their parameter classes with case-insensitive lookup."""
__SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = {
"Step": {
"class": lr_scheduler.StepLR,
"params": StepLRParams,
},
"Exponential": {
"class": lr_scheduler.ExponentialLR,
"params": ExponentialLRParams,
},
"MultiStep": {
"class": lr_scheduler.MultiStepLR,
"params": MultiStepLRParams,
},
"CosineAnnealing": {
"class": lr_scheduler.CosineAnnealingLR,
"params": CosineAnnealingLRParams,
},
}
@classmethod
def __get_entry(cls, name: str) -> Dict[str, Type[Any]]:
"""
Private method to retrieve the scheduler entry from the registry using case-insensitive lookup.
Args:
name (str): The name of the scheduler.
Returns:
Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'.
Raises:
ValueError: If the scheduler is not found.
"""
name_lower = name.lower()
mapping = {key.lower(): key for key in cls.__SCHEDULERS}
original_key = mapping.get(name_lower)
if original_key is None:
raise ValueError(
f"Scheduler '{name}' not found! Available options: {list(cls.__SCHEDULERS.keys())}"
)
return cls.__SCHEDULERS[original_key]
@classmethod
def get_scheduler_class(cls, name: str) -> Type[lr_scheduler.LRScheduler]:
"""
Retrieves the scheduler class by name (case-insensitive).
Args:
name (str): Name of the scheduler.
Returns:
Type[lr_scheduler.LRScheduler]: The scheduler class.
"""
entry = cls.__get_entry(name)
return entry["class"]
@classmethod
def get_scheduler_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]:
"""
Retrieves the scheduler parameter class by name (case-insensitive).
Args:
name (str): Name of the scheduler.
Returns:
Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The scheduler parameter class or a tuple of parameter classes.
"""
entry = cls.__get_entry(name)
return entry["params"]
@classmethod
def get_available_schedulers(cls) -> List[str]:
"""
Returns a list of available scheduler names in their original case.
Returns:
List[str]: List of available scheduler names.
"""
return list(cls.__SCHEDULERS.keys())

@ -0,0 +1,14 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
class CosineAnnealingLRParams(BaseModel):
"""Configuration for `torch.optim.lr_scheduler.CosineAnnealingLR`."""
model_config = ConfigDict(frozen=True)
T_max: int = 100 # Maximum number of iterations
eta_min: float = 0.0 # Minimum learning rate
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`."""
return self.model_dump()

@ -0,0 +1,13 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
class ExponentialLRParams(BaseModel):
"""Configuration for `torch.optim.lr_scheduler.ExponentialLR`."""
model_config = ConfigDict(frozen=True)
gamma: float = 0.95 # Multiplicative factor of learning rate decay
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`."""
return self.model_dump()

@ -0,0 +1,14 @@
from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict
class MultiStepLRParams(BaseModel):
"""Configuration for `torch.optim.lr_scheduler.MultiStepLR`."""
model_config = ConfigDict(frozen=True)
milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`."""
return self.model_dump()

@ -0,0 +1,14 @@
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict
class StepLRParams(BaseModel):
"""Configuration for `torch.optim.lr_scheduler.StepLR`."""
model_config = ConfigDict(frozen=True)
step_size: int = 30 # Period of learning rate decay
gamma: float = 0.1 # Multiplicative factor of learning rate decay
def asdict(self) -> Dict[str, Any]:
"""Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`."""
return self.model_dump()

@ -0,0 +1,120 @@
import os
from pydantic import BaseModel
from typing import Any, Dict, Type, Union, List
from config.config import Config
from config.dataset_config import DatasetConfig
from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
)
def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]:
"""
Instantiates the parameter class(es) with default values.
If 'param' is a tuple, instantiate each class and return a list of instances.
Otherwise, instantiate the single class and return the instance.
"""
if isinstance(param, tuple):
return [cls() for cls in param]
else:
return param()
def prompt_choice(prompt_message: str, options: List[str]) -> str:
"""
Prompt the user with a list of options and return the selected option.
"""
print(prompt_message)
for i, option in enumerate(options, start=1):
print(f"{i}. {option}")
while True:
try:
choice = int(input("Enter your choice (number): "))
if 1 <= choice <= len(options):
return options[choice - 1]
else:
print("Invalid choice. Please try again.")
except ValueError:
print("Please enter a valid number.")
def main():
# Determine the directory of this script.
script_path = os.path.dirname(os.path.abspath(__file__))
# Ask the user whether this is training mode.
training_input = input("Is this training mode? (y/n): ").strip().lower()
is_training = training_input in ("y", "yes")
# Create a default DatasetConfig based on the training mode.
# The DatasetConfig.default_config method fills in required fields with zero-values.
dataset_config = DatasetConfig(is_training=is_training)
# Prompt the user to select a model.
model_options = ModelRegistry.get_available_models()
chosen_model = prompt_choice("\nSelect a model:", model_options)
model_param_class = ModelRegistry.get_model_params(chosen_model)
model_instance = instantiate_params(model_param_class)
if is_training is False:
config = Config(
model={chosen_model: model_instance},
dataset_config=dataset_config
)
# Construct a base filename from the selected registry names.
base_filename = f"{chosen_model}"
else:
# Prompt the user to select a criterion.
criterion_options = CriterionRegistry.get_available_criterions()
chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options)
criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion)
criterion_instance = instantiate_params(criterion_param_class)
# Prompt the user to select an optimizer.
optimizer_options = OptimizerRegistry.get_available_optimizers()
chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options)
optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer)
optimizer_instance = instantiate_params(optimizer_param_class)
# Prompt the user to select a scheduler.
scheduler_options = SchedulerRegistry.get_available_schedulers()
chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options)
scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler)
scheduler_instance = instantiate_params(scheduler_param_class)
# Assemble the overall configuration using the registry names as keys.
config = Config(
model={chosen_model: model_instance},
dataset_config=dataset_config,
criterion={chosen_criterion: criterion_instance},
optimizer={chosen_optimizer: optimizer_instance},
scheduler={chosen_scheduler: scheduler_instance}
)
# Construct a base filename from the selected registry names.
base_filename = f"{chosen_model}_{chosen_criterion}_{chosen_optimizer}_{chosen_scheduler}"
# Determine the output directory relative to this script.
base_dir = os.path.join(script_path, "config/jsons", "train" if is_training else "predict")
os.makedirs(base_dir, exist_ok=True)
filename = f"{base_filename}.json"
full_path = os.path.join(base_dir, filename)
counter = 1
# Append a counter if a file with the same name exists.
while os.path.exists(full_path):
filename = f"{base_filename}_{counter}.json"
full_path = os.path.join(base_dir, filename)
counter += 1
# Save the configuration as a JSON file.
config.save_json(full_path)
print(f"\nConfiguration saved to: {full_path}")
if __name__ == "__main__":
main()

@ -0,0 +1,14 @@
from config.config import Config
from pprint import pprint
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json')
pprint(config, indent=4)
print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV.json')
pprint(config, indent=4)
print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV_1.json')
pprint(config, indent=4)
Loading…
Cancel
Save