You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
4.3 KiB
108 lines
4.3 KiB
import os
|
|
from typing import Tuple
|
|
|
|
from config.config import *
|
|
from config.dataset_config import DatasetConfig
|
|
|
|
from core import (
|
|
ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry
|
|
)
|
|
|
|
|
|
def prompt_choice(prompt_message: str, options: Tuple[str, ...]) -> str:
|
|
"""
|
|
Prompt the user with a list of options and return the selected option.
|
|
"""
|
|
print(prompt_message)
|
|
for i, option in enumerate(options, start=1):
|
|
print(f"{i}. {option}")
|
|
while True:
|
|
try:
|
|
choice = int(input("Enter your choice (number): "))
|
|
if 1 <= choice <= len(options):
|
|
return options[choice - 1]
|
|
else:
|
|
print("Invalid choice. Please try again.")
|
|
except ValueError:
|
|
print("Please enter a valid number.")
|
|
|
|
def main():
|
|
# Determine the directory of this script.
|
|
script_path = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
# Ask the user whether this is training mode.
|
|
training_input = input("Is this training mode? (y/n): ").strip().lower()
|
|
is_training = training_input in ("y", "yes")
|
|
|
|
# Create a default DatasetConfig based on the training mode.
|
|
# The DatasetConfig.default_config method fills in required fields with zero-values.
|
|
dataset_config = DatasetConfig(is_training=is_training)
|
|
|
|
# Prompt the user to select a model.
|
|
model_options = ModelRegistry.get_available_models()
|
|
chosen_model = prompt_choice("\nSelect a model:", model_options)
|
|
model_param_class = ModelRegistry.get_model_params(chosen_model)
|
|
model_instance = model_param_class()
|
|
|
|
if is_training is False:
|
|
config = Config(
|
|
model=ComponentConfig(name=chosen_model, params=model_instance),
|
|
dataset_config=dataset_config
|
|
)
|
|
|
|
# Construct a base filename from the selected registry names.
|
|
base_filename = f"{chosen_model}"
|
|
|
|
else:
|
|
# Prompt the user to select a criterion.
|
|
criterion_options = CriterionRegistry.get_available_criterions()
|
|
chosen_criterion = prompt_choice("\nSelect a criterion:", criterion_options)
|
|
criterion_param_class = CriterionRegistry.get_criterion_params(chosen_criterion)
|
|
criterion_instance = criterion_param_class()
|
|
|
|
# Prompt the user to select an optimizer.
|
|
optimizer_options = OptimizerRegistry.get_available_optimizers()
|
|
chosen_optimizer = prompt_choice("\nSelect an optimizer:", optimizer_options)
|
|
optimizer_param_class = OptimizerRegistry.get_optimizer_params(chosen_optimizer)
|
|
optimizer_instance = optimizer_param_class()
|
|
|
|
# Prompt the user to select a scheduler.
|
|
scheduler_options = SchedulerRegistry.get_available_schedulers()
|
|
chosen_scheduler = prompt_choice("\nSelect a scheduler:", scheduler_options)
|
|
scheduler_param_class = SchedulerRegistry.get_scheduler_params(chosen_scheduler)
|
|
scheduler_instance = scheduler_param_class()
|
|
|
|
# Assemble the overall configuration using the registry names as keys.
|
|
config = Config(
|
|
model=ComponentConfig(name=chosen_model, params=model_instance),
|
|
dataset_config=dataset_config,
|
|
criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance),
|
|
optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance),
|
|
scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance)
|
|
)
|
|
|
|
# Construct a base filename from the selected registry names.
|
|
base_filename = f"{chosen_model}_{chosen_criterion}_{chosen_optimizer}_{chosen_scheduler}"
|
|
|
|
# Determine the output directory relative to this script.
|
|
base_dir = os.path.join(script_path, "config/templates", "train" if is_training else "predict")
|
|
os.makedirs(base_dir, exist_ok=True)
|
|
|
|
filename = f"{base_filename}.json"
|
|
full_path = os.path.join(base_dir, filename)
|
|
counter = 1
|
|
|
|
# Append a counter if a file with the same name exists.
|
|
while os.path.exists(full_path):
|
|
filename = f"{base_filename}_{counter}.json"
|
|
full_path = os.path.join(base_dir, filename)
|
|
counter += 1
|
|
|
|
# Save the configuration as a JSON file.
|
|
config.save_json(full_path)
|
|
|
|
print(f"\nConfiguration saved to: {full_path}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|