minor fixes

master
laynholt 3 months ago
parent e00331b503
commit 33ce003657

@ -36,36 +36,10 @@ class TrainingPreSplitInfo(BaseModel):
Configuration for training mode when data is pre-split (is_split is True). Configuration for training mode when data is pre-split (is_split is True).
Contains: Contains:
- train_dir, valid_dir, test_dir: Directories for training, validation, and testing data. - train_dir, valid_dir, test_dir: Directories for training, validation, and testing data.
- train_size, valid_size, test_size: Data split ratios or counts.
- train_offset, valid_offset, test_offset: Offsets for respective splits.
""" """
train_dir: str = "." # Directory for training data if data is pre-split train_dir: str = "." # Directory for training data if data is pre-split
valid_dir: str = "" # Directory for validation data if data is pre-split valid_dir: str = "" # Directory for validation data if data is pre-split
test_dir: str = "" # Directory for testing data if data is pre-split test_dir: str = "" # Directory for testing data if data is pre-split
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data
@field_validator("train_size", "valid_size", "test_size", mode="before")
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]:
"""
Validates size values:
- If provided as a float, must be in the range (0, 1].
- If provided as an int, must be non-negative.
"""
if isinstance(v, float):
if not (0 <= v <= 1):
raise ValueError("When provided as a float, size must be in the range (0, 1]")
elif isinstance(v, int):
if v < 0:
raise ValueError("When provided as an int, size must be non-negative")
else:
raise ValueError("Size must be either an int or a float")
return v
class DatasetTrainingConfig(BaseModel): class DatasetTrainingConfig(BaseModel):
@ -75,6 +49,8 @@ class DatasetTrainingConfig(BaseModel):
- is_split: Determines whether data is pre-split. - is_split: Determines whether data is pre-split.
- pre_split: Configuration for when data is NOT pre-split. - pre_split: Configuration for when data is NOT pre-split.
- split: Configuration for when data is pre-split. - split: Configuration for when data is pre-split.
- train_size, valid_size, test_size: Data split ratios or counts.
- train_offset, valid_offset, test_offset: Offsets for respective splits.
- Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights. - Other training parameters: batch_size, num_epochs, val_freq, use_amp, pretrained_weights.
Both pre_split and split objects are always created, but only the one corresponding Both pre_split and split objects are always created, but only the one corresponding
@ -84,12 +60,37 @@ class DatasetTrainingConfig(BaseModel):
pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo() pre_split: TrainingPreSplitInfo = TrainingPreSplitInfo()
split: TrainingSplitInfo = TrainingSplitInfo() split: TrainingSplitInfo = TrainingSplitInfo()
train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data
batch_size: int = 1 # Batch size for training batch_size: int = 1 # Batch size for training
num_epochs: int = 100 # Number of training epochs num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training val_freq: int = 1 # Frequency of validation during training
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
pretrained_weights: str = "" # Path to pretrained weights for training pretrained_weights: str = "" # Path to pretrained weights for training
@field_validator("train_size", "valid_size", "test_size", mode="before")
def validate_sizes(cls, v: Union[int, float]) -> Union[int, float]:
"""
Validates size values:
- If provided as a float, must be in the range (0, 1].
- If provided as an int, must be non-negative.
"""
if isinstance(v, float):
if not (0 <= v <= 1):
raise ValueError("When provided as a float, size must be in the range (0, 1]")
elif isinstance(v, int):
if v < 0:
raise ValueError("When provided as an int, size must be non-negative")
else:
raise ValueError("Size must be either an int or a float")
return v
@model_validator(mode="after") @model_validator(mode="after")
def validate_split_info(self) -> "DatasetTrainingConfig": def validate_split_info(self) -> "DatasetTrainingConfig":
""" """
@ -211,18 +212,17 @@ class DatasetConfig(BaseModel):
if self.is_training: if self.is_training:
if self.training is None: if self.training is None:
raise ValueError("Training configuration must be provided when is_training is True") raise ValueError("Training configuration must be provided when is_training is True")
# Check predictions_dir if training.split.test_dir and test_size are set if self.training.train_size == 0:
if self.training.pre_split.test_dir and self.training.pre_split.test_size > 0: raise ValueError("train_size must be provided when is_training is True")
if not self.common.predictions_dir: if self.training.test_size > 0 and not self.common.predictions_dir:
raise ValueError("predictions_dir must be provided when training.split.test_dir and test_size are non-zero") raise ValueError("predictions_dir must be provided when test_size is non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
else: else:
if self.testing is None: if self.testing is None:
raise ValueError("Testing configuration must be provided when is_training is False") raise ValueError("Testing configuration must be provided when is_training is False")
if self.testing.test_dir and self.testing.test_size > 0: if self.testing.test_size > 0 and not self.common.predictions_dir:
if not self.common.predictions_dir: raise ValueError("predictions_dir must be provided when test_size is non-zero")
raise ValueError("predictions_dir must be provided when testing.test_dir and test_size are non-zero")
if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir):
raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}")
return self return self

