|
|
import json
|
|
|
from typing import Any
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
from .wandb_config import WandbConfig
|
|
|
from .dataset_config import DatasetConfig
|
|
|
|
|
|
|
|
|
__all__ = ["Config", "ComponentConfig"]
|
|
|
|
|
|
|
|
|
class ComponentConfig(BaseModel):
|
|
|
name: str
|
|
|
params: BaseModel
|
|
|
|
|
|
def dump(self) -> dict[str, Any]:
|
|
|
"""
|
|
|
Recursively serializes the component into a dictionary.
|
|
|
|
|
|
Returns:
|
|
|
dict: A dictionary containing the component name and its serialized parameters.
|
|
|
"""
|
|
|
if isinstance(self.params, BaseModel):
|
|
|
params_dump = self.params.model_dump()
|
|
|
else:
|
|
|
params_dump = self.params
|
|
|
return {"name": self.name, "params": params_dump}
|
|
|
|
|
|
|
|
|
class Config(BaseModel):
|
|
|
model: ComponentConfig
|
|
|
dataset_config: DatasetConfig
|
|
|
wandb_config: WandbConfig
|
|
|
criterion: ComponentConfig | None = None
|
|
|
optimizer: ComponentConfig | None = None
|
|
|
scheduler: ComponentConfig | None = None
|
|
|
|
|
|
def asdict(self) -> dict[str, Any]:
|
|
|
"""
|
|
|
Produce a JSON‐serializable dict of this config, including nested
|
|
|
ComponentConfig and DatasetConfig entries. Useful for saving to file
|
|
|
or passing to experiment loggers (e.g. wandb.init(config=...)).
|
|
|
|
|
|
Returns:
|
|
|
A dict with keys 'model', 'dataset_config', and (if set)
|
|
|
'criterion', 'optimizer', 'scheduler'.
|
|
|
"""
|
|
|
data: dict[str, Any] = {
|
|
|
"model": self.model.dump(),
|
|
|
"dataset_config": self.dataset_config.model_dump(),
|
|
|
}
|
|
|
if self.criterion is not None:
|
|
|
data["criterion"] = self.criterion.dump()
|
|
|
if self.optimizer is not None:
|
|
|
data["optimizer"] = self.optimizer.dump()
|
|
|
if self.scheduler is not None:
|
|
|
data["scheduler"] = self.scheduler.dump()
|
|
|
data["wandb"] = self.wandb_config.model_dump()
|
|
|
return data
|
|
|
|
|
|
def save_json(self, file_path: str, indent: int = 4) -> None:
|
|
|
"""
|
|
|
Save this config to a JSON file.
|
|
|
|
|
|
Args:
|
|
|
file_path: Path to write the JSON file.
|
|
|
indent: JSON indent level.
|
|
|
"""
|
|
|
config_dict = self.asdict()
|
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
|
f.write(json.dumps(config_dict, indent=indent))
|
|
|
|
|
|
@classmethod
|
|
|
def load_json(cls, file_path: str) -> "Config":
|
|
|
"""
|
|
|
Loads a configuration from a JSON file and re-instantiates each section using
|
|
|
the registry keys to recover the original parameter class(es).
|
|
|
|
|
|
Args:
|
|
|
file_path (str): Path to the JSON file.
|
|
|
|
|
|
Returns:
|
|
|
Config: An instance of Config with the proper parameter classes.
|
|
|
"""
|
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
# Parse dataset_config and wandb_config using its Pydantic model.
|
|
|
dataset_config = DatasetConfig(**data.get("dataset_config", {}))
|
|
|
wandb_config = WandbConfig(**data.get("wandb", {}))
|
|
|
|
|
|
# Helper function to parse registry fields.
|
|
|
def parse_field(
|
|
|
component_data: dict[str, Any], registry_getter
|
|
|
) -> ComponentConfig | None:
|
|
|
name = component_data.get("name")
|
|
|
params_data = component_data.get("params", {})
|
|
|
|
|
|
if name is not None:
|
|
|
expected = registry_getter(name)
|
|
|
params = expected(**params_data)
|
|
|
return ComponentConfig(name=name, params=params)
|
|
|
return None
|
|
|
|
|
|
from core import (
|
|
|
ModelRegistry,
|
|
|
CriterionRegistry,
|
|
|
OptimizerRegistry,
|
|
|
SchedulerRegistry,
|
|
|
)
|
|
|
|
|
|
parsed_model = parse_field(
|
|
|
data.get("model", {}),
|
|
|
lambda key: ModelRegistry.get_model_params(key),
|
|
|
)
|
|
|
parsed_criterion = parse_field(
|
|
|
data.get("criterion", {}),
|
|
|
lambda key: CriterionRegistry.get_criterion_params(key),
|
|
|
)
|
|
|
parsed_optimizer = parse_field(
|
|
|
data.get("optimizer", {}),
|
|
|
lambda key: OptimizerRegistry.get_optimizer_params(key),
|
|
|
)
|
|
|
parsed_scheduler = parse_field(
|
|
|
data.get("scheduler", {}),
|
|
|
lambda key: SchedulerRegistry.get_scheduler_params(key),
|
|
|
)
|
|
|
|
|
|
if parsed_model is None:
|
|
|
raise ValueError("Failed to load model information")
|
|
|
|
|
|
return cls(
|
|
|
model=parsed_model,
|
|
|
dataset_config=dataset_config,
|
|
|
criterion=parsed_criterion,
|
|
|
optimizer=parsed_optimizer,
|
|
|
scheduler=parsed_scheduler,
|
|
|
wandb_config=wandb_config,
|
|
|
)
|