You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
251 lines
12 KiB
251 lines
12 KiB
from pydantic import BaseModel, model_validator, field_validator
|
|
from typing import Any, Dict, Optional, Union
|
|
import os
|
|
|
|
|
|
class DatasetCommonConfig(BaseModel):
|
|
"""
|
|
Common configuration fields shared by both training and testing.
|
|
"""
|
|
device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
|
|
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
|
|
predictions_dir: str = "." # Directory to save predictions
|
|
|
|
@model_validator(mode="after")
|
|
def validate_common(self) -> "DatasetCommonConfig":
|
|
"""
|
|
Validates that device is non-empty.
|
|
"""
|
|
if not self.device:
|
|
raise ValueError("device must be provided and non-empty")
|
|
return self
|
|
|
|
|
|
class TrainingSplitInfo(BaseModel):
|
|
"""
|
|
Configuration for training mode when data is NOT pre-split (is_split is False).
|
|
Contains:
|
|
- split_seed: Seed for splitting.
|
|
- all_data_dir: Directory containing all data.
|
|
"""
|
|
split_seed: int = 0 # Seed for splitting if data is not pre-split
|
|
all_data_dir: str = "." # Directory containing all data if not pre-split
|
|
|
|
class TrainingPreSplitInfo(BaseModel):
|
|
"""
|
|
Configuration for training mode when data is pre-split (is_split is True).
|
|
Contains:
|
|
- train_dir, valid_dir, test_dir: Directories for training, validation, and testing data.
|
|
"""
|
|
train_dir: str = "." # Directory for training data if data is pre-split
|
|
valid_dir: str = "" # Directory for validation data if data is pre-split
|
|
test_dir: str = "" # Directory for testing data if data is pre-split
|
|
|
|
|
|
class DatasetTrainingConfig(BaseModel):
|
|
"""
|
|
Main training configuration.
|
|
Contains:
|
|
- is_split: Determines whether data is pre-split.
|
|
- pre_split: Configuration for when data is NOT pre-split.
|
|
- split: Configuration for when data is pre-split.
|
|
- train_size, valid_size, test_size: Data split ratios or counts.
|
|
- train_offset, valid_offset, test_offset: Offsets for respective splits.
|
|
- Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights.
|
|
|
|
Both pre_split and split objects are always created, but only the one corresponding
|
|
to is_split is validated.
|
|
"""
|
|
is_split: bool = False # Whether the data is already split into train/validation sets
|
|
pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo()
|
|
split: TrainingSplitInfo = TrainingSplitInfo()
|
|
|
|
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
|
|
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic)
|
|
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic)
|
|
train_offset: int = 0 # Offset for training data
|
|
valid_offset: int = 0 # Offset for validation data
|
|
test_offset: int = 0 # Offset for testing data
|
|
|
|
batch_size: int = 1 # Batch size for training
|
|
num_epochs: int = 100 # Number of training epochs
|
|
val_freq: int = 1 # Frequency of validation during training
|
|
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
|
|
pretrained_weights: str = "" # Path to pretrained weights for training
|
|
|
|
|
|
@field_validator("train_size", "valid_size", "test_size", mode="before")
|
|
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]:
|
|
"""
|
|
Validates size values:
|
|
- If provided as a float, must be in the range (0, 1].
|
|
- If provided as an int, must be non-negative.
|
|
"""
|
|
if isinstance(v, float):
|
|
if not (0 <= v <= 1):
|
|
raise ValueError("When provided as a float, size must be in the range (0, 1]")
|
|
elif isinstance(v, int):
|
|
if v < 0:
|
|
raise ValueError("When provided as an int, size must be non-negative")
|
|
else:
|
|
raise ValueError("Size must be either an int or a float")
|
|
return v
|
|
|
|
@model_validator(mode="after")
|
|
def validate_split_info(self) -> "DatasetTrainingConfig":
|
|
"""
|
|
Conditionally validates the nested split objects:
|
|
- If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist).
|
|
- If is_split is False, validates split (all_data_dir must be non-empty and exist).
|
|
"""
|
|
if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)):
|
|
if (self.train_size + self.valid_size + self.test_size) > 1:
|
|
raise ValueError("The total sample size with dynamically defined sizes must be <= 1")
|
|
|
|
if not self.is_split:
|
|
if not self.split.all_data_dir:
|
|
raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split")
|
|
if not os.path.exists(self.split.all_data_dir):
|
|
raise ValueError(f"Path for all_data_dir does not exist: {self.split.all_data_dir}")
|
|
else:
|
|
if not self.pre_split.train_dir:
|
|
raise ValueError("When is_split is True, train_dir must be provided and non-empty in split")
|
|
if not os.path.exists(self.pre_split.train_dir):
|
|
raise ValueError(f"Path for train_dir does not exist: {self.pre_split.train_dir}")
|
|
if self.pre_split.valid_dir and not os.path.exists(self.pre_split.valid_dir):
|
|
raise ValueError(f"Path for valid_dir does not exist: {self.pre_split.valid_dir}")
|
|
if self.pre_split.test_dir and not os.path.exists(self.pre_split.test_dir):
|
|
raise ValueError(f"Path for test_dir does not exist: {self.pre_split.test_dir}")
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def validate_numeric_fields(self) -> "DatasetTrainingConfig":
|
|
"""
|
|
Validates numeric fields:
|
|
- batch_size and num_epochs must be > 0.
|
|
- val_freq must be >= 0.
|
|
"""
|
|
if self.batch_size <= 0:
|
|
raise ValueError("batch_size must be > 0")
|
|
if self.num_epochs <= 0:
|
|
raise ValueError("num_epochs must be > 0")
|
|
if self.val_freq < 0:
|
|
raise ValueError("val_freq must be >= 0")
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def validate_pretrained(self) -> "DatasetTrainingConfig":
|
|
"""
|
|
Validates that pretrained_weights is provided and exists.
|
|
"""
|
|
if self.pretrained_weights and not os.path.exists(self.pretrained_weights):
|
|
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
|
|
return self
|
|
|
|
|
|
class DatasetTestingConfig(BaseModel):
|
|
"""
|
|
Configuration fields used only in testing mode.
|
|
"""
|
|
test_dir: str = "." # Test data directory; must be non-empty
|
|
test_size: Union[int, float] = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
|
|
test_offset: int = 0 # Offset for testing data
|
|
use_ensemble: bool = False # Flag to use ensemble mode in testing
|
|
ensemble_pretrained_weights1: str = "."
|
|
ensemble_pretrained_weights2: str = "."
|
|
pretrained_weights: str = "."
|
|
|
|
@field_validator("test_size", mode="before")
|
|
def validate_test_size(cls, v: Union[int, float]) -> Union[int, float]:
|
|
"""
|
|
Validates the test_size value.
|
|
"""
|
|
if isinstance(v, float):
|
|
if not (0 < v <= 1):
|
|
raise ValueError("When provided as a float, test_size must be in the range (0, 1]")
|
|
elif isinstance(v, int):
|
|
if v < 0:
|
|
raise ValueError("When provided as an int, test_size must be non-negative")
|
|
else:
|
|
raise ValueError("test_size must be either an int or a float")
|
|
return v
|
|
|
|
@model_validator(mode="after")
|
|
def validate_testing(self) -> "DatasetTestingConfig":
|
|
"""
|
|
Validates the testing configuration:
|
|
- test_dir must be non-empty and exist.
|
|
- If use_ensemble is True, both ensemble_pretrained_weights1 and ensemble_pretrained_weights2 must be provided and exist.
|
|
- If use_ensemble is False, pretrained_weights must be provided and exist.
|
|
"""
|
|
if not self.test_dir:
|
|
raise ValueError("In testing configuration, test_dir must be provided and non-empty")
|
|
if not os.path.exists(self.test_dir):
|
|
raise ValueError(f"Path for test_dir does not exist: {self.test_dir}")
|
|
if self.use_ensemble:
|
|
for field in ["ensemble_pretrained_weights1", "ensemble_pretrained_weights2"]:
|
|
value = getattr(self, field)
|
|
if not value:
|
|
raise ValueError(f"When use_ensemble is True, {field} must be provided and non-empty")
|
|
if not os.path.exists(value):
|
|
raise ValueError(f"Path for {field} does not exist: {value}")
|
|
else:
|
|
if not self.pretrained_weights:
|
|
raise ValueError("When use_ensemble is False, pretrained_weights must be provided and non-empty")
|
|
if not os.path.exists(self.pretrained_weights):
|
|
raise ValueError(f"Path for pretrained_weights does not exist: {self.pretrained_weights}")
|
|
if self.test_offset < 0:
|
|
raise ValueError("test_offset must be >= 0")
|
|
return self
|
|
|
|
|
|
class DatasetConfig(BaseModel):
|
|
"""
|
|
Main dataset configuration that groups fields into nested models for a structured and readable JSON.
|
|
"""
|
|
is_training: bool = True # Flag indicating whether the configuration is for training (True) or testing (False)
|
|
common: DatasetCommonConfig = DatasetCommonConfig()
|
|
training: DatasetTrainingConfig = DatasetTrainingConfig()
|
|
testing: DatasetTestingConfig = DatasetTestingConfig()
|
|
|
|
@model_validator(mode="after")
|
|
def validate_config(self) -> "DatasetConfig":
|
|
"""
|
|
Validates the overall dataset configuration:
|
|
"""
|
|
if self.is_training:
|
|
if self.training is None:
|
|
raise ValueError("Training configuration must be provided when is_training is True")
|
|
if self.training.train_size == 0:
|
|
raise ValueError("train_size must be provided when is_training is True")
|
|
if self.training.test_size > 0 and not self.common.predictions_dir:
|
|
raise ValueError("predictions_dir must be provided when test_size is non-zero")
|
|
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
|
|
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
|
|
else:
|
|
if self.testing is None:
|
|
raise ValueError("Testing configuration must be provided when is_training is False")
|
|
if self.testing.test_size > 0 and not self.common.predictions_dir:
|
|
raise ValueError("predictions_dir must be provided when test_size is non-zero")
|
|
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
|
|
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
|
|
return self
|
|
|
|
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
|
"""
|
|
Dumps only the relevant configuration depending on the is_training flag.
|
|
Only the nested configuration (training or testing) along with common fields is returned.
|
|
"""
|
|
if self.is_training:
|
|
return {
|
|
"is_training": self.is_training,
|
|
"common": self.common.model_dump(),
|
|
"training": self.training.model_dump() if self.training else {}
|
|
}
|
|
else:
|
|
return {
|
|
"is_training": self.is_training,
|
|
"common": self.common.model_dump(),
|
|
"testing": self.testing.model_dump() if self.testing else {}
|
|
}
|