You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

42 lines
1.3 KiB

import os
import wandb
from config.config import Config
from core.data import *
from core.segmentator import CellSegmentator
if __name__ == "__main__":
config_path = 'config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json'
# config_path = 'config/templates/predict/ModelV.json'
config = Config.load_json(config_path)
# config = Config.load_json(config_path)
if config.dataset_config.wandb.use_wandb:
# Initialize W&B
wandb.init(config=config.asdict(), **config.dataset_config.wandb.asdict())
# How many batches to wait before logging training status
wandb.config.log_interval = 10
segmentator = CellSegmentator(config)
segmentator.create_dataloaders()
# Watch parameters & gradients of model
if config.dataset_config.wandb.use_wandb:
wandb.watch(segmentator._model, log="all", log_graph=True)
segmentator.run()
weights_dir = "weights" if not config.dataset_config.wandb.use_wandb else wandb.run.dir # type: ignore
saving_path = os.path.join(
weights_dir, os.path.basename(config.dataset_config.common.predictions_dir) + '.pth'
)
segmentator.save_checkpoint(saving_path)
if config.dataset_config.wandb.use_wandb:
wandb.save(saving_path)