minor fixes

master
laynholt 3 months ago
parent e00331b503
commit 33ce003657

@ -36,36 +36,10 @@ class TrainingPreSplitInfo(BaseModel):
Configuration for training mode when data is pre-split (is_split is True).
Contains:
- train_dir, valid_dir, test_dir: Directories for training, validation, and testing data.
- train_size, valid_size, test_size: Data split ratios or counts.
- train_offset, valid_offset, test_offset: Offsets for respective splits.
"""
train_dir: str = "." # Directory for training data if data is pre-split
valid_dir: str = "" # Directory for validation data if data is pre-split
test_dir: str = "" # Directory for testing data if data is pre-split
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data
@field_validator("train_size", "valid_size", "test_size", mode="before")
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]:
"""
Validates size values:
- If provided as a float, must be in the range (0, 1].
- If provided as an int, must be non-negative.
"""
if isinstance(v, float):
if not (0 <= v <= 1):
raise ValueError("When provided as a float, size must be in the range (0, 1]")
elif isinstance(v, int):
if v < 0:
raise ValueError("When provided as an int, size must be non-negative")
else:
raise ValueError("Size must be either an int or a float")
return v
class DatasetTrainingConfig(BaseModel):
@ -75,6 +49,8 @@ class DatasetTrainingConfig(BaseModel):
- is_split: Determines whether data is pre-split.
- pre_split: Configuration for when data is NOT pre-split.
- split: Configuration for when data is pre-split.
- train_size, valid_size, test_size: Data split ratios or counts.
- train_offset, valid_offset, test_offset: Offsets for respective splits.
- Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights.
Both pre_split and split objects are always created, but only the one corresponding
@ -84,12 +60,37 @@ class DatasetTrainingConfig(BaseModel):
pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo()
split: TrainingSplitInfo = TrainingSplitInfo()
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data
batch_size: int = 1 # Batch size for training
num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
pretrained_weights: str = "" # Path to pretrained weights for training
@field_validator("train_size", "valid_size", "test_size", mode="before")
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]:
"""
Validates size values:
- If provided as a float, must be in the range (0, 1].
- If provided as an int, must be non-negative.
"""
if isinstance(v, float):
if not (0 <= v <= 1):
raise ValueError("When provided as a float, size must be in the range (0, 1]")
elif isinstance(v, int):
if v < 0:
raise ValueError("When provided as an int, size must be non-negative")
else:
raise ValueError("Size must be either an int or a float")
return v
@model_validator(mode="after")
def validate_split_info(self) -> "DatasetTrainingConfig":
"""
@ -211,18 +212,17 @@ class DatasetConfig(BaseModel):
if self.is_training:
if self.training is None:
raise ValueError("Training configuration must be provided when is_training is True")
# Check predictions_dir if training.split.test_dir and test_size are set
if self.training.pre_split.test_dir and self.training.pre_split.test_size > 0:
if not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when training.split.test_dir and test_size are non-zero")
if self.training.train_size == 0:
raise ValueError("train_size must be provided when is_training is True")
if self.training.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
else:
if self.testing is None:
raise ValueError("Testing configuration must be provided when is_training is False")
if self.testing.test_dir and self.testing.test_size > 0:
if not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when testing.test_dir and test_size are non-zero")
if self.testing.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
return self

@ -87,11 +87,11 @@ class CriterionRegistry:
return entry["params"]
@classmethod
def get_available_criterions(cls) -> List[str]:
def get_available_criterions(cls) -> Tuple[str, ...]:
"""
Returns a list of available loss function names in their original case.
Returns a tuple of available loss function names in their original case.
Returns:
List[str]: List of available loss function names.
Tuple[str]: Tuple of available loss function names.
"""
return list(cls.__CRITERIONS.keys())
return tuple(cls.__CRITERIONS.keys())

