commit e00331b503ae7983d74a1aaba4f41a4dfdb7286f Author: laynholt Date: Wed Mar 26 16:53:46 2025 +0000 definitions of basic classes, as well as the ability to create config files diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9d851fe --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*.pyc +__pycache__/ +**/__pycache__/ + +.vscode/ +*.json \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..d890918 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,3 @@ +from .config import Config + +__all__ = ["Config"] \ No newline at end of file diff --git a/config/config.py b/config/config.py new file mode 100644 index 0000000..6971ea6 --- /dev/null +++ b/config/config.py @@ -0,0 +1,96 @@ +import json +from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel + +from .dataset_config import DatasetConfig + +class Config(BaseModel): + model: Dict[str, Union[BaseModel, List[BaseModel]]] + dataset_config: DatasetConfig + criterion: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None + optimizer: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None + scheduler: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None + + @staticmethod + def __dump_field(value: Any) -> Any: + """ + Recursively dumps a field if it is a BaseModel or a list/dict of BaseModels. + """ + if isinstance(value, BaseModel): + return value.model_dump() + elif isinstance(value, list): + return [Config.__dump_field(item) for item in value] + elif isinstance(value, dict): + return {k: Config.__dump_field(v) for k, v in value.items()} + else: + return value + + def save_json(self, file_path: str, indent: int = 4) -> None: + """ + Saves the configuration to a JSON file using dumps of each individual field. + + Args: + file_path (str): Destination path for the JSON file. + indent (int): Indentation level for the JSON file. + """ + config_dump = { + "model": self.__dump_field(self.model), + "dataset_config": self.dataset_config.model_dump() + } + if self.criterion is not None: + config_dump.update({"criterion": self.__dump_field(self.criterion)}) + if self.optimizer is not None: + config_dump.update({"optimizer": self.__dump_field(self.optimizer)}) + if self.scheduler is not None: + config_dump.update({"scheduler": self.__dump_field(self.scheduler)}) + + with open(file_path, "w", encoding="utf-8") as f: + f.write(json.dumps(config_dump, indent=indent)) + + + @classmethod + def load_json(cls, file_path: str) -> "Config": + """ + Loads a configuration from a JSON file and re-instantiates each section using + the registry keys to recover the original parameter class(es). + + Args: + file_path (str): Path to the JSON file. + + Returns: + Config: An instance of Config with the proper parameter classes. + """ + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + # Parse dataset_config using its Pydantic model. + dataset_config = DatasetConfig(**data.get("dataset_config", {})) + + # Helper function to parse registry fields. + def parse_field(field_data: Dict[str, Any], registry_getter) -> Dict[str, Union[BaseModel, List[BaseModel]]]: + result = {} + for key, value in field_data.items(): + expected = registry_getter(key) + # If the registry returns a tuple, then we expect a list of dictionaries. + if isinstance(expected, tuple): + result[key] = [cls_param(**item) for cls_param, item in zip(expected, value)] + else: + result[key] = expected(**value) + return result + + from core import ( + ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry + ) + + parsed_model = parse_field(data.get("model", {}), lambda key: ModelRegistry.get_model_params(key)) + parsed_criterion = parse_field(data.get("criterion", {}), lambda key: CriterionRegistry.get_criterion_params(key)) + parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key)) + parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key)) + + return cls( + model=parsed_model, + dataset_config=dataset_config, + criterion=parsed_criterion, + optimizer=parsed_optimizer, + scheduler=parsed_scheduler + ) diff --git a/config/dataset_config.py b/config/dataset_config.py new file mode 100644 index 0000000..ccbfdfa --- /dev/null +++ b/config/dataset_config.py @@ -0,0 +1,246 @@ +from pydantic import BaseModel, model_validator, field_validator +from typing import Any, Dict, Optional, Union +import os + + +class DatasetCommonConfig(BaseModel): + """ + Common configuration fields shared by both training and testing. + """ + device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda') + use_tta: bool = False # Flag to use Test-Time Augmentation (TTA) + predictions_dir: str = "." # Directory to save predictions + + @model_validator(mode="after") + def validate_common(self) -> "DatasetCommonConfig": + """ + Validates that device is non-empty. + """ + if not self.device: + raise ValueError("device must be provided and non-empty") + return self + + +class TrainingSplitInfo(BaseModel): + """ + Configuration for training mode when data is NOT pre-split (is_split is False). + Contains: + - split_seed: Seed for splitting. + - all_data_dir: Directory containing all data. + """ + split_seed: int = 0 # Seed for splitting if data is not pre-split + all_data_dir: str = "." # Directory containing all data if not pre-split + +class TrainingPreSplitInfo(BaseModel): + """ + Configuration for training mode when data is pre-split (is_split is True). + Contains: + - train_dir, valid_dir, test_dir: Directories for training, validation, and testing data. + - train_size, valid_size, test_size: Data split ratios or counts. + - train_offset, valid_offset, test_offset: Offsets for respective splits. + """ + train_dir: str = "." # Directory for training data if data is pre-split + valid_dir: str = "" # Directory for validation data if data is pre-split + test_dir: str = "" # Directory for testing data if data is pre-split + train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic) + valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic) + test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic) + train_offset: int = 0 # Offset for training data + valid_offset: int = 0 # Offset for validation data + test_offset: int = 0 # Offset for testing data + + + @field_validator("train_size", "valid_size", "test_size", mode="before") + def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]: + """ + Validates size values: + - If provided as a float, must be in the range (0, 1]. + - If provided as an int, must be non-negative. + """ + if isinstance(v, float): + if not (0 <= v <= 1): + raise ValueError("When provided as a float, size must be in the range (0, 1]") + elif isinstance(v, int): + if v < 0: + raise ValueError("When provided as an int, size must be non-negative") + else: + raise ValueError("Size must be either an int or a float") + return v + + +class DatasetTrainingConfig(BaseModel): + """ + Main training configuration. + Contains: + - is_split: Determines whether data is pre-split. + - pre_split: Configuration for when data is NOT pre-split. + - split: Configuration for when data is pre-split. + - Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights. + + Both pre_split and split objects are always created, but only the one corresponding + to is_split is validated. + """ + is_split: bool = False # Whether the data is already split into train/validation sets + pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo() + split: TrainingSplitInfo = TrainingSplitInfo() + + batch_size: int = 1 # Batch size for training + num_epochs: int = 100 # Number of training epochs + val_freq: int = 1 # Frequency of validation during training + use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) + pretrained_weights: str = "" # Path to pretrained weights for training + + @model_validator(mode="after") + def validate_split_info(self) -> "DatasetTrainingConfig": + """ + Conditionally validates the nested split objects: + - If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist). + - If is_split is False, validates split (all_data_dir must be non-empty and exist). + """ + if not self.is_split: + if not self.split.all_data_dir: + raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split") + if not os.path.exists(self.split.all_data_dir): + raise ValueError(f"Path for all_data_dir does not exist: {self.split.all_data_dir}") + else: + if not self.pre_split.train_dir: + raise ValueError("When is_split is True, train_dir must be provided and non-empty in split") + if not os.path.exists(self.pre_split.train_dir): + raise ValueError(f"Path for train_dir does not exist: {self.pre_split.train_dir}") + if self.pre_split.valid_dir and not os.path.exists(self.pre_split.valid_dir): + raise ValueError(f"Path for valid_dir does not exist: {self.pre_split.valid_dir}") + if self.pre_split.test_dir and not os.path.exists(self.pre_split.test_dir): + raise ValueError(f"Path for test_dir does not exist: {self.pre_split.test_dir}") + return self + + @model_validator(mode="after") + def validate_numeric_fields(self) -> "DatasetTrainingConfig": + """ + Validates numeric fields: + - batch_size and num_epochs must be > 0. + - val_freq must be >= 0. + """ + if self.batch_size <= 0: + raise ValueError("batch_size must be > 0") + if self.num_epochs <= 0: + raise ValueError("num_epochs must be > 0") + if self.val_freq < 0: + raise ValueError("val_freq must be >= 0") + return self + + @model_validator(mode="after") + def validate_pretrained(self) -> "DatasetTrainingConfig": + """ + Validates that pretrained_weights is provided and exists. + """ + if self.pretrained_weights and not os.path.exists(self.pretrained_weights): + raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}") + return self + + +class DatasetTestingConfig(BaseModel): + """ + Configuration fields used only in testing mode. + """ + test_dir: str = "." # Test data directory; must be non-empty + test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic) + test_offset: int = 0 # Offset for testing data + use_ensemble: bool = False # Flag to use ensemble mode in testing + ensemble_pretrained_weights1: str = "." + ensemble_pretrained_weights2: str = "." + pretrained_weights: str = "." + + @field_validator("test_size", mode="before") + def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]: + """ + Validates the test_size value. + """ + if isinstance(v, float): + if not (0 < v <= 1): + raise ValueError("When provided as a float, test_size must be in the range (0, 1]") + elif isinstance(v, int): + if v < 0: + raise ValueError("When provided as an int, test_size must be non-negative") + else: + raise ValueError("test_size must be either an int or a float") + return v + + @model_validator(mode="after") + def validate_testing(self) -> "DatasetTestingConfig": + """ + Validates the testing configuration: + - test_dir must be non-empty and exist. + - If use_ensemble is True, both ensemble_pretrained_weights1 and ensemble_pretrained_weights2 must be provided and exist. + - If use_ensemble is False, pretrained_weights must be provided and exist. + """ + if not self.test_dir: + raise ValueError("In testing configuration, test_dir must be provided and non-empty") + if not os.path.exists(self.test_dir): + raise ValueError(f"Path for test_dir does not exist: {self.test_dir}") + if self.use_ensemble: + for field in ["ensemble_pretrained_weights1", "ensemble_pretrained_weights2"]: + value = getattr(self, field) + if not value: + raise ValueError(f"When use_ensemble is True, {field} must be provided and non-empty") + if not os.path.exists(value): + raise ValueError(f"Path for {field} does not exist: {value}") + else: + if not self.pretrained_weights: + raise ValueError("When use_ensemble is False, pretrained_weights must be provided and non-empty") + if not os.path.exists(self.pretrained_weights): + raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}") + if self.test_offset < 0: + raise ValueError("test_offset must be >= 0") + return self + + +class DatasetConfig(BaseModel): + """ + Main dataset configuration that groups fields into nested models for a structured and readable JSON. + """ + is_training: bool = True # Flag indicating whether the configuration is for training (True) or testing (False) + common: DatasetCommonConfig = DatasetCommonConfig() + training: DatasetTrainingConfig = DatasetTrainingConfig() + testing: DatasetTestingConfig = DatasetTestingConfig() + + @model_validator(mode="after") + def validate_config(self) -> "DatasetConfig": + """ + Validates the overall dataset configuration: + """ + if self.is_training: + if self.training is None: + raise ValueError("Training configuration must be provided when is_training is True") + # Check predictions_dir if training.split.test_dir and test_size are set + if self.training.pre_split.test_dir and self.training.pre_split.test_size > 0: + if not self.common.predictions_dir: + raise ValueError("predictions_dir must be provided when training.split.test_dir and test_size are non-zero") + if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): + raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") + else: + if self.testing is None: + raise ValueError("Testing configuration must be provided when is_training is False") + if self.testing.test_dir and self.testing.test_size > 0: + if not self.common.predictions_dir: + raise ValueError("predictions_dir must be provided when testing.test_dir and test_size are non-zero") + if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): + raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") + return self + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """ + Dumps only the relevant configuration depending on the is_training flag. + Only the nested configuration (training or testing) along with common fields is returned. + """ + if self.is_training: + return { + "is_training": self.is_training, + "common": self.common.model_dump(), + "training": self.training.model_dump() if self.training else {} + } + else: + return { + "is_training": self.is_training, + "common": self.common.model_dump(), + "testing": self.testing.model_dump() if self.testing else {} + } diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..1c28630 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,7 @@ +from .models import ModelRegistry +from .criteria import CriterionRegistry +from .optimizers import OptimizerRegistry +from .schedulers import SchedulerRegistry + + +__all__ = ["ModelRegistry", "CriterionRegistry", "OptimizerRegistry", "SchedulerRegistry"] diff --git a/core/criteria/__init__.py b/core/criteria/__init__.py new file mode 100644 index 0000000..86c32c7 --- /dev/null +++ b/core/criteria/__init__.py @@ -0,0 +1,97 @@ +from typing import Dict, Final, Tuple, Type, List, Any, Union +from pydantic import BaseModel + +from .base import BaseLoss +from .ce import CrossEntropyLoss, CrossEntropyLossParams +from .bce import BCELoss, BCELossParams +from .mse import MSELoss, MSELossParams +from .mse_with_bce import BCE_MSE_Loss + +__all__ = [ + "CriterionRegistry", + "CrossEntropyLoss", "BCELoss", "MSELoss", "BCE_MSE_Loss", + "CrossEntropyLossParams", "BCELossParams", "MSELossParams" +] + +class CriterionRegistry: + """Registry of loss functions and their parameter classes with case-insensitive lookup.""" + + __CRITERIONS: Final[Dict[str, Dict[str, Any]]] = { + "CrossEntropyLoss": { + "class": CrossEntropyLoss, + "params": CrossEntropyLossParams, + }, + "BCELoss": { + "class": BCELoss, + "params": BCELossParams, + }, + "MSELoss": { + "class": MSELoss, + "params": MSELossParams, + }, + "BCE_MSE_Loss": { + "class": BCE_MSE_Loss, + "params": (BCELossParams, MSELossParams), + }, + } + + @classmethod + def __get_entry(cls, name: str) -> Dict[str, Any]: + """ + Private method to retrieve the criterion entry from the registry using case-insensitive lookup. + + Args: + name (str): The name of the loss function. + + Returns: + Dict[str, Any]: A dictionary containing the keys 'class' and 'params'. + + Raises: + ValueError: If the loss function is not found. + """ + name_lower = name.lower() + mapping = {key.lower(): key for key in cls.__CRITERIONS} + original_key = mapping.get(name_lower) + if original_key is None: + raise ValueError( + f"Criterion '{name}' not found! Available options: {list(cls.__CRITERIONS.keys())}" + ) + return cls.__CRITERIONS[original_key] + + @classmethod + def get_criterion_class(cls, name: str) -> Type[BaseLoss]: + """ + Retrieves the loss function class by name (case-insensitive). + + Args: + name (str): Name of the loss function. + + Returns: + Type[BaseLoss]: The loss function class. + """ + entry = cls.__get_entry(name) + return entry["class"] + + @classmethod + def get_criterion_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: + """ + Retrieves the loss function parameter class (or classes) by name (case-insensitive). + + Args: + name (str): Name of the loss function. + + Returns: + Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The loss function parameter class or a tuple of parameter classes. + """ + entry = cls.__get_entry(name) + return entry["params"] + + @classmethod + def get_available_criterions(cls) -> List[str]: + """ + Returns a list of available loss function names in their original case. + + Returns: + List[str]: List of available loss function names. + """ + return list(cls.__CRITERIONS.keys()) diff --git a/core/criteria/base.py b/core/criteria/base.py new file mode 100644 index 0000000..4d8196f --- /dev/null +++ b/core/criteria/base.py @@ -0,0 +1,44 @@ +import abc +import torch +import torch.nn as nn +from typing import Dict, Any, Optional +from monai.metrics.cumulative_average import CumulativeAverage + + +class BaseLoss(abc.ABC): + """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" + + def __init__(self): + """ + """ + super().__init__() + + + @abc.abstractmethod + def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Computes the loss between true labels and prediction outputs. + + Args: + outputs (torch.Tensor): Model predictions. + target (torch.Tensor): Ground truth. + + Returns: + torch.Tensor: The total loss value. + """ + + + @abc.abstractmethod + def get_loss_metrics(self) -> Dict[str, float]: + """ + Retrieves the tracked loss metrics. + + Returns: + Dict[str, float]: A dictionary containing the loss name and average loss value. + """ + + + @abc.abstractmethod + def reset_metrics(self): + """Resets the stored loss metrics.""" + \ No newline at end of file diff --git a/core/criteria/bce.py b/core/criteria/bce.py new file mode 100644 index 0000000..8dd2a62 --- /dev/null +++ b/core/criteria/bce.py @@ -0,0 +1,98 @@ +from .base import * +from typing import Literal +from pydantic import BaseModel, ConfigDict + + +class BCELossParams(BaseModel): + """ + Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`. + """ + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + weight: Optional[torch.Tensor] = None # Sample weights + reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method + pos_weight: Optional[torch.Tensor] = None # Used only for BCEWithLogitsLoss + + def asdict(self, with_logits: bool = False) -> Dict[str, Any]: + """ + Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`. + + - If `with_logits=False`, `pos_weight` is **removed** to avoid errors. + - Ensures only the valid parameters are passed based on the loss function. + + Args: + with_logits (bool): If `True`, includes `pos_weight` (for `nn.BCEWithLogitsLoss`). + If `False`, removes `pos_weight` (for `nn.BCELoss`). + + Returns: + Dict[str, Any]: Filtered dictionary of parameters. + """ + loss_kwargs = self.model_dump() + if not with_logits: + loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss + + return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values + + +class BCELoss(BaseLoss): + """ + Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics. + """ + + def __init__(self, bce_params: Optional[BCELossParams] = None, with_logits: bool = False): + """ + Initializes the loss function with optional BCELoss parameters. + + Args: + bce_params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None). + """ + super().__init__() + _bce_params = bce_params.asdict(with_logits=with_logits) if bce_params is not None else {} + + # Initialize loss functions with user-provided parameters or PyTorch defaults + self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params) + + # Using CumulativeAverage from MONAI to track loss metrics + self.loss_bce_metric = CumulativeAverage() + + + def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Computes the loss between true labels and prediction outputs. + + Args: + outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W). + target (torch.Tensor): Ground truth labels in one-hot format. + + Returns: + torch.Tensor: The total loss value. + """ + # Ensure target is on the same device as outputs + assert ( + target.device == outputs.device + ), ( + "Target tensor must be moved to the same device as outputs " + "before calling forward()." + ) + + loss = self.bce_loss(outputs, target) + self.loss_bce_metric.append(loss.item()) + + return loss + + + def get_loss_metrics(self) -> Dict[str, float]: + """ + Retrieves the tracked loss metrics. + + Returns: + Dict[str, float]: A dictionary containing the average BCE loss. + """ + return { + "loss": round(self.loss_bce_metric.aggregate().item(), 4), + } + + + def reset_metrics(self): + """Resets the stored loss metrics.""" + self.loss_bce_metric.reset() diff --git a/core/criteria/ce.py b/core/criteria/ce.py new file mode 100644 index 0000000..fe4a086 --- /dev/null +++ b/core/criteria/ce.py @@ -0,0 +1,90 @@ +from .base import * +from typing import Literal +from pydantic import BaseModel, ConfigDict + + +class CrossEntropyLossParams(BaseModel): + """ + Class for handling parameters for `nn.CrossEntropyLoss`. + """ + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + weight: Optional[torch.Tensor] = None + ignore_index: int = -100 + reduction: Literal["none", "mean", "sum"] = "mean" + label_smoothing: float = 0.0 + + def asdict(self): + """ + Returns a dictionary of valid parameters for `nn.CrossEntropyLoss`. + + Returns: + Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss. + """ + loss_kwargs = self.model_dump() + return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values + + + +class CrossEntropyLoss(BaseLoss): + """ + Custom loss function wrapper for `nn.CrossEntropyLoss` with tracking of loss metrics. + """ + + def __init__(self, ce_params: Optional[CrossEntropyLossParams] = None): + """ + Initializes the loss function with optional CrossEntropyLoss parameters. + + Args: + ce_params (Optional[Dict[str, Any]]): Parameters for nn.CrossEntropyLoss (default: None). + """ + super().__init__() + _ce_params = ce_params.asdict() if ce_params is not None else {} + + # Initialize loss functions with user-provided parameters or PyTorch defaults + self.ce_loss = nn.CrossEntropyLoss(**_ce_params) + + # Using CumulativeAverage from MONAI to track loss metrics + self.loss_ce_metric = CumulativeAverage() + + + def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Computes the loss between true labels and prediction outputs. + + Args: + outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W). + target (torch.Tensor): Ground truth labels of shape (batch_size, H, W). + + Returns: + torch.Tensor: The total loss value. + """ + # Ensure target is on the same device as outputs + assert ( + target.device == outputs.device + ), ( + "Target tensor must be moved to the same device as outputs " + "before calling forward()." + ) + + loss = self.ce_loss(outputs, target) + self.loss_ce_metric.append(loss.item()) + + return loss + + + def get_loss_metrics(self) -> Dict[str, float]: + """ + Retrieves the tracked loss metrics. + + Returns: + Dict[str, float]: A dictionary containing the average CrossEntropy loss. + """ + return { + "loss": round(self.loss_ce_metric.aggregate().item(), 4), + } + + + def reset_metrics(self): + """Resets the stored loss metrics.""" + self.loss_ce_metric.reset() diff --git a/core/criteria/mse.py b/core/criteria/mse.py new file mode 100644 index 0000000..df9fd13 --- /dev/null +++ b/core/criteria/mse.py @@ -0,0 +1,83 @@ +from .base import * +from typing import Literal +from pydantic import BaseModel, ConfigDict + + +class MSELossParams(BaseModel): + """ + Class for MSE loss parameters, compatible with `nn.MSELoss`. + """ + model_config = ConfigDict(frozen=True) + + reduction: Literal["none", "mean", "sum"] = "mean" + + def asdict(self) -> Dict[str, Any]: + """ + Returns a dictionary of valid parameters for `nn.MSELoss`. + + Returns: + Dict[str, Any]: Dictionary of parameters for `nn.MSELoss`. + """ + loss_kwargs = self.model_dump() + return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values + + +class MSELoss(BaseLoss): + """ + Custom loss function wrapper for `nn.MSELoss` with tracking of loss metrics. + """ + + def __init__(self, mse_params: Optional[MSELossParams] = None): + """ + Initializes the loss function with optional MSELoss parameters. + + Args: + mse_params (Optional[MSELossParams]): Parameters for `nn.MSELoss` (default: None). + """ + super().__init__() + _mse_params = mse_params.asdict() if mse_params is not None else {} + + # Initialize MSE loss with user-provided parameters or PyTorch defaults + self.mse_loss = nn.MSELoss(**_mse_params) + + # Using CumulativeAverage from MONAI to track loss metrics + self.loss_mse_metric = CumulativeAverage() + + def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Computes the loss between true values and predictions. + + Args: + outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W). + target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W). + + Returns: + torch.Tensor: The total loss value. + """ + # Ensure target is on the same device as outputs + assert ( + target.device == outputs.device + ), ( + "Target tensor must be moved to the same device as outputs " + "before calling forward()." + ) + + loss = self.mse_loss(outputs, target) + self.loss_mse_metric.append(loss.item()) + + return loss + + def get_loss_metrics(self) -> Dict[str, float]: + """ + Retrieves the tracked loss metrics. + + Returns: + Dict[str, float]: A dictionary containing the average MSE loss. + """ + return { + "loss": round(self.loss_mse_metric.aggregate().item(), 4), + } + + def reset_metrics(self): + """Resets the stored loss metrics.""" + self.loss_mse_metric.reset() diff --git a/core/criteria/mse_with_bce.py b/core/criteria/mse_with_bce.py new file mode 100644 index 0000000..f36cf6f --- /dev/null +++ b/core/criteria/mse_with_bce.py @@ -0,0 +1,105 @@ +from .base import * +from .bce import BCELossParams +from .mse import MSELossParams + + +class BCE_MSE_Loss(BaseLoss): + """ + Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction. + """ + + def __init__( + self, + num_classes: int, + bce_params: Optional[BCELossParams] = None, + mse_params: Optional[MSELossParams] = None, + bce_with_logits: bool = False, + ): + """ + Initializes the loss function with optional BCE and MSE parameters. + + Args: + num_classes (int): Number of output classes, used for target shifting. + bce_params (Optional[BCELossParams]): Parameters for BCEWithLogitsLoss or BCELoss (default: None). + mse_params (Optional[MSELossParams]): Parameters for MSELoss (default: None). + bce_with_logits (bool): If True, uses BCEWithLogitsLoss; otherwise, uses BCELoss. + """ + super().__init__() + + self.num_classes = num_classes + + # Process BCE parameters + _bce_params = bce_params.asdict(bce_with_logits) if bce_params is not None else {} + + # Choose BCE loss function + self.bce_loss = ( + nn.BCEWithLogitsLoss(**_bce_params) if bce_with_logits else nn.BCELoss(**_bce_params) + ) + + # Process MSE parameters + _mse_params = mse_params.asdict() if mse_params is not None else {} + + # Initialize MSE loss + self.mse_loss = nn.MSELoss(**_mse_params) + + # Using CumulativeAverage from MONAI to track loss metrics + self.loss_bce_metric = CumulativeAverage() + self.loss_mse_metric = CumulativeAverage() + + def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Computes the loss between true labels and prediction outputs. + + Args: + outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W). + target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W). + + Returns: + torch.Tensor: The total loss value. + """ + # Ensure target is on the same device as outputs + assert ( + target.device == outputs.device + ), ( + "Target tensor must be moved to the same device as outputs " + "before calling forward()." + ) + + # Cell Recognition Loss + cellprob_loss = self.bce_loss( + outputs[:, -self.num_classes:], target[:, self.num_classes:2 * self.num_classes].float() + ) + + # Cell Distinction Loss + gradflow_loss = 0.5 * self.mse_loss( + outputs[:, :2 * self.num_classes], 5.0 * target[:, 2 * self.num_classes:] + ) + + # Total loss + total_loss = cellprob_loss + gradflow_loss + + # Track individual losses + self.loss_bce_metric.append(cellprob_loss.item()) + self.loss_mse_metric.append(gradflow_loss.item()) + + return total_loss + + def get_loss_metrics(self) -> Dict[str, float]: + """ + Retrieves the tracked loss metrics. + + Returns: + Dict[str, float]: A dictionary containing the average BCE and MSE loss. + """ + return { + "bce_loss": round(self.loss_bce_metric.aggregate().item(), 4), + "mse_loss": round(self.loss_mse_metric.aggregate().item(), 4), + "loss": round( + self.loss_bce_metric.aggregate().item() + self.loss_mse_metric.aggregate().item(), 4 + ), + } + + def reset_metrics(self): + """Resets the stored loss metrics.""" + self.loss_bce_metric.reset() + self.loss_mse_metric.reset() diff --git a/core/models/__init__.py b/core/models/__init__.py new file mode 100644 index 0000000..35b6ec1 --- /dev/null +++ b/core/models/__init__.py @@ -0,0 +1,85 @@ +import torch.nn as nn +from typing import Dict, Final, Tuple, Type, Any, List, Union +from pydantic import BaseModel + +from .model_v import ModelV, ModelVParams + + +__all__ = [ + "ModelRegistry", + "ModelV", + "ModelVParams" +] + + +class ModelRegistry: + """Registry for models and their parameter classes with case-insensitive lookup.""" + + # Single dictionary storing both model classes and parameter classes. + __MODELS: Final[Dict[str, Dict[str, Type[Any]]]] = { + "ModelV": { + "class": ModelV, + "params": ModelVParams, + }, + } + + @classmethod + def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: + """ + Private method to retrieve the model entry from the registry using case-insensitive lookup. + + Args: + name (str): The name of the model. + + Returns: + Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. + + Raises: + ValueError: If the model is not found. + """ + name_lower = name.lower() + mapping = {key.lower(): key for key in cls.__MODELS} + original_key = mapping.get(name_lower) + if original_key is None: + raise ValueError( + f"Model '{name}' not found! Available options: {list(cls.__MODELS.keys())}" + ) + return cls.__MODELS[original_key] + + @classmethod + def get_model_class(cls, name: str) -> Type[nn.Module]: + """ + Retrieves the model class by name (case-insensitive). + + Args: + name (str): Name of the model. + + Returns: + Type[nn.Module]: The model class. + """ + entry = cls.__get_entry(name) + return entry["class"] + + @classmethod + def get_model_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: + """ + Retrieves the model parameter class by name (case-insensitive). + + Args: + name (str): Name of the model. + + Returns: + Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The model parameter class or a tuple of parameter classes. + """ + entry = cls.__get_entry(name) + return entry["params"] + + @classmethod + def get_available_models(cls) -> List[str]: + """ + Returns a list of available model names in their original case. + + Returns: + List[str]: List of available model names. + """ + return list(cls.__MODELS.keys()) diff --git a/core/models/model_v.py b/core/models/model_v.py new file mode 100644 index 0000000..6474dc1 --- /dev/null +++ b/core/models/model_v.py @@ -0,0 +1,117 @@ +from typing import List, Optional + +import torch +import torch.nn as nn + +from segmentation_models_pytorch import MAnet +from segmentation_models_pytorch.base.modules import Activation + +from pydantic import BaseModel, ConfigDict + +__all__ = ["ModelV"] + + +class ModelVParams(BaseModel): + model_config = ConfigDict(frozen=True) + + encoder_name: str = "mit_b5" # Default encoder + encoder_weights: Optional[str] = "imagenet" # Pre-trained weights + decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration + decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels + in_channels: int = 3 # Number of input channels + out_classes: int = 3 # Number of output classes + + def asdict(self): + """ + Returns a dictionary of valid parameters for `nn.ModelV`. + + Returns: + Dict[str, Any]: Dictionary of parameters for nn.ModelV. + """ + loss_kwargs = self.model_dump() + return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values + + +class ModelV(MAnet): + """ModelV model""" + + def __init__(self, params: ModelVParams) -> None: + # Initialize the MAnet model with provided parameters + super().__init__(**params.asdict()) + + # Remove the default segmentation head as it's not used in this architecture + self.segmentation_head = None + + # Modify all activation functions in the encoder and decoder from ReLU to Mish + _convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True)) + _convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True)) + + # Add custom segmentation heads for different segmentation tasks + self.cellprob_head = DeepSegmentationHead( + in_channels=params.decoder_channels[-1], out_channels=params.out_classes + ) + self.gradflow_head = DeepSegmentationHead( + in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the network""" + # Ensure the input shape is correct + self.check_input_shape(x) + + # Encode the input and then decode it + features = self.encoder(x) + decoder_output = self.decoder(*features) + + # Generate masks for cell probability and gradient flows + cellprob_mask = self.cellprob_head(decoder_output) + gradflow_mask = self.gradflow_head(decoder_output) + + # Concatenate the masks for output + masks = torch.cat([gradflow_mask, cellprob_mask], dim=1) + + return masks + + +class DeepSegmentationHead(nn.Sequential): + """Custom segmentation head for generating specific masks""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + activation: Optional[str] = None, + upsampling: int = 1, + ) -> None: + # Define a sequence of layers for the segmentation head + layers: List[nn.Module] = [ + nn.Conv2d( + in_channels, + in_channels // 2, + kernel_size=kernel_size, + padding=kernel_size // 2, + ), + nn.Mish(inplace=True), + nn.BatchNorm2d(in_channels // 2), + nn.Conv2d( + in_channels // 2, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + ), + nn.UpsamplingBilinear2d(scale_factor=upsampling) + if upsampling > 1 + else nn.Identity(), + Activation(activation) if activation else nn.Identity(), + ] + super().__init__(*layers) + + +def _convert_activations(module: nn.Module, from_activation: type, to_activation: nn.Module) -> None: + """Recursively convert activation functions in a module""" + for name, child in module.named_children(): + if isinstance(child, from_activation): + setattr(module, name, to_activation) + else: + _convert_activations(child, from_activation, to_activation) diff --git a/core/optimizers/__init__.py b/core/optimizers/__init__.py new file mode 100644 index 0000000..8e77a6b --- /dev/null +++ b/core/optimizers/__init__.py @@ -0,0 +1,92 @@ +import torch.optim as optim +from pydantic import BaseModel +from typing import Dict, Final, Tuple, Type, List, Any, Union + +from .adam import AdamParams +from .adamw import AdamWParams +from .sgd import SGDParams + +__all__ = [ + "OptimizerRegistry", + "AdamParams", "AdamWParams", "SGDParams" +] + +class OptimizerRegistry: + """Registry for optimizers and their parameter classes with case-insensitive lookup.""" + + # Single dictionary storing both optimizer classes and parameter classes. + __OPTIMIZERS: Final[Dict[str, Dict[str, Type[Any]]]] = { + "SGD": { + "class": optim.SGD, + "params": SGDParams, + }, + "Adam": { + "class": optim.Adam, + "params": AdamParams, + }, + "AdamW": { + "class": optim.AdamW, + "params": AdamWParams, + }, + } + + @classmethod + def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: + """ + Private method to retrieve the optimizer entry from the registry using case-insensitive lookup. + + Args: + name (str): The name of the optimizer. + + Returns: + Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. + + Raises: + ValueError: If the optimizer is not found. + """ + name_lower = name.lower() + mapping = {key.lower(): key for key in cls.__OPTIMIZERS} + original_key = mapping.get(name_lower) + if original_key is None: + raise ValueError( + f"Optimizer '{name}' not found! Available options: {list(cls.__OPTIMIZERS.keys())}" + ) + return cls.__OPTIMIZERS[original_key] + + @classmethod + def get_optimizer_class(cls, name: str) -> Type[optim.Optimizer]: + """ + Retrieves the optimizer class by name (case-insensitive). + + Args: + name (str): Name of the optimizer. + + Returns: + Type[optim.Optimizer]: The optimizer class. + """ + entry = cls.__get_entry(name) + return entry["class"] + + @classmethod + def get_optimizer_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: + """ + Retrieves the optimizer parameter class by name (case-insensitive). + + Args: + name (str): Name of the optimizer. + + Returns: + Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The optimizer parameter class or a tuple of parameter classes. + """ + entry = cls.__get_entry(name) + return entry["params"] + + @classmethod + def get_available_optimizers(cls) -> List[str]: + """ + Returns a list of available optimizer names in their original case. + + Returns: + List[str]: List of available optimizer names. + """ + return list(cls.__OPTIMIZERS.keys()) diff --git a/core/optimizers/adam.py b/core/optimizers/adam.py new file mode 100644 index 0000000..e70620a --- /dev/null +++ b/core/optimizers/adam.py @@ -0,0 +1,18 @@ +import torch +from typing import Any, Dict, Tuple +from pydantic import BaseModel, ConfigDict + + +class AdamParams(BaseModel): + """Configuration for `torch.optim.Adam` optimizer.""" + model_config = ConfigDict(frozen=True) + + lr: float = 1e-3 # Learning rate + betas: Tuple[float, float] = (0.9, 0.999) # Coefficients for computing running averages + eps: float = 1e-8 # Term added to denominator for numerical stability + weight_decay: float = 0.0 # L2 regularization + amsgrad: bool = False # Whether to use the AMSGrad variant + + def asdict(self) -> Dict[str, Any]: + """Returns a dictionary of valid parameters for `torch.optim.Adam`.""" + return self.model_dump() \ No newline at end of file diff --git a/core/optimizers/adamw.py b/core/optimizers/adamw.py new file mode 100644 index 0000000..3293f7f --- /dev/null +++ b/core/optimizers/adamw.py @@ -0,0 +1,17 @@ +import torch +from typing import Any, Dict, Tuple +from pydantic import BaseModel, ConfigDict + +class AdamWParams(BaseModel): + """Configuration for `torch.optim.AdamW` optimizer.""" + model_config = ConfigDict(frozen=True) + + lr: float = 1e-3 # Learning rate + betas: Tuple[float, ...] = (0.9, 0.999) # Adam coefficients + eps: float = 1e-8 # Numerical stability + weight_decay: float = 1e-2 # L2 penalty (AdamW uses decoupled weight decay) + amsgrad: bool = False # Whether to use the AMSGrad variant + + def asdict(self) -> Dict[str, Any]: + """Returns a dictionary of valid parameters for `torch.optim.AdamW`.""" + return self.model_dump() \ No newline at end of file diff --git a/core/optimizers/sgd.py b/core/optimizers/sgd.py new file mode 100644 index 0000000..d2e0e0e --- /dev/null +++ b/core/optimizers/sgd.py @@ -0,0 +1,18 @@ +import torch +from typing import Any, Dict +from pydantic import BaseModel, ConfigDict + + +class SGDParams(BaseModel): + """Configuration for `torch.optim.SGD` optimizer.""" + model_config = ConfigDict(frozen=True) + + lr: float = 1e-3 # Learning rate + momentum: float = 0.0 # Momentum factor + dampening: float = 0.0 # Dampening for momentum + weight_decay: float = 0.0 # L2 penalty + nesterov: bool = False # Enables Nesterov momentum + + def asdict(self) -> Dict[str, Any]: + """Returns a dictionary of valid parameters for `torch.optim.SGD`.""" + return self.model_dump() \ No newline at end of file diff --git a/core/schedulers/__init__.py b/core/schedulers/__init__.py new file mode 100644 index 0000000..5025196 --- /dev/null +++ b/core/schedulers/__init__.py @@ -0,0 +1,96 @@ +import torch.optim.lr_scheduler as lr_scheduler +from typing import Dict, Final, Tuple, Type, List, Any, Union +from pydantic import BaseModel + +from .step import StepLRParams +from .multi_step import MultiStepLRParams +from .exponential import ExponentialLRParams +from .cosine_annealing import CosineAnnealingLRParams + +__all__ = [ + "SchedulerRegistry", + "StepLRParams", "MultiStepLRParams", "ExponentialLRParams", "CosineAnnealingLRParams" +] + +class SchedulerRegistry: + """Registry for learning rate schedulers and their parameter classes with case-insensitive lookup.""" + + __SCHEDULERS: Final[Dict[str, Dict[str, Type[Any]]]] = { + "Step": { + "class": lr_scheduler.StepLR, + "params": StepLRParams, + }, + "Exponential": { + "class": lr_scheduler.ExponentialLR, + "params": ExponentialLRParams, + }, + "MultiStep": { + "class": lr_scheduler.MultiStepLR, + "params": MultiStepLRParams, + }, + "CosineAnnealing": { + "class": lr_scheduler.CosineAnnealingLR, + "params": CosineAnnealingLRParams, + }, + } + + @classmethod + def __get_entry(cls, name: str) -> Dict[str, Type[Any]]: + """ + Private method to retrieve the scheduler entry from the registry using case-insensitive lookup. + + Args: + name (str): The name of the scheduler. + + Returns: + Dict[str, Type[Any]]: A dictionary containing the keys 'class' and 'params'. + + Raises: + ValueError: If the scheduler is not found. + """ + name_lower = name.lower() + mapping = {key.lower(): key for key in cls.__SCHEDULERS} + original_key = mapping.get(name_lower) + if original_key is None: + raise ValueError( + f"Scheduler '{name}' not found! Available options: {list(cls.__SCHEDULERS.keys())}" + ) + return cls.__SCHEDULERS[original_key] + + @classmethod + def get_scheduler_class(cls, name: str) -> Type[lr_scheduler.LRScheduler]: + """ + Retrieves the scheduler class by name (case-insensitive). + + Args: + name (str): Name of the scheduler. + + Returns: + Type[lr_scheduler.LRScheduler]: The scheduler class. + """ + entry = cls.__get_entry(name) + return entry["class"] + + @classmethod + def get_scheduler_params(cls, name: str) -> Union[Type[BaseModel], Tuple[Type[BaseModel]]]: + """ + Retrieves the scheduler parameter class by name (case-insensitive). + + Args: + name (str): Name of the scheduler. + + Returns: + Union[Type[BaseModel], Tuple[Type[BaseModel]]]: The scheduler parameter class or a tuple of parameter classes. + """ + entry = cls.__get_entry(name) + return entry["params"] + + @classmethod + def get_available_schedulers(cls) -> List[str]: + """ + Returns a list of available scheduler names in their original case. + + Returns: + List[str]: List of available scheduler names. + """ + return list(cls.__SCHEDULERS.keys()) diff --git a/core/schedulers/cosine_annealing.py b/core/schedulers/cosine_annealing.py new file mode 100644 index 0000000..c086ffe --- /dev/null +++ b/core/schedulers/cosine_annealing.py @@ -0,0 +1,14 @@ +from typing import Any, Dict +from pydantic import BaseModel, ConfigDict + + +class CosineAnnealingLRParams(BaseModel): + """Configuration for `torch.optim.lr_scheduler.CosineAnnealingLR`.""" + model_config = ConfigDict(frozen=True) + + T_max: int = 100 # Maximum number of iterations + eta_min: float = 0.0 # Minimum learning rate + + def asdict(self) -> Dict[str, Any]: + """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.CosineAnnealingLR`.""" + return self.model_dump() diff --git a/core/schedulers/exponential.py b/core/schedulers/exponential.py new file mode 100644 index 0000000..d708a7e --- /dev/null +++ b/core/schedulers/exponential.py @@ -0,0 +1,13 @@ +from typing import Any, Dict +from pydantic import BaseModel, ConfigDict + + +class ExponentialLRParams(BaseModel): + """Configuration for `torch.optim.lr_scheduler.ExponentialLR`.""" + model_config = ConfigDict(frozen=True) + + gamma: float = 0.95 # Multiplicative factor of learning rate decay + + def asdict(self) -> Dict[str, Any]: + """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.ExponentialLR`.""" + return self.model_dump() \ No newline at end of file diff --git a/core/schedulers/multi_step.py b/core/schedulers/multi_step.py new file mode 100644 index 0000000..1b02788 --- /dev/null +++ b/core/schedulers/multi_step.py @@ -0,0 +1,14 @@ +from typing import Any, Dict, Tuple +from pydantic import BaseModel, ConfigDict + + +class MultiStepLRParams(BaseModel): + """Configuration for `torch.optim.lr_scheduler.MultiStepLR`.""" + model_config = ConfigDict(frozen=True) + + milestones: Tuple[int, ...] = (30, 80) # List of epoch indices for LR decay + gamma: float = 0.1 # Multiplicative factor of learning rate decay + + def asdict(self) -> Dict[str, Any]: + """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.MultiStepLR`.""" + return self.model_dump() \ No newline at end of file diff --git a/core/schedulers/step.py b/core/schedulers/step.py new file mode 100644 index 0000000..1bf4c56 --- /dev/null +++ b/core/schedulers/step.py @@ -0,0 +1,14 @@ +from typing import Any, Dict +from pydantic import BaseModel, ConfigDict + + +class StepLRParams(BaseModel): + """Configuration for `torch.optim.lr_scheduler.StepLR`.""" + model_config = ConfigDict(frozen=True) + + step_size: int = 30 # Period of learning rate decay + gamma: float = 0.1 # Multiplicative factor of learning rate decay + + def asdict(self) -> Dict[str, Any]: + """Returns a dictionary of valid parameters for `torch.optim.lr_scheduler.StepLR`.""" + return self.model_dump() \ No newline at end of file diff --git a/core/segmentator.py b/core/segmentator.py new file mode 100644 index 0000000..e69de29 diff --git a/generate_config.py b/generate_config.py new file mode 100644 index 0000000..8172794 --- /dev/null +++ b/generate_config.py @@ -0,0 +1,120 @@ +import os +from pydantic import BaseModel +from typing import Any, Dict, Type, Union, List + +from config.config import Config +from config.dataset_config import DatasetConfig + +from core import ( + ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry +) + + +def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]: + """ + Instantiates the parameter class(es) with default values. + + If 'param' is a tuple, instantiate each class and return a list of instances. + Otherwise, instantiate the single class and return the instance. + """ + if isinstance(param, tuple): + return [cls() for cls in param] + else: + return param() + +def prompt_choice(prompt_message: str, options: List[str]) -> str: + """ + Prompt the user with a list of options and return the selected option. + """ + print(prompt_message) + for i, option in enumerate(options, start=1): + print(f"{i}. {option}") + while True: + try: + choice = int(input("Enter your choice (number): ")) + if 1 <= choice <= len(options): + return options[choice - 1] + else: + print("Invalid choice. Please try again.") + except ValueError: + print("Please enter a valid number.") + +def main(): + # Determine the directory of this script. + script_path = os.path.dirname(os.path.abspath(__file__)) + + # Ask the user whether this is training mode. + training_input = input("Is this training mode? (y/n): ").strip().lower() + is_training = training_input in ("y", "yes") + + # Create a default DatasetConfig based on the training mode. + # The DatasetConfig.default_config method fills in required fields with zero-values. + dataset_config = DatasetConfig(is_training=is_training) + + # Prompt the user to select a model. + model_options = ModelRegistry.get_available_models() + chosen_model = prompt_choice("\nSelect a model:", model_options) + model_param_class = ModelRegistry.get_model_params(chosen_model) + model_instance = instantiate_params(model_param_class) + + if is_training is False: + config = Config( + model={chosen_model: model_instance}, + dataset_config=dataset_config + ) + + # Construct a base filename from the selected registry names. + base_filename = f"{chosen_model}" + + else: + # Prompt the user to select a criterion. + criterion_options = CriterionRegistry.get_available_criterions() + chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options) + criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion) + criterion_instance = instantiate_params(criterion_param_class) + + # Prompt the user to select an optimizer. + optimizer_options = OptimizerRegistry.get_available_optimizers() + chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options) + optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer) + optimizer_instance = instantiate_params(optimizer_param_class) + + # Prompt the user to select a scheduler. + scheduler_options = SchedulerRegistry.get_available_schedulers() + chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options) + scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler) + scheduler_instance = instantiate_params(scheduler_param_class) + + # Assemble the overall configuration using the registry names as keys. + config = Config( + model={chosen_model: model_instance}, + dataset_config=dataset_config, + criterion={chosen_criterion: criterion_instance}, + optimizer={chosen_optimizer: optimizer_instance}, + scheduler={chosen_scheduler: scheduler_instance} + ) + + # Construct a base filename from the selected registry names. + base_filename = f"{chosen_model}_{chosen_criterion}_{chosen_optimizer}_{chosen_scheduler}" + + # Determine the output directory relative to this script. + base_dir = os.path.join(script_path, "config/jsons", "train" if is_training else "predict") + os.makedirs(base_dir, exist_ok=True) + + filename = f"{base_filename}.json" + full_path = os.path.join(base_dir, filename) + counter = 1 + + # Append a counter if a file with the same name exists. + while os.path.exists(full_path): + filename = f"{base_filename}_{counter}.json" + full_path = os.path.join(base_dir, filename) + counter += 1 + + # Save the configuration as a JSON file. + config.save_json(full_path) + + print(f"\nConfiguration saved to: {full_path}") + +if __name__ == "__main__": + main() diff --git a/train.py b/train.py new file mode 100644 index 0000000..f03ef41 --- /dev/null +++ b/train.py @@ -0,0 +1,14 @@ +from config.config import Config +from pprint import pprint + + +config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json') +pprint(config, indent=4) + +print('\n\n') +config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV.json') +pprint(config, indent=4) + +print('\n\n') +config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV_1.json') +pprint(config, indent=4) \ No newline at end of file