from pydantic import BaseModel, model_validator, field_validator from typing import Any, Dict, Optional, Union import os class DatasetCommonConfig(BaseModel): """ Common configuration fields shared by both training and testing. """ device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda') use_tta: bool = False # Flag to use Test-Time Augmentation (TTA) predictions_dir: str = "." # Directory to save predictions @model_validator(mode="after") def validate_common(self) -> "DatasetCommonConfig": """ Validates that device is non-empty. """ if not self.device: raise ValueError("device must be provided and non-empty") return self class TrainingSplitInfo(BaseModel): """ Configuration for training mode when data is NOT pre-split (is_split is False). Contains: - split_seed: Seed for splitting. - all_data_dir: Directory containing all data. """ split_seed: int = 0 # Seed for splitting if data is not pre-split all_data_dir: str = "." # Directory containing all data if not pre-split class TrainingPreSplitInfo(BaseModel): """ Configuration for training mode when data is pre-split (is_split is True). Contains: - train_dir, valid_dir, test_dir: Directories for training, validation, and testing data. - train_size, valid_size, test_size: Data split ratios or counts. - train_offset, valid_offset, test_offset: Offsets for respective splits. """ train_dir: str = "." # Directory for training data if data is pre-split valid_dir: str = "" # Directory for validation data if data is pre-split test_dir: str = "" # Directory for testing data if data is pre-split train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic) valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic) test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic) train_offset: int = 0 # Offset for training data valid_offset: int = 0 # Offset for validation data test_offset: int = 0 # Offset for testing data @field_validator("train_size", "valid_size", "test_size", mode="before") def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]: """ Validates size values: - If provided as a float, must be in the range (0, 1]. - If provided as an int, must be non-negative. """ if isinstance(v, float): if not (0 <= v <= 1): raise ValueError("When provided as a float, size must be in the range (0, 1]") elif isinstance(v, int): if v < 0: raise ValueError("When provided as an int, size must be non-negative") else: raise ValueError("Size must be either an int or a float") return v class DatasetTrainingConfig(BaseModel): """ Main training configuration. Contains: - is_split: Determines whether data is pre-split. - pre_split: Configuration for when data is NOT pre-split. - split: Configuration for when data is pre-split. - Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights. Both pre_split and split objects are always created, but only the one corresponding to is_split is validated. """ is_split: bool = False # Whether the data is already split into train/validation sets pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo() split: TrainingSplitInfo = TrainingSplitInfo() batch_size: int = 1 # Batch size for training num_epochs: int = 100 # Number of training epochs val_freq: int = 1 # Frequency of validation during training use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) pretrained_weights: str = "" # Path to pretrained weights for training @model_validator(mode="after") def validate_split_info(self) -> "DatasetTrainingConfig": """ Conditionally validates the nested split objects: - If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist). - If is_split is False, validates split (all_data_dir must be non-empty and exist). """ if not self.is_split: if not self.split.all_data_dir: raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split") if not os.path.exists(self.split.all_data_dir): raise ValueError(f"Path for all_data_dir does not exist: {self.split.all_data_dir}") else: if not self.pre_split.train_dir: raise ValueError("When is_split is True, train_dir must be provided and non-empty in split") if not os.path.exists(self.pre_split.train_dir): raise ValueError(f"Path for train_dir does not exist: {self.pre_split.train_dir}") if self.pre_split.valid_dir and not os.path.exists(self.pre_split.valid_dir): raise ValueError(f"Path for valid_dir does not exist: {self.pre_split.valid_dir}") if self.pre_split.test_dir and not os.path.exists(self.pre_split.test_dir): raise ValueError(f"Path for test_dir does not exist: {self.pre_split.test_dir}") return self @model_validator(mode="after") def validate_numeric_fields(self) -> "DatasetTrainingConfig": """ Validates numeric fields: - batch_size and num_epochs must be > 0. - val_freq must be >= 0. """ if self.batch_size <= 0: raise ValueError("batch_size must be > 0") if self.num_epochs <= 0: raise ValueError("num_epochs must be > 0") if self.val_freq < 0: raise ValueError("val_freq must be >= 0") return self @model_validator(mode="after") def validate_pretrained(self) -> "DatasetTrainingConfig": """ Validates that pretrained_weights is provided and exists. """ if self.pretrained_weights and not os.path.exists(self.pretrained_weights): raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}") return self class DatasetTestingConfig(BaseModel): """ Configuration fields used only in testing mode. """ test_dir: str = "." # Test data directory; must be non-empty test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic) test_offset: int = 0 # Offset for testing data use_ensemble: bool = False # Flag to use ensemble mode in testing ensemble_pretrained_weights1: str = "." ensemble_pretrained_weights2: str = "." pretrained_weights: str = "." @field_validator("test_size", mode="before") def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]: """ Validates the test_size value. """ if isinstance(v, float): if not (0 < v <= 1): raise ValueError("When provided as a float, test_size must be in the range (0, 1]") elif isinstance(v, int): if v < 0: raise ValueError("When provided as an int, test_size must be non-negative") else: raise ValueError("test_size must be either an int or a float") return v @model_validator(mode="after") def validate_testing(self) -> "DatasetTestingConfig": """ Validates the testing configuration: - test_dir must be non-empty and exist. - If use_ensemble is True, both ensemble_pretrained_weights1 and ensemble_pretrained_weights2 must be provided and exist. - If use_ensemble is False, pretrained_weights must be provided and exist. """ if not self.test_dir: raise ValueError("In testing configuration, test_dir must be provided and non-empty") if not os.path.exists(self.test_dir): raise ValueError(f"Path for test_dir does not exist: {self.test_dir}") if self.use_ensemble: for field in ["ensemble_pretrained_weights1", "ensemble_pretrained_weights2"]: value = getattr(self, field) if not value: raise ValueError(f"When use_ensemble is True, {field} must be provided and non-empty") if not os.path.exists(value): raise ValueError(f"Path for {field} does not exist: {value}") else: if not self.pretrained_weights: raise ValueError("When use_ensemble is False, pretrained_weights must be provided and non-empty") if not os.path.exists(self.pretrained_weights): raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}") if self.test_offset < 0: raise ValueError("test_offset must be >= 0") return self class DatasetConfig(BaseModel): """ Main dataset configuration that groups fields into nested models for a structured and readable JSON. """ is_training: bool = True # Flag indicating whether the configuration is for training (True) or testing (False) common: DatasetCommonConfig = DatasetCommonConfig() training: DatasetTrainingConfig = DatasetTrainingConfig() testing: DatasetTestingConfig = DatasetTestingConfig() @model_validator(mode="after") def validate_config(self) -> "DatasetConfig": """ Validates the overall dataset configuration: """ if self.is_training: if self.training is None: raise ValueError("Training configuration must be provided when is_training is True") # Check predictions_dir if training.split.test_dir and test_size are set if self.training.pre_split.test_dir and self.training.pre_split.test_size > 0: if not self.common.predictions_dir: raise ValueError("predictions_dir must be provided when training.split.test_dir and test_size are non-zero") if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") else: if self.testing is None: raise ValueError("Testing configuration must be provided when is_training is False") if self.testing.test_dir and self.testing.test_size > 0: if not self.common.predictions_dir: raise ValueError("predictions_dir must be provided when testing.test_dir and test_size are non-zero") if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") return self def model_dump(self, **kwargs) -> Dict[str, Any]: """ Dumps only the relevant configuration depending on the is_training flag. Only the nested configuration (training or testing) along with common fields is returned. """ if self.is_training: return { "is_training": self.is_training, "common": self.common.model_dump(), "training": self.training.model_dump() if self.training else {} } else: return { "is_training": self.is_training, "common": self.common.model_dump(), "testing": self.testing.model_dump() if self.testing else {} }