You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.9 KiB
97 lines
3.9 KiB
import json
|
|
from typing import Any, Dict, List, Optional, Union
|
|
from pydantic import BaseModel
|
|
|
|
from .dataset_config import DatasetConfig
|
|
|
|
class Config(BaseModel):
|
|
model: Dict[str, Union[BaseModel, List[BaseModel]]]
|
|
dataset_config: DatasetConfig
|
|
criterion: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
|
|
optimizer: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
|
|
scheduler: Optional[Dict[str, Union[BaseModel, List[BaseModel]]]] = None
|
|
|
|
@staticmethod
|
|
def __dump_field(value: Any) -> Any:
|
|
"""
|
|
Recursively dumps a field if it is a BaseModel or a list/dict of BaseModels.
|
|
"""
|
|
if isinstance(value, BaseModel):
|
|
return value.model_dump()
|
|
elif isinstance(value, list):
|
|
return [Config.__dump_field(item) for item in value]
|
|
elif isinstance(value, dict):
|
|
return {k: Config.__dump_field(v) for k, v in value.items()}
|
|
else:
|
|
return value
|
|
|
|
def save_json(self, file_path: str, indent: int = 4) -> None:
|
|
"""
|
|
Saves the configuration to a JSON file using dumps of each individual field.
|
|
|
|
Args:
|
|
file_path (str): Destination path for the JSON file.
|
|
indent (int): Indentation level for the JSON file.
|
|
"""
|
|
config_dump = {
|
|
"model": self.__dump_field(self.model),
|
|
"dataset_config": self.dataset_config.model_dump()
|
|
}
|
|
if self.criterion is not None:
|
|
config_dump.update({"criterion": self.__dump_field(self.criterion)})
|
|
if self.optimizer is not None:
|
|
config_dump.update({"optimizer": self.__dump_field(self.optimizer)})
|
|
if self.scheduler is not None:
|
|
config_dump.update({"scheduler": self.__dump_field(self.scheduler)})
|
|
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
f.write(json.dumps(config_dump, indent=indent))
|
|
|
|
|
|
@classmethod
|
|
def load_json(cls, file_path: str) -> "Config":
|
|
"""
|
|
Loads a configuration from a JSON file and re-instantiates each section using
|
|
the registry keys to recover the original parameter class(es).
|
|
|
|
Args:
|
|
file_path (str): Path to the JSON file.
|
|
|
|
Returns:
|
|
Config: An instance of Config with the proper parameter classes.
|
|
"""
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
# Parse dataset_config using its Pydantic model.
|
|
dataset_config = DatasetConfig(**data.get("dataset_config", {}))
|
|
|
|
# Helper function to parse registry fields.
|
|
def parse_field(field_data: Dict[str, Any], registry_getter) -> Dict[str, Union[BaseModel, List[BaseModel]]]:
|
|
result = {}
|
|
for key, value in field_data.items():
|
|
expected = registry_getter(key)
|
|
# If the registry returns a tuple, then we expect a list of dictionaries.
|
|
if isinstance(expected, tuple):
|
|
result[key] = [cls_param(**item) for cls_param, item in zip(expected, value)]
|
|
else:
|
|
result[key] = expected(**value)
|
|
return result
|
|
|
|
from core import (
|
|
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
|
|
)
|
|
|
|
parsed_model = parse_field(data.get("model", {}), lambda key: ModelRegistry.get_model_params(key))
|
|
parsed_criterion = parse_field(data.get("criterion", {}), lambda key: CriterionRegistry.get_criterion_params(key))
|
|
parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key))
|
|
parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key))
|
|
|
|
return cls(
|
|
model=parsed_model,
|
|
dataset_config=dataset_config,
|
|
criterion=parsed_criterion,
|
|
optimizer=parsed_optimizer,
|
|
scheduler=parsed_scheduler
|
|
)
|