@ -1,5 +1,5 @@
from .base import *
from typing import Literal
from typing import List, Literal, Union
from pydantic import BaseModel, ConfigDict
@ -7,11 +7,11 @@ class BCELossParams(BaseModel):
"""
Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`.
"""
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
model_config = ConfigDict(frozen=True)
weight: Optional[torch.Tensor] = None # Sample weights
weight: Optional[List[Union[int, float]]] = None # Sample weights
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
pos_weight: Optional[torch.Tensor] = None # Used only for BCEWithLogitsLoss
pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss
def asdict(self, with_logits: bool = False) -> Dict[str, Any]:
"""
@ -31,6 +31,15 @@ class BCELossParams(BaseModel):
if not with_logits:
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
weight = loss_kwargs.get("weight")
pos_weight = loss_kwargs.get("pos_weight")
if weight is not None:
loss_kwargs["weight"] = torch.Tensor(weight)
if pos_weight is not None:
loss_kwargs["pos_weight"] = torch.Tensor(pos_weight)
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values

@ -1,5 +1,5 @@
from .base import *
from typing import Literal
from typing import List, Literal, Union
from pydantic import BaseModel, ConfigDict
@ -7,9 +7,9 @@ class CrossEntropyLossParams(BaseModel):
"""
Class for handling parameters for `nn.CrossEntropyLoss`.
"""
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
model_config = ConfigDict(frozen=True)
weight: Optional[torch.Tensor] = None
weight: Optional[List[Union[int, float]]] = None
ignore_index: int = -100
reduction: Literal["none", "mean", "sum"] = "mean"
label_smoothing: float = 0.0
@ -22,6 +22,11 @@ class CrossEntropyLossParams(BaseModel):
Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss.
"""
loss_kwargs = self.model_dump()
weight = loss_kwargs.get("weight")
if weight is not None:
loss_kwargs["weight"] = torch.Tensor(weight)
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values

@ -75,11 +75,11 @@ class ModelRegistry:
return entry["params"]
@classmethod
def get_available_models(cls) -> List[str]:
def get_available_models(cls) -> Tuple[str, ...]:
"""
Returns a list of available model names in their original case.
Returns a tuple of available model names in their original case.
Returns:
List[str]: List of available model names.
Tuple[str]: Tuple of available model names.
"""
return list(cls.__MODELS.keys())
return tuple(cls.__MODELS.keys())

@ -82,11 +82,11 @@ class OptimizerRegistry:
return entry["params"]
@classmethod
def get_available_optimizers(cls) -> List[str]:
def get_available_optimizers(cls) -> Tuple[str, ...]:
"""
Returns a list of available optimizer names in their original case.
Returns a tuple of available optimizer names in their original case.
Returns:
List[str]: List of available optimizer names.
Tuple[str]: Tuple of available optimizer names.
"""
return list(cls.__OPTIMIZERS.keys())
return tuple(cls.__OPTIMIZERS.keys())

@ -1,4 +1,3 @@
import torch
from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict

@ -1,4 +1,3 @@
import torch
from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict

@ -1,4 +1,3 @@
import torch
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict

@ -86,11 +86,11 @@ class SchedulerRegistry:
return entry["params"]
@classmethod
def get_available_schedulers(cls) -> List[str]:
def get_available_schedulers(cls) -> Tuple[str, ...]:
"""
Returns a list of available scheduler names in their original case.
Returns a tuple of available scheduler names in their original case.
Returns:
List[str]: List of available scheduler names.
Tuple[str]: Tuple of available scheduler names.
"""
return list(cls.__SCHEDULERS.keys())
return tuple(cls.__SCHEDULERS.keys())

@ -1,6 +1,6 @@
import os
from pydantic import BaseModel
from typing import Any, Dict, Type, Union, List
from typing import Any, Dict, Tuple, Type, Union, List
from config.config import Config
from config.dataset_config import DatasetConfig
@ -22,7 +22,7 @@ def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]:
else:
return param()
def prompt_choice(prompt_message: str, options: List[str]) -> str:
def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str:
"""
Prompt the user with a list of options and return the selected option.
"""

Loading…
Cancel
Save