diff --git a/README.md b/README.md index 9e391b0..d3a54bc 100644 --- a/README.md +++ b/README.md @@ -52,33 +52,36 @@ Your data directory must follow this hierarchy: ``` path_to_data_folder/ -├── images/ # Input images +├── images/ # Input images (any supported format) │ ├── img1.tif -│ ├── img2.tif -│ └── ... -└── masks/ # Ground-truth instance masks +│ ├── img2.png +│ └── … +└── masks/ # Ground-truth instance masks (any supported format) ├── mask1.tif - ├── mask2.tif - └── ... + ├── mask2.jpg + └── … ``` If your dataset contains multiple classes (e.g., class A and B) and you prefer not to duplicate images, you can organize masks into class-specific subdirectories: ``` path_to_data_folder/ -├── images/ -│ └── img1.tif +├── images/ # Input images (any supported format) +│ └── img1.bmp └── masks/ - ├── A/ # Masks for class A - │ ├── img1_mask.tif - │ └── ... - └── B/ # Masks for class B - ├── img1_mask.tif - └── ... + ├── A/ # Masks for class A (any supported format) + │ ├── img1_mask.png + │ └── … + └── B/ # Masks for class B (any supported format) + ├── img1_mask.jpeg + └── … ``` In this case, set the `masks_subdir` field in your dataset configuration to the name of the mask subdirectory (e.g., `"A"` or `"B"`). +**Supported file formats**: Image and mask files can have any of these extensions: +`tif`, `tiff`, `png`, `jpg`, `bmp`, `jpeg`. + **Mask format**: Instance masks should be provided for multi-label segmentation with channel-last ordering, i.e., each mask array must have shape `(H, W, C)`. --- diff --git a/config/dataset_config.py b/config/dataset_config.py index db71de1..0b1a646 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -10,6 +10,7 @@ class DatasetCommonConfig(BaseModel): seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda') use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) + roi_size: int = 512 # The size of the square window for cropping masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars' predictions_dir: str = "." # Directory to save predictions pretrained_weights: str = "" # Path to pretrained weights @@ -21,6 +22,8 @@ class DatasetCommonConfig(BaseModel): """ if not self.device: raise ValueError("device must be provided and non-empty") + if self.roi_size <= 0: + raise ValueError("roi_size must be > 0") return self @@ -70,7 +73,6 @@ class DatasetTrainingConfig(BaseModel): test_offset: int = 0 # Offset for testing data batch_size: int = 1 # Batch size for training - roi_size: int = 512 # The size of the square window for cropping num_epochs: int = 100 # Number of training epochs val_freq: int = 1 # Frequency of validation during training @@ -124,14 +126,11 @@ class DatasetTrainingConfig(BaseModel): """ Validates numeric fields: - batch_size and num_epochs must be > 0. - - roi_size must be > 0. - val_freq must be >= 0. - offsets must be >= 0. """ if self.batch_size <= 0: raise ValueError("batch_size must be > 0") - if self.roi_size <= 0: - raise ValueError("roi_size must be > 0") if self.num_epochs <= 0: raise ValueError("num_epochs must be > 0") if self.val_freq < 0: diff --git a/core/data/transforms/load_image.py b/core/data/transforms/load_image.py index 2929131..2ff9f55 100644 --- a/core/data/transforms/load_image.py +++ b/core/data/transforms/load_image.py @@ -1,7 +1,7 @@ import numpy as np import tifffile as tif import skimage.io as io -from typing import List, Optional, Sequence, Type, Union +from typing import Final, List, Optional, Sequence, Type, Union from monai.utils.enums import PostFix from monai.utils.module import optional_import @@ -24,8 +24,11 @@ __all__ = [ "CustomLoadImaged", # Dictionary-based image loader "CustomLoadImageD", # Dictionary-based image loader "CustomLoadImageDict", # Dictionary-based image loader + "SUPPORTED_IMAGE_FORMATS" ] +SUPPORTED_IMAGE_FORMATS: Final[Sequence[str]] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg"] + class CustomLoadImage(LoadImage): """ @@ -150,8 +153,7 @@ class UniversalImageReader(NumpyReader): Supported extensions: tif, tiff, png, jpg, bmp, jpeg. """ - suffixes: Sequence[str] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg"] - return has_itk or is_supported_format(filename, suffixes) + return has_itk or is_supported_format(filename, SUPPORTED_IMAGE_FORMATS) def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ diff --git a/core/segmentator.py b/core/segmentator.py index a93f20b..f67ed5a 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -38,6 +38,7 @@ import csv import copy import time import tifffile as tiff +from itertools import chain from pprint import pformat from tabulate import tabulate @@ -56,6 +57,7 @@ from core.utils import ( compute_f1_score, compute_average_precision_score ) +from core.data.transforms.load_image import SUPPORTED_IMAGE_FORMATS from core.logger import get_logger @@ -749,8 +751,9 @@ class CellSegmentator: """ # Collect sorted list of image paths images = sorted( - glob.glob(os.path.join(images_dir, '*.tif')) + - glob.glob(os.path.join(images_dir, '*.tiff')) + chain(glob.glob( + os.path.join(images_dir, f'*.{ext}')) for ext in SUPPORTED_IMAGE_FORMATS + ) ) if not images: raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'") @@ -758,8 +761,9 @@ class CellSegmentator: if masks_dir is not None: # Collect and validate sorted list of mask paths masks = sorted( - glob.glob(os.path.join(masks_dir, '*.tif')) + - glob.glob(os.path.join(masks_dir, '*.tiff')) + chain(glob.glob( + os.path.join(masks_dir, f'*.{ext}')) for ext in SUPPORTED_IMAGE_FORMATS + ) ) if len(images) != len(masks): raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})") @@ -1031,7 +1035,7 @@ class CellSegmentator: # Use sliding window inference for non-training phases outputs = sliding_window_inference( inputs, - roi_size=512, + roi_size=self._dataset_setup.common.roi_size, sw_batch_size=4, predictor=self._model, padding_mode="constant", diff --git a/main.py b/main.py index b1415f3..5c2d091 100644 --- a/main.py +++ b/main.py @@ -60,7 +60,7 @@ def main(): segmentator = CellSegmentator(config) segmentator.create_dataloaders( train_transforms=get_train_transforms( - roi_size=config.dataset_config.training.roi_size) if mode == "train" else None, + roi_size=config.dataset_config.common.roi_size) if mode == "train" else None, valid_transforms=get_valid_transforms() if mode == "train" else None, test_transforms=get_test_transforms() if mode in ("train", "test") else None, predict_transforms=get_predict_transforms() if mode == "predict" else None