You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.8 KiB
111 lines
3.8 KiB
import json
|
|
from typing import Any, Dict, Optional
|
|
from pydantic import BaseModel
|
|
|
|
from .dataset_config import DatasetConfig
|
|
|
|
|
|
__all__ = ["Config", "ComponentConfig"]
|
|
|
|
|
|
class ComponentConfig(BaseModel):
|
|
name: str
|
|
params: BaseModel
|
|
|
|
def dump(self) -> Dict[str, Any]:
|
|
"""
|
|
Recursively serializes the component into a dictionary.
|
|
|
|
Returns:
|
|
dict: A dictionary containing the component name and its serialized parameters.
|
|
"""
|
|
if isinstance(self.params, BaseModel):
|
|
params_dump = self.params.model_dump()
|
|
else:
|
|
params_dump = self.params
|
|
return {
|
|
"name": self.name,
|
|
"params": params_dump
|
|
}
|
|
|
|
|
|
|
|
class Config(BaseModel):
|
|
model: ComponentConfig
|
|
dataset_config: DatasetConfig
|
|
criterion: Optional[ComponentConfig] = None
|
|
optimizer: Optional[ComponentConfig] = None
|
|
scheduler: Optional[ComponentConfig] = None
|
|
|
|
def save_json(self, file_path: str, indent: int = 4) -> None:
|
|
"""
|
|
Saves the configuration to a JSON file using dumps of each individual field.
|
|
|
|
Args:
|
|
file_path (str): Destination path for the JSON file.
|
|
indent (int): Indentation level for the JSON file.
|
|
"""
|
|
config_dump = {
|
|
"model": self.model.dump(),
|
|
"dataset_config": self.dataset_config.model_dump()
|
|
}
|
|
if self.criterion is not None:
|
|
config_dump["criterion"] = self.criterion.dump()
|
|
if self.optimizer is not None:
|
|
config_dump["optimizer"] = self.optimizer.dump()
|
|
if self.scheduler is not None:
|
|
config_dump["scheduler"] = self.scheduler.dump()
|
|
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
f.write(json.dumps(config_dump, indent=indent))
|
|
|
|
|
|
@classmethod
|
|
def load_json(cls, file_path: str) -> "Config":
|
|
"""
|
|
Loads a configuration from a JSON file and re-instantiates each section using
|
|
the registry keys to recover the original parameter class(es).
|
|
|
|
Args:
|
|
file_path (str): Path to the JSON file.
|
|
|
|
Returns:
|
|
Config: An instance of Config with the proper parameter classes.
|
|
"""
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
# Parse dataset_config using its Pydantic model.
|
|
dataset_config = DatasetConfig(**data.get("dataset_config", {}))
|
|
|
|
# Helper function to parse registry fields.
|
|
def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]:
|
|
name = component_data.get("name")
|
|
params_data = component_data.get("params", {})
|
|
|
|
if name is not None:
|
|
expected = registry_getter(name)
|
|
params = expected(**params_data)
|
|
return ComponentConfig(name=name, params=params)
|
|
return None
|
|
|
|
from core import (
|
|
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
|
|
)
|
|
|
|
parsed_model = parse_field(data.get("model", {}), lambda key: ModelRegistry.get_model_params(key))
|
|
parsed_criterion = parse_field(data.get("criterion", {}), lambda key: CriterionRegistry.get_criterion_params(key))
|
|
parsed_optimizer = parse_field(data.get("optimizer", {}), lambda key: OptimizerRegistry.get_optimizer_params(key))
|
|
parsed_scheduler = parse_field(data.get("scheduler", {}), lambda key: SchedulerRegistry.get_scheduler_params(key))
|
|
|
|
if parsed_model is None:
|
|
raise ValueError('Failed to load model information')
|
|
|
|
return cls(
|
|
model=parsed_model,
|
|
dataset_config=dataset_config,
|
|
criterion=parsed_criterion,
|
|
optimizer=parsed_optimizer,
|
|
scheduler=parsed_scheduler
|
|
)
|