parent
5d984dc7a9
commit
4a501ea31a
@ -0,0 +1,41 @@
|
|||||||
|
import os
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
from config.config import Config
|
||||||
|
from core.data import *
|
||||||
|
from core.segmentator import CellSegmentator
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config_path = 'config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json'
|
||||||
|
# config_path = 'config/templates/predict/ModelV.json'
|
||||||
|
|
||||||
|
config = Config.load_json(config_path)
|
||||||
|
# config = Config.load_json(config_path)
|
||||||
|
|
||||||
|
if config.dataset_config.wandb.use_wandb:
|
||||||
|
# Initialize W&B
|
||||||
|
wandb.init(config=config.asdict(), **config.dataset_config.wandb.asdict())
|
||||||
|
|
||||||
|
# How many batches to wait before logging training status
|
||||||
|
wandb.config.log_interval = 10
|
||||||
|
|
||||||
|
segmentator = CellSegmentator(config)
|
||||||
|
segmentator.create_dataloaders()
|
||||||
|
|
||||||
|
# Watch parameters & gradients of model
|
||||||
|
if config.dataset_config.wandb.use_wandb:
|
||||||
|
wandb.watch(segmentator._model, log="all", log_graph=True)
|
||||||
|
|
||||||
|
segmentator.run()
|
||||||
|
|
||||||
|
weights_dir = "weights" if not config.dataset_config.wandb.use_wandb else wandb.run.dir # type: ignore
|
||||||
|
saving_path = os.path.join(
|
||||||
|
weights_dir, os.path.basename(config.dataset_config.common.predictions_dir) + '.pth'
|
||||||
|
)
|
||||||
|
segmentator.save_checkpoint(saving_path)
|
||||||
|
|
||||||
|
if config.dataset_config.wandb.use_wandb:
|
||||||
|
wandb.save(saving_path)
|
||||||
|
|
||||||
|
|
@ -1,10 +0,0 @@
|
|||||||
from config.config import Config
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
|
|
||||||
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json')
|
|
||||||
pprint(config, indent=4)
|
|
||||||
|
|
||||||
print('\n\n')
|
|
||||||
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json')
|
|
||||||
pprint(config, indent=4)
|
|
Loading…
Reference in new issue