You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
4.7 KiB
121 lines
4.7 KiB
import os
|
|
from pydantic import BaseModel
|
|
from typing import Any, Dict, Tuple, Type, Union, List
|
|
|
|
from config.config import Config
|
|
from config.dataset_config import DatasetConfig
|
|
|
|
from core import (
|
|
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
|
|
)
|
|
|
|
|
|
def instantiate_params(param: Any) -> Union[BaseModel, List[BaseModel]]:
|
|
"""
|
|
Instantiates the parameter class(es) with default values.
|
|
|
|
If 'param' is a tuple, instantiate each class and return a list of instances.
|
|
Otherwise, instantiate the single class and return the instance.
|
|
"""
|
|
if isinstance(param, tuple):
|
|
return [cls() for cls in param]
|
|
else:
|
|
return param()
|
|
|
|
def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str:
|
|
"""
|
|
Prompt the user with a list of options and return the selected option.
|
|
"""
|
|
print(prompt_message)
|
|
for i, option in enumerate(options, start=1):
|
|
print(f"{i}. {option}")
|
|
while True:
|
|
try:
|
|
choice = int(input("Enter your choice (number): "))
|
|
if 1 <= choice <= len(options):
|
|
return options[choice - 1]
|
|
else:
|
|
print("Invalid choice. Please try again.")
|
|
except ValueError:
|
|
print("Please enter a valid number.")
|
|
|
|
def main():
|
|
# Determine the directory of this script.
|
|
script_path = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
# Ask the user whether this is training mode.
|
|
training_input = input("Is this training mode? (y/n): ").strip().lower()
|
|
is_training = training_input in ("y", "yes")
|
|
|
|
# Create a default DatasetConfig based on the training mode.
|
|
# The DatasetConfig.default_config method fills in required fields with zero-values.
|
|
dataset_config = DatasetConfig(is_training=is_training)
|
|
|
|
# Prompt the user to select a model.
|
|
model_options = ModelRegistry.get_available_models()
|
|
chosen_model = prompt_choice("\nSelect a model:", model_options)
|
|
model_param_class = ModelRegistry.get_model_params(chosen_model)
|
|
model_instance = instantiate_params(model_param_class)
|
|
|
|
if is_training is False:
|
|
config = Config(
|
|
model={chosen_model: model_instance},
|
|
dataset_config=dataset_config
|
|
)
|
|
|
|
# Construct a base filename from the selected registry names.
|
|
base_filename = f"{chosen_model}"
|
|
|
|
else:
|
|
# Prompt the user to select a criterion.
|
|
criterion_options = CriterionRegistry.get_available_criterions()
|
|
chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options)
|
|
criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion)
|
|
criterion_instance = instantiate_params(criterion_param_class)
|
|
|
|
# Prompt the user to select an optimizer.
|
|
optimizer_options = OptimizerRegistry.get_available_optimizers()
|
|
chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options)
|
|
optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer)
|
|
optimizer_instance = instantiate_params(optimizer_param_class)
|
|
|
|
# Prompt the user to select a scheduler.
|
|
scheduler_options = SchedulerRegistry.get_available_schedulers()
|
|
chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options)
|
|
scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler)
|
|
scheduler_instance = instantiate_params(scheduler_param_class)
|
|
|
|
# Assemble the overall configuration using the registry names as keys.
|
|
config = Config(
|
|
model={chosen_model: model_instance},
|
|
dataset_config=dataset_config,
|
|
criterion={chosen_criterion: criterion_instance},
|
|
optimizer={chosen_optimizer: optimizer_instance},
|
|
scheduler={chosen_scheduler: scheduler_instance}
|
|
)
|
|
|
|
# Construct a base filename from the selected registry names.
|
|
base_filename = f"{chosen_model}_{chosen_criterion}_{chosen_optimizer}_{chosen_scheduler}"
|
|
|
|
# Determine the output directory relative to this script.
|
|
base_dir = os.path.join(script_path, "config/templates", "train" if is_training else "predict")
|
|
os.makedirs(base_dir, exist_ok=True)
|
|
|
|
filename = f"{base_filename}.json"
|
|
full_path = os.path.join(base_dir, filename)
|
|
counter = 1
|
|
|
|
# Append a counter if a file with the same name exists.
|
|
while os.path.exists(full_path):
|
|
filename = f"{base_filename}_{counter}.json"
|
|
full_path = os.path.join(base_dir, filename)
|
|
counter += 1
|
|
|
|
# Save the configuration as a JSON file.
|
|
config.save_json(full_path)
|
|
|
|
print(f"\nConfiguration saved to: {full_path}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|