You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
4.3 KiB

import os
from typing import Tuple
from config.config import *
from config.dataset_config import DatasetConfig
from core import (
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
)
def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str:
"""
Prompt the user with a list of options and return the selected option.
"""
print(prompt_message)
for i, option in enumerate(options, start=1):
print(f"{i}. {option}")
while True:
try:
choice = int(input("Enter your choice (number): "))
if 1 <= choice <= len(options):
return options[choice - 1]
else:
print("Invalid choice. Please try again.")
except ValueError:
print("Please enter a valid number.")
def main():
# Determine the directory of this script.
script_path = os.path.dirname(os.path.abspath(__file__))
# Ask the user whether this is training mode.
training_input = input("Is this training mode? (y/n): ").strip().lower()
is_training = training_input in ("y", "yes")
# Create a default DatasetConfig based on the training mode.
# The DatasetConfig.default_config method fills in required fields with zero-values.
dataset_config = DatasetConfig(is_training=is_training)
# Prompt the user to select a model.
model_options = ModelRegistry.get_available_models()
chosen_model = prompt_choice("\nSelect a model:", model_options)
model_param_class = ModelRegistry.get_model_params(chosen_model)
model_instance = model_param_class()
if is_training is False:
config = Config(
model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config
)
# Construct a base filename from the selected registry names.
base_filename = f"{chosen_model}"
else:
# Prompt the user to select a criterion.
criterion_options = CriterionRegistry.get_available_criterions()
chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options)
criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion)
criterion_instance = criterion_param_class()
# Prompt the user to select an optimizer.
optimizer_options = OptimizerRegistry.get_available_optimizers()
chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options)
optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer)
optimizer_instance = optimizer_param_class()
# Prompt the user to select a scheduler.
scheduler_options = SchedulerRegistry.get_available_schedulers()
chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options)
scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler)
scheduler_instance = scheduler_param_class()
# Assemble the overall configuration using the registry names as keys.
config = Config(
model=ComponentConfig(name=chosen_model, params=model_instance),
dataset_config=dataset_config,
criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance),
optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance),
scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance)
)
# Construct a base filename from the selected registry names.
base_filename = f"{chosen_model}_{chosen_criterion}_{chosen_optimizer}_{chosen_scheduler}"
# Determine the output directory relative to this script.
base_dir = os.path.join(script_path, "config/templates", "train" if is_training else "predict")
os.makedirs(base_dir, exist_ok=True)
filename = f"{base_filename}.json"
full_path = os.path.join(base_dir, filename)
counter = 1
# Append a counter if a file with the same name exists.
while os.path.exists(full_path):
filename = f"{base_filename}_{counter}.json"
full_path = os.path.join(base_dir, filename)
counter += 1
# Save the configuration as a JSON file.
config.save_json(full_path)
print(f"\nConfiguration saved to: {full_path}")
if __name__ == "__main__":
main()