diff --git a/config/config.py b/config/config.py index 6f8b4e0..acd6ee3 100644 --- a/config/config.py +++ b/config/config.py @@ -37,27 +37,40 @@ class Config(BaseModel): optimizer: Optional[ComponentConfig] = None scheduler: Optional[ComponentConfig] = None - def save_json(self, file_path: str, indent: int = 4) -> None: + def asdict(self) -> Dict[str, Any]: """ - Saves the configuration to a JSON file using dumps of each individual field. - - Args: - file_path (str): Destination path for the JSON file. - indent (int): Indentation level for the JSON file. + Produce a JSON‐serializable dict of this config, including nested + ComponentConfig and DatasetConfig entries. Useful for saving to file + or passing to experiment loggers (e.g. wandb.init(config=...)). + + Returns: + A dict with keys 'model', 'dataset_config', and (if set) + 'criterion', 'optimizer', 'scheduler'. """ - config_dump = { + data: Dict[str, Any] = { "model": self.model.dump(), - "dataset_config": self.dataset_config.model_dump() + "dataset_config": self.dataset_config.model_dump(), } if self.criterion is not None: - config_dump["criterion"] = self.criterion.dump() + data["criterion"] = self.criterion.dump() if self.optimizer is not None: - config_dump["optimizer"] = self.optimizer.dump() + data["optimizer"] = self.optimizer.dump() if self.scheduler is not None: - config_dump["scheduler"] = self.scheduler.dump() - + data["scheduler"] = self.scheduler.dump() + return data + + + def save_json(self, file_path: str, indent: int = 4) -> None: + """ + Save this config to a JSON file. + + Args: + file_path: Path to write the JSON file. + indent: JSON indent level. + """ + config_dict = self.asdict() with open(file_path, "w", encoding="utf-8") as f: - f.write(json.dumps(config_dump, indent=indent)) + f.write(json.dumps(config_dict, indent=indent)) @classmethod diff --git a/config/dataset_config.py b/config/dataset_config.py index 6c62103..4ea53d2 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -214,6 +214,34 @@ class DatasetTestingConfig(BaseModel): return self +class WandbConfig(BaseModel): + """ + Configuration for Weights & Biases logging. + """ + use_wandb: bool = False # Whether to enable WandB logging + project: Optional[str] = None # WandB project name + entity: Optional[str] = None # WandB entity (user or team) + name: Optional[str] = None # Name of the run + tags: Optional[list[str]] = None # List of tags for the run + notes: Optional[str] = None # Notes or description for the run + save_code: bool = True # Whether to save the code to WandB + + @model_validator(mode="after") + def validate_wandb(cls) -> "WandbConfig": + if cls.use_wandb: + if not cls.project: + raise ValueError("When use_wandb=True, 'project' must be provided") + if not cls.entity: + raise ValueError("When use_wandb=True, 'entity' must be provided") + return cls + + def asdict(self) -> Dict[str, Any]: + """ + Return a dict of all W&B parameters, excluding 'use_wandb' and any None values. + """ + return self.model_dump(exclude_none=True, exclude={"use_wandb"}) + + class DatasetConfig(BaseModel): """ Main dataset configuration that groups fields into nested models for a structured and readable JSON. @@ -222,6 +250,7 @@ class DatasetConfig(BaseModel): common: DatasetCommonConfig = DatasetCommonConfig() training: DatasetTrainingConfig = DatasetTrainingConfig() testing: DatasetTestingConfig = DatasetTestingConfig() + wandb: WandbConfig = WandbConfig() @model_validator(mode="after") def validate_config(self) -> "DatasetConfig": @@ -256,11 +285,13 @@ class DatasetConfig(BaseModel): return { "is_training": self.is_training, "common": self.common.model_dump(), - "training": self.training.model_dump() if self.training else {} + "training": self.training.model_dump() if self.training else {}, + "wandb": self.wandb.model_dump() } else: return { "is_training": self.is_training, "common": self.common.model_dump(), - "testing": self.testing.model_dump() if self.testing else {} + "testing": self.testing.model_dump() if self.testing else {}, + "wandb": self.wandb.model_dump() } diff --git a/core/segmentator.py b/core/segmentator.py index f84795f..142484a 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -592,6 +592,21 @@ class CellSegmentator: logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}") logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}") + wandb_cfg = config.dataset_config.wandb + if wandb_cfg.use_wandb: + logger.info("[W&B]") + logger.info(f"├─ Project: {wandb_cfg.project}") + logger.info(f"├─ Entity: {wandb_cfg.entity}") + if wandb_cfg.name: + logger.info(f"├─ Run name: {wandb_cfg.name}") + if wandb_cfg.tags: + logger.info(f"├─ Tags: {', '.join(wandb_cfg.tags)}") + if wandb_cfg.notes: + logger.info(f"├─ Notes: {wandb_cfg.notes}") + logger.info(f"└─ Save code: {'yes' if wandb_cfg.save_code else 'no'}") + else: + logger.info("[W&B] Logging disabled") + logger.info("===================================") @@ -705,7 +720,8 @@ class CellSegmentator: tablefmt="fancy_grid" ) print(table, "\n") - wandb.log(results, step=step) + if self._dataset_setup.wandb.use_wandb: + wandb.log(results, step=step) def __run_epoch(self, diff --git a/main.py b/main.py new file mode 100644 index 0000000..eea6098 --- /dev/null +++ b/main.py @@ -0,0 +1,41 @@ +import os +import wandb + +from config.config import Config +from core.data import * +from core.segmentator import CellSegmentator + + +if __name__ == "__main__": + config_path = 'config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json' + # config_path = 'config/templates/predict/ModelV.json' + + config = Config.load_json(config_path) + # config = Config.load_json(config_path) + + if config.dataset_config.wandb.use_wandb: + # Initialize W&B + wandb.init(config=config.asdict(), **config.dataset_config.wandb.asdict()) + + # How many batches to wait before logging training status + wandb.config.log_interval = 10 + + segmentator = CellSegmentator(config) + segmentator.create_dataloaders() + + # Watch parameters & gradients of model + if config.dataset_config.wandb.use_wandb: + wandb.watch(segmentator._model, log="all", log_graph=True) + + segmentator.run() + + weights_dir = "weights" if not config.dataset_config.wandb.use_wandb else wandb.run.dir # type: ignore + saving_path = os.path.join( + weights_dir, os.path.basename(config.dataset_config.common.predictions_dir) + '.pth' + ) + segmentator.save_checkpoint(saving_path) + + if config.dataset_config.wandb.use_wandb: + wandb.save(saving_path) + + diff --git a/train.py b/train.py deleted file mode 100644 index 2ce5b9c..0000000 --- a/train.py +++ /dev/null @@ -1,10 +0,0 @@ -from config.config import Config -from pprint import pprint - - -config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json') -pprint(config, indent=4) - -print('\n\n') -config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json') -pprint(config, indent=4) \ No newline at end of file