diff --git a/config/dataset_config.py b/config/dataset_config.py index ccbfdfa..8ef5901 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -36,36 +36,10 @@ class TrainingPreSplitInfo(BaseModel): Configuration for training mode when data is pre-split (is_split is True). Contains: - train_dir, valid_dir, test_dir: Directories for training, validation, and testing data. - - train_size, valid_size, test_size: Data split ratios or counts. - - train_offset, valid_offset, test_offset: Offsets for respective splits. """ train_dir: str = "." # Directory for training data if data is pre-split valid_dir: str = "" # Directory for validation data if data is pre-split test_dir: str = "" # Directory for testing data if data is pre-split - train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic) - valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic) - test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic) - train_offset: int = 0 # Offset for training data - valid_offset: int = 0 # Offset for validation data - test_offset: int = 0 # Offset for testing data - - - @field_validator("train_size", "valid_size", "test_size", mode="before") - def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]: - """ - Validates size values: - - If provided as a float, must be in the range (0, 1]. - - If provided as an int, must be non-negative. - """ - if isinstance(v, float): - if not (0 <= v <= 1): - raise ValueError("When provided as a float, size must be in the range (0, 1]") - elif isinstance(v, int): - if v < 0: - raise ValueError("When provided as an int, size must be non-negative") - else: - raise ValueError("Size must be either an int or a float") - return v class DatasetTrainingConfig(BaseModel): @@ -75,6 +49,8 @@ class DatasetTrainingConfig(BaseModel): - is_split: Determines whether data is pre-split. - pre_split: Configuration for when data is NOT pre-split. - split: Configuration for when data is pre-split. + - train_size, valid_size, test_size: Data split ratios or counts. + - train_offset, valid_offset, test_offset: Offsets for respective splits. - Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights. Both pre_split and split objects are always created, but only the one corresponding @@ -84,12 +60,37 @@ class DatasetTrainingConfig(BaseModel): pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo() split: TrainingSplitInfo = TrainingSplitInfo() + train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic) + valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic) + test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic) + train_offset: int = 0 # Offset for training data + valid_offset: int = 0 # Offset for validation data + test_offset: int = 0 # Offset for testing data + batch_size: int = 1 # Batch size for training num_epochs: int = 100 # Number of training epochs val_freq: int = 1 # Frequency of validation during training use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) pretrained_weights: str = "" # Path to pretrained weights for training + + @field_validator("train_size", "valid_size", "test_size", mode="before") + def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]: + """ + Validates size values: + - If provided as a float, must be in the range (0, 1]. + - If provided as an int, must be non-negative. + """ + if isinstance(v, float): + if not (0 <= v <= 1): + raise ValueError("When provided as a float, size must be in the range (0, 1]") + elif isinstance(v, int): + if v < 0: + raise ValueError("When provided as an int, size must be non-negative") + else: + raise ValueError("Size must be either an int or a float") + return v + @model_validator(mode="after") def validate_split_info(self) -> "DatasetTrainingConfig": """ @@ -211,18 +212,17 @@ class DatasetConfig(BaseModel): if self.is_training: if self.training is None: raise ValueError("Training configuration must be provided when is_training is True") - # Check predictions_dir if training.split.test_dir and test_size are set - if self.training.pre_split.test_dir and self.training.pre_split.test_size > 0: - if not self.common.predictions_dir: - raise ValueError("predictions_dir must be provided when training.split.test_dir and test_size are non-zero") + if self.training.train_size == 0: + raise ValueError("train_size must be provided when is_training is True") + if self.training.test_size > 0 and not self.common.predictions_dir: + raise ValueError("predictions_dir must be provided when test_size is non-zero") if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") else: if self.testing is None: raise ValueError("Testing configuration must be provided when is_training is False") - if self.testing.test_dir and self.testing.test_size > 0: - if not self.common.predictions_dir: - raise ValueError("predictions_dir must be provided when testing.test_dir and test_size are non-zero") + if self.testing.test_size > 0 and not self.common.predictions_dir: + raise ValueError("predictions_dir must be provided when test_size is non-zero") if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") return self diff --git a/core/criteria/__init__.py b/core/criteria/__init__.py index 86c32c7..d8e1cdc 100644 --- a/core/criteria/__init__.py +++ b/core/criteria/__init__.py @@ -87,11 +87,11 @@ class CriterionRegistry: return entry["params"] @classmethod - def get_available_criterions(cls) -> List[str]: + def get_available_criterions(cls) -> Tuple[str, ...]: """ - Returns a list of available loss function names in their original case. + Returns a tuple of available loss function names in their original case. Returns: - List[str]: List of available loss function names. + Tuple[str]: Tuple of available loss function names. """ - return list(cls.__CRITERIONS.keys()) + return tuple(cls.__CRITERIONS.keys()) diff --git a/core/criteria/bce.py b/core/criteria/bce.py index 8dd2a62..451fb43 100644 --- a/core/criteria/bce.py +++ b/core/criteria/bce.py @@ -1,5 +1,5 @@ from .base import * -from typing import Literal +from typing import List, Literal, Union from pydantic import BaseModel, ConfigDict @@ -7,11 +7,11 @@ class BCELossParams(BaseModel): """ Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`. """ - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + model_config = ConfigDict(frozen=True) - weight: Optional[torch.Tensor] = None # Sample weights + weight: Optional[List[Union[int, float]]] = None # Sample weights reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method - pos_weight: Optional[torch.Tensor] = None # Used only for BCEWithLogitsLoss + pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss def asdict(self, with_logits: bool = False) -> Dict[str, Any]: """ @@ -31,6 +31,15 @@ class BCELossParams(BaseModel): if not with_logits: loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss + weight = loss_kwargs.get("weight") + pos_weight = loss_kwargs.get("pos_weight") + + if weight is not None: + loss_kwargs["weight"] = torch.Tensor(weight) + + if pos_weight is not None: + loss_kwargs["pos_weight"] = torch.Tensor(pos_weight) + return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values diff --git a/core/criteria/ce.py b/core/criteria/ce.py index fe4a086..00e641b 100644 --- a/core/criteria/ce.py +++ b/core/criteria/ce.py @@ -1,5 +1,5 @@ from .base import * -from typing import Literal +from typing import List, Literal, Union from pydantic import BaseModel, ConfigDict @@ -7,9 +7,9 @@ class CrossEntropyLossParams(BaseModel): """ Class for handling parameters for `nn.CrossEntropyLoss`. """ - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + model_config = ConfigDict(frozen=True) - weight: Optional[torch.Tensor] = None + weight: Optional[List[Union[int, float]]] = None ignore_index: int = -100 reduction: Literal["none", "mean", "sum"] = "mean" label_smoothing: float = 0.0 @@ -22,6 +22,11 @@ class CrossEntropyLossParams(BaseModel): Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss. """ loss_kwargs = self.model_dump() + + weight = loss_kwargs.get("weight") + if weight is not None: + loss_kwargs["weight"] = torch.Tensor(weight) + return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values diff --git a/core/models/__init__.py b/core/models/__init__.py index 35b6ec1..75ba2a8 100644 --- a/core/models/__init__.py +++ b/core/models/__init__.py @@ -75,11 +75,11 @@ class ModelRegistry: return entry["params"] @classmethod - def get_available_models(cls) -> List[str]: + def get_available_models(cls) -> Tuple[str, ...]: """ - Returns a list of available model names in their original case. + Returns a tuple of available model names in their original case. Returns: - List[str]: List of available model names. + Tuple[str]: Tuple of available model names. """ - return list(cls.__MODELS.keys()) + return tuple(cls.__MODELS.keys()) diff --git a/core/optimizers/__init__.py b/core/optimizers/__init__.py index 8e77a6b..2841b3c 100644 --- a/core/optimizers/__init__.py +++ b/core/optimizers/__init__.py @@ -82,11 +82,11 @@ class OptimizerRegistry: return entry["params"] @classmethod - def get_available_optimizers(cls) -> List[str]: + def get_available_optimizers(cls) -> Tuple[str, ...]: """ - Returns a list of available optimizer names in their original case. + Returns a tuple of available optimizer names in their original case. Returns: - List[str]: List of available optimizer names. + Tuple[str]: Tuple of available optimizer names. """ - return list(cls.__OPTIMIZERS.keys()) + return tuple(cls.__OPTIMIZERS.keys()) diff --git a/core/optimizers/adam.py b/core/optimizers/adam.py index e70620a..1919c77 100644 --- a/core/optimizers/adam.py +++ b/core/optimizers/adam.py @@ -1,4 +1,3 @@ -import torch from typing import Any, Dict, Tuple from pydantic import BaseModel, ConfigDict diff --git a/core/optimizers/adamw.py b/core/optimizers/adamw.py index 3293f7f..755e3c0 100644 --- a/core/optimizers/adamw.py +++ b/core/optimizers/adamw.py @@ -1,4 +1,3 @@ -import torch from typing import Any, Dict, Tuple from pydantic import BaseModel, ConfigDict diff --git a/core/optimizers/sgd.py b/core/optimizers/sgd.py index d2e0e0e..0de3cc3 100644 --- a/core/optimizers/sgd.py +++ b/core/optimizers/sgd.py @@ -1,4 +1,3 @@ -import torch from typing import Any, Dict from pydantic import BaseModel, ConfigDict diff --git a/core/schedulers/__init__.py b/core/schedulers/__init__.py index 5025196..2e92f1d 100644 --- a/core/schedulers/__init__.py +++ b/core/schedulers/__init__.py @@ -86,11 +86,11 @@ class SchedulerRegistry: return entry["params"] @classmethod - def get_available_schedulers(cls) -> List[str]: + def get_available_schedulers(cls) -> Tuple[str, ...]: """ - Returns a list of available scheduler names in their original case. + Returns a tuple of available scheduler names in their original case. Returns: - List[str]: List of available scheduler names. + Tuple[str]: Tuple of available scheduler names. """ - return list(cls.__SCHEDULERS.keys()) + return tuple(cls.__SCHEDULERS.keys()) diff --git a/generate_config.py b/generate_config.py index 8172794..054eaaa 100644 --- a/generate_config.py +++ b/generate_config.py @@ -1,6 +1,6 @@ import os from pydantic import BaseModel -from typing import Any, Dict, Type, Union, List +from typing import Any, Dict, Tuple, Type, Union, List from config.config import Config from config.dataset_config import DatasetConfig @@ -22,7 +22,7 @@ def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]: else: return param() -def prompt_choice(prompt_message: str, options: List[str]) -> str: +def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str: """ Prompt the user with a list of options and return the selected option. """