From 662889c7a71646a91059506ed0b371495382f823 Mon Sep 17 00:00:00 2001 From: laynholt Date: Fri, 9 May 2025 23:33:24 +0000 Subject: [PATCH] add roi_size parameter to training --- config/dataset_config.py | 4 ++++ core/data/transforms/__init__.py | 9 ++++++--- main.py | 3 ++- uv.lock | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/config/dataset_config.py b/config/dataset_config.py index cda221d..db71de1 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -70,6 +70,7 @@ class DatasetTrainingConfig(BaseModel): test_offset: int = 0 # Offset for testing data batch_size: int = 1 # Batch size for training + roi_size: int = 512 # The size of the square window for cropping num_epochs: int = 100 # Number of training epochs val_freq: int = 1 # Frequency of validation during training @@ -123,11 +124,14 @@ class DatasetTrainingConfig(BaseModel): """ Validates numeric fields: - batch_size and num_epochs must be > 0. + - roi_size must be > 0. - val_freq must be >= 0. - offsets must be >= 0. """ if self.batch_size <= 0: raise ValueError("batch_size must be > 0") + if self.roi_size <= 0: + raise ValueError("roi_size must be > 0") if self.num_epochs <= 0: raise ValueError("num_epochs must be > 0") if self.val_freq < 0: diff --git a/core/data/transforms/__init__.py b/core/data/transforms/__init__.py index 74335b8..03e538a 100644 --- a/core/data/transforms/__init__.py +++ b/core/data/transforms/__init__.py @@ -13,7 +13,7 @@ __all__ = [ ] -def get_train_transforms(): +def get_train_transforms(roi_size: int = 512): """ Returns the transformation pipeline for training data. @@ -27,6 +27,9 @@ def get_train_transforms(): 7. Apply additional intensity perturbations (noise, contrast, smoothing, histogram shift, and sharpening). 8. Convert the data types to the desired format. + Args: + roi_size (int): + The size of the square window for cropping. (Default 512). Returns: Compose: The composed transformation pipeline for training. """ @@ -55,9 +58,9 @@ def get_train_transforms(): keep_size=False, ), # Pad spatial dimensions to ensure a size of 512. - SpatialPadd(keys=["image", "mask"], spatial_size=512), + SpatialPadd(keys=["image", "mask"], spatial_size=roi_size), # Randomly crop a region of interest of size 512. - RandSpatialCropd(keys=["image", "mask"], roi_size=512, random_size=False), + RandSpatialCropd(keys=["image", "mask"], roi_size=roi_size, random_size=False), # Randomly flip the image and label along an axis. RandAxisFlipd(keys=["image", "mask"], prob=0.5), # Randomly rotate the image and label by 90 degrees. diff --git a/main.py b/main.py index 3bf82ed..b1415f3 100644 --- a/main.py +++ b/main.py @@ -59,7 +59,8 @@ def main(): segmentator = CellSegmentator(config) segmentator.create_dataloaders( - train_transforms=get_train_transforms() if mode == "train" else None, + train_transforms=get_train_transforms( + roi_size=config.dataset_config.training.roi_size) if mode == "train" else None, valid_transforms=get_valid_transforms() if mode == "train" else None, test_transforms=get_test_transforms() if mode in ("train", "test") else None, predict_transforms=get_predict_transforms() if mode == "predict" else None diff --git a/uv.lock b/uv.lock index 98089bf..0b7f59e 100644 --- a/uv.lock +++ b/uv.lock @@ -607,7 +607,7 @@ wheels = [ [[package]] name = "model-v" -version = "0.1.0" +version = "1.0.0" source = { virtual = "." } dependencies = [ { name = "colorlog" },