@ -87,11 +87,11 @@ class CriterionRegistry:
return entry["params"] return entry["params"]
@classmethod @classmethod
def get_available_criterions(cls) -> List[str]: def get_available_criterions(cls) -> Tuple[str, ...]:
""" """
Returns a list of available loss function names in their original case. Returns a tuple of available loss function names in their original case.
Returns: Returns:
List[str]: List of available loss function names. Tuple[str]: Tuple of available loss function names.
""" """
return list(cls.__CRITERIONS.keys()) return tuple(cls.__CRITERIONS.keys())

@ -1,5 +1,5 @@
from .base import * from .base import *
from typing import Literal from typing import List, Literal, Union
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -7,11 +7,11 @@ class BCELossParams(BaseModel):
""" """
Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`. Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`.
""" """
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) model_config = ConfigDict(frozen=True)
weight: Optional[torch.Tensor] = None # Sample weights weight: Optional[List[Union[int, float]]] = None # Sample weights
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
pos_weight: Optional[torch.Tensor] = None # Used only for BCEWithLogitsLoss pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss
def asdict(self, with_logits: bool = False) -> Dict[str, Any]: def asdict(self, with_logits: bool = False) -> Dict[str, Any]:
""" """
@ -31,6 +31,15 @@ class BCELossParams(BaseModel):
if not with_logits: if not with_logits:
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
weight = loss_kwargs.get("weight")
pos_weight = loss_kwargs.get("pos_weight")
if weight is not None:
loss_kwargs["weight"] = torch.Tensor(weight)
if pos_weight is not None:
loss_kwargs["pos_weight"] = torch.Tensor(pos_weight)
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values

@ -1,5 +1,5 @@
from .base import * from .base import *
from typing import Literal from typing import List, Literal, Union
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -7,9 +7,9 @@ class CrossEntropyLossParams(BaseModel):
""" """
Class for handling parameters for `nn.CrossEntropyLoss`. Class for handling parameters for `nn.CrossEntropyLoss`.
""" """
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) model_config = ConfigDict(frozen=True)
weight: Optional[torch.Tensor] = None weight: Optional[List[Union[int, float]]] = None
ignore_index: int = -100 ignore_index: int = -100
reduction: Literal["none", "mean", "sum"] = "mean" reduction: Literal["none", "mean", "sum"] = "mean"
label_smoothing: float = 0.0 label_smoothing: float = 0.0
@ -22,6 +22,11 @@ class CrossEntropyLossParams(BaseModel):
Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss. Dict[str, Any]: Dictionary of parameters for nn.CrossEntropyLoss.
""" """
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
weight = loss_kwargs.get("weight")
if weight is not None:
loss_kwargs["weight"] = torch.Tensor(weight)
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values

@ -75,11 +75,11 @@ class ModelRegistry:
return entry["params"] return entry["params"]
@classmethod @classmethod
def get_available_models(cls) -> List[str]: def get_available_models(cls) -> Tuple[str, ...]:
""" """
Returns a list of available model names in their original case. Returns a tuple of available model names in their original case.
Returns: Returns:
List[str]: List of available model names. Tuple[str]: Tuple of available model names.
""" """
return list(cls.__MODELS.keys()) return tuple(cls.__MODELS.keys())

@ -82,11 +82,11 @@ class OptimizerRegistry:
return entry["params"] return entry["params"]
@classmethod @classmethod
def get_available_optimizers(cls) -> List[str]: def get_available_optimizers(cls) -> Tuple[str, ...]:
""" """
Returns a list of available optimizer names in their original case. Returns a tuple of available optimizer names in their original case.
Returns: Returns:
List[str]: List of available optimizer names. Tuple[str]: Tuple of available optimizer names.
""" """
return list(cls.__OPTIMIZERS.keys()) return tuple(cls.__OPTIMIZERS.keys())

@ -1,4 +1,3 @@
import torch
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict

@ -1,4 +1,3 @@
import torch
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict

@ -1,4 +1,3 @@
import torch
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict

@ -86,11 +86,11 @@ class SchedulerRegistry:
return entry["params"] return entry["params"]
@classmethod @classmethod
def get_available_schedulers(cls) -> List[str]: def get_available_schedulers(cls) -> Tuple[str, ...]:
""" """
Returns a list of available scheduler names in their original case. Returns a tuple of available scheduler names in their original case.
Returns: Returns:
List[str]: List of available scheduler names. Tuple[str]: Tuple of available scheduler names.
""" """
return list(cls.__SCHEDULERS.keys()) return tuple(cls.__SCHEDULERS.keys())

@ -1,6 +1,6 @@
import os import os
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Dict, Type, Union, List from typing import Any, Dict, Tuple, Type, Union, List
from config.config import Config from config.config import Config
from config.dataset_config import DatasetConfig from config.dataset_config import DatasetConfig
@ -22,7 +22,7 @@ def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]:
else: else:
return param() return param()
def prompt_choice(prompt_message: str, options: List[str]) -> str: def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str:
""" """
Prompt the user with a list of options and return the selected option. Prompt the user with a list of options and return the selected option.
""" """

Loading…
Cancel
Save