add id field to wandb and remove dublicate code from train

master
laynholt 1 month ago
parent 7e4643eb84
commit f2456126ef

@ -11,6 +11,7 @@ class WandbConfig(BaseModel):
group: str | None = None # WandB group name
entity: str | None = None # WandB entity (user or team)
name: str | None = None # Name of the run
id: str | None = None # Id of the run
tags: list[str] | None = None # List of tags for the run
notes: str | None = None # Notes or description for the run
save_code: bool = True # Whether to save the code to WandB

@ -382,12 +382,7 @@ class CellSegmentator:
self._model.load_state_dict(self._best_weights)
if self._test_dataloader is not None:
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0)
save_path = self._dataset_setup.common.predictions_dir
os.makedirs(save_path, exist_ok=True)
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
self.evaluate(save_results=save_results, only_masks=only_masks)
def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None:
@ -693,6 +688,8 @@ class CellSegmentator:
logger.info(f"├─ Entity: {self._wandb_config.entity}")
if self._wandb_config.name:
logger.info(f"├─ Run name: {self._wandb_config.name}")
if self._wandb_config.id:
logger.info(f"├─ Run id: {self._wandb_config.id}")
if self._wandb_config.tags:
logger.info(f"├─ Tags: {', '.join(self._wandb_config.tags)}")
if self._wandb_config.notes:

Loading…
Cancel
Save