You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
4.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
from typing import Any
from pydantic import BaseModel
from .wandb_config import WandbConfig
from .dataset_config import DatasetConfig
__all__ = ["Config", "ComponentConfig"]
class ComponentConfig(BaseModel):
name: str
params: BaseModel
def dump(self) -> dict[str, Any]:
"""
Recursively serializes the component into a dictionary.
Returns:
dict: A dictionary containing the component name and its serialized parameters.
"""
if isinstance(self.params, BaseModel):
params_dump = self.params.model_dump()
else:
params_dump = self.params
return {"name": self.name, "params": params_dump}
class Config(BaseModel):
model: ComponentConfig
dataset_config: DatasetConfig
wandb_config: WandbConfig
criterion: ComponentConfig | None = None
optimizer: ComponentConfig | None = None
scheduler: ComponentConfig | None = None
def asdict(self) -> dict[str, Any]:
"""
Produce a JSONserializable dict of this config, including nested
ComponentConfig and DatasetConfig entries. Useful for saving to file
or passing to experiment loggers (e.g. wandb.init(config=...)).
Returns:
A dict with keys 'model', 'dataset_config', and (if set)
'criterion', 'optimizer', 'scheduler'.
"""
data: dict[str, Any] = {
"model": self.model.dump(),
"dataset_config": self.dataset_config.model_dump(),
}
if self.criterion is not None:
data["criterion"] = self.criterion.dump()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.dump()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.dump()
data["wandb"] = self.wandb_config.model_dump()
return data
def save_json(self, file_path: str, indent: int = 4) -> None:
"""
Save this config to a JSON file.
Args:
file_path: Path to write the JSON file.
indent: JSON indent level.
"""
config_dict = self.asdict()
with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(config_dict, indent=indent))
@classmethod
def load_json(cls, file_path: str) -> "Config":
"""
Loads a configuration from a JSON file and re-instantiates each section using
the registry keys to recover the original parameter class(es).
Args:
file_path (str): Path to the JSON file.
Returns:
Config: An instance of Config with the proper parameter classes.
"""
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Parse dataset_config and wandb_config using its Pydantic model.
dataset_config = DatasetConfig(**data.get("dataset_config", {}))
wandb_config = WandbConfig(**data.get("wandb", {}))
# Helper function to parse registry fields.
def parse_field(
component_data: dict[str, Any], registry_getter
) -> ComponentConfig | None:
name = component_data.get("name")
params_data = component_data.get("params", {})
if name is not None:
expected = registry_getter(name)
params = expected(**params_data)
return ComponentConfig(name=name, params=params)
return None
from core import (
ModelRegistry,
CriterionRegistry,
OptimizerRegistry,
SchedulerRegistry,
)
parsed_model = parse_field(
data.get("model", {}),
lambda key: ModelRegistry.get_model_params(key),
)
parsed_criterion = parse_field(
data.get("criterion", {}),
lambda key: CriterionRegistry.get_criterion_params(key),
)
parsed_optimizer = parse_field(
data.get("optimizer", {}),
lambda key: OptimizerRegistry.get_optimizer_params(key),
)
parsed_scheduler = parse_field(
data.get("scheduler", {}),
lambda key: SchedulerRegistry.get_scheduler_params(key),
)
if parsed_model is None:
raise ValueError("Failed to load model information")
return cls(
model=parsed_model,
dataset_config=dataset_config,
criterion=parsed_criterion,
optimizer=parsed_optimizer,
scheduler=parsed_scheduler,
wandb_config=wandb_config,
)