main script implemented

master
laynholt 2 months ago
parent 5d984dc7a9
commit 4a501ea31a

@ -37,27 +37,40 @@ class Config(BaseModel):
optimizer: Optional[ComponentConfig] = None optimizer: Optional[ComponentConfig] = None
scheduler: Optional[ComponentConfig] = None scheduler: Optional[ComponentConfig] = None
def save_json(self, file_path: str, indent: int = 4) -> None: def asdict(self) -> Dict[str, Any]:
""" """
Saves the configuration to a JSON file using dumps of each individual field. Produce a JSONserializable dict of this config, including nested
ComponentConfig and DatasetConfig entries. Useful for saving to file
or passing to experiment loggers (e.g. wandb.init(config=...)).
Args: Returns:
file_path (str): Destination path for the JSON file. A dict with keys 'model', 'dataset_config', and (if set)
indent (int): Indentation level for the JSON file. 'criterion', 'optimizer', 'scheduler'.
""" """
config_dump = { data: Dict[str, Any] = {
"model": self.model.dump(), "model": self.model.dump(),
"dataset_config": self.dataset_config.model_dump() "dataset_config": self.dataset_config.model_dump(),
} }
if self.criterion is not None: if self.criterion is not None:
config_dump["criterion"] = self.criterion.dump() data["criterion"] = self.criterion.dump()
if self.optimizer is not None: if self.optimizer is not None:
config_dump["optimizer"] = self.optimizer.dump() data["optimizer"] = self.optimizer.dump()
if self.scheduler is not None: if self.scheduler is not None:
config_dump["scheduler"] = self.scheduler.dump() data["scheduler"] = self.scheduler.dump()
return data
def save_json(self, file_path: str, indent: int = 4) -> None:
"""
Save this config to a JSON file.
Args:
file_path: Path to write the JSON file.
indent: JSON indent level.
"""
config_dict = self.asdict()
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(config_dump, indent=indent)) f.write(json.dumps(config_dict, indent=indent))
@classmethod @classmethod

@ -214,6 +214,34 @@ class DatasetTestingConfig(BaseModel):
return self return self
class WandbConfig(BaseModel):
"""
Configuration for Weights & Biases logging.
"""
use_wandb: bool = False # Whether to enable WandB logging
project: Optional[str] = None # WandB project name
entity: Optional[str] = None # WandB entity (user or team)
name: Optional[str] = None # Name of the run
tags: Optional[list[str]] = None # List of tags for the run
notes: Optional[str] = None # Notes or description for the run
save_code: bool = True # Whether to save the code to WandB
@model_validator(mode="after")
def validate_wandb(cls) -> "WandbConfig":
if cls.use_wandb:
if not cls.project:
raise ValueError("When use_wandb=True, 'project' must be provided")
if not cls.entity:
raise ValueError("When use_wandb=True, 'entity' must be provided")
return cls
def asdict(self) -> Dict[str, Any]:
"""
Return a dict of all W&B parameters, excluding 'use_wandb' and any None values.
"""
return self.model_dump(exclude_none=True, exclude={"use_wandb"})
class DatasetConfig(BaseModel): class DatasetConfig(BaseModel):
""" """
Main dataset configuration that groups fields into nested models for a structured and readable JSON. Main dataset configuration that groups fields into nested models for a structured and readable JSON.
@ -222,6 +250,7 @@ class DatasetConfig(BaseModel):
common: DatasetCommonConfig = DatasetCommonConfig() common: DatasetCommonConfig = DatasetCommonConfig()
training: DatasetTrainingConfig = DatasetTrainingConfig() training: DatasetTrainingConfig = DatasetTrainingConfig()
testing: DatasetTestingConfig = DatasetTestingConfig() testing: DatasetTestingConfig = DatasetTestingConfig()
wandb: WandbConfig = WandbConfig()
@model_validator(mode="after") @model_validator(mode="after")
def validate_config(self) -> "DatasetConfig": def validate_config(self) -> "DatasetConfig":
@ -256,11 +285,13 @@ class DatasetConfig(BaseModel):
return { return {
"is_training": self.is_training, "is_training": self.is_training,
"common": self.common.model_dump(), "common": self.common.model_dump(),
"training": self.training.model_dump() if self.training else {} "training": self.training.model_dump() if self.training else {},
"wandb": self.wandb.model_dump()
} }
else: else:
return { return {
"is_training": self.is_training, "is_training": self.is_training,
"common": self.common.model_dump(), "common": self.common.model_dump(),
"testing": self.testing.model_dump() if self.testing else {} "testing": self.testing.model_dump() if self.testing else {},
"wandb": self.wandb.model_dump()
} }

@ -592,6 +592,21 @@ class CellSegmentator:
logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}") logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}")
logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}") logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}")
wandb_cfg = config.dataset_config.wandb
if wandb_cfg.use_wandb:
logger.info("[W&B]")
logger.info(f"├─ Project: {wandb_cfg.project}")
logger.info(f"├─ Entity: {wandb_cfg.entity}")
if wandb_cfg.name:
logger.info(f"├─ Run name: {wandb_cfg.name}")
if wandb_cfg.tags:
logger.info(f"├─ Tags: {', '.join(wandb_cfg.tags)}")
if wandb_cfg.notes:
logger.info(f"├─ Notes: {wandb_cfg.notes}")
logger.info(f"└─ Save code: {'yes' if wandb_cfg.save_code else 'no'}")
else:
logger.info("[W&B] Logging disabled")
logger.info("===================================") logger.info("===================================")
@ -705,7 +720,8 @@ class CellSegmentator:
tablefmt="fancy_grid" tablefmt="fancy_grid"
) )
print(table, "\n") print(table, "\n")
wandb.log(results, step=step) if self._dataset_setup.wandb.use_wandb:
wandb.log(results, step=step)
def __run_epoch(self, def __run_epoch(self,

@ -0,0 +1,41 @@
import os
import wandb
from config.config import Config
from core.data import *
from core.segmentator import CellSegmentator
if __name__ == "__main__":
config_path = 'config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json'
# config_path = 'config/templates/predict/ModelV.json'
config = Config.load_json(config_path)
# config = Config.load_json(config_path)
if config.dataset_config.wandb.use_wandb:
# Initialize W&B
wandb.init(config=config.asdict(), **config.dataset_config.wandb.asdict())
# How many batches to wait before logging training status
wandb.config.log_interval = 10
segmentator = CellSegmentator(config)
segmentator.create_dataloaders()
# Watch parameters & gradients of model
if config.dataset_config.wandb.use_wandb:
wandb.watch(segmentator._model, log="all", log_graph=True)
segmentator.run()
weights_dir = "weights" if not config.dataset_config.wandb.use_wandb else wandb.run.dir # type: ignore
saving_path = os.path.join(
weights_dir, os.path.basename(config.dataset_config.common.predictions_dir) + '.pth'
)
segmentator.save_checkpoint(saving_path)
if config.dataset_config.wandb.use_wandb:
wandb.save(saving_path)

@ -1,10 +0,0 @@
from config.config import Config
from pprint import pprint
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json')
pprint(config, indent=4)
print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json')
pprint(config, indent=4)
Loading…
Cancel
Save