roi_size moved to common section; added multi image extentions support

master
laynholt 4 months ago
parent 662889c7a7
commit 40c21d5456

@ -52,33 +52,36 @@ Your data directory must follow this hierarchy:
``` ```
path_to_data_folder/ path_to_data_folder/
├── images/ # Input images ├── images/ # Input images (any supported format)
│ ├── img1.tif │ ├── img1.tif
│ ├── img2.tif │ ├── img2.png
│ └── ... │ └──
└── masks/ # Ground-truth instance masks └── masks/ # Ground-truth instance masks (any supported format)
├── mask1.tif ├── mask1.tif
├── mask2.tif ├── mask2.jpg
└── ... └──
``` ```
If your dataset contains multiple classes (e.g., class A and B) and you prefer not to duplicate images, you can organize masks into class-specific subdirectories: If your dataset contains multiple classes (e.g., class A and B) and you prefer not to duplicate images, you can organize masks into class-specific subdirectories:
``` ```
path_to_data_folder/ path_to_data_folder/
├── images/ ├── images/ # Input images (any supported format)
│ └── img1.tif │ └── img1.bmp
└── masks/ └── masks/
├── A/ # Masks for class A ├── A/ # Masks for class A (any supported format)
│ ├── img1_mask.tif │ ├── img1_mask.png
│ └── ... │ └──
└── B/ # Masks for class B └── B/ # Masks for class B (any supported format)
├── img1_mask.tif ├── img1_mask.jpeg
└── ... └──
``` ```
In this case, set the `masks_subdir` field in your dataset configuration to the name of the mask subdirectory (e.g., `"A"` or `"B"`). In this case, set the `masks_subdir` field in your dataset configuration to the name of the mask subdirectory (e.g., `"A"` or `"B"`).
**Supported file formats**: Image and mask files can have any of these extensions:
`tif`, `tiff`, `png`, `jpg`, `bmp`, `jpeg`.
**Mask format**: Instance masks should be provided for multi-label segmentation with channel-last ordering, i.e., each mask array must have shape `(H, W, C)`. **Mask format**: Instance masks should be provided for multi-label segmentation with channel-last ordering, i.e., each mask array must have shape `(H, W, C)`.
--- ---

@ -10,6 +10,7 @@ class DatasetCommonConfig(BaseModel):
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda') device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
roi_size: int = 512 # The size of the square window for cropping
masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars' masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars'
predictions_dir: str = "." # Directory to save predictions predictions_dir: str = "." # Directory to save predictions
pretrained_weights: str = "" # Path to pretrained weights pretrained_weights: str = "" # Path to pretrained weights
@ -21,6 +22,8 @@ class DatasetCommonConfig(BaseModel):
""" """
if not self.device: if not self.device:
raise ValueError("device must be provided and non-empty") raise ValueError("device must be provided and non-empty")
if self.roi_size <= 0:
raise ValueError("roi_size must be > 0")
return self return self
@ -70,7 +73,6 @@ class DatasetTrainingConfig(BaseModel):
test_offset: int = 0 # Offset for testing data test_offset: int = 0 # Offset for testing data
batch_size: int = 1 # Batch size for training batch_size: int = 1 # Batch size for training
roi_size: int = 512 # The size of the square window for cropping
num_epochs: int = 100 # Number of training epochs num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training val_freq: int = 1 # Frequency of validation during training
@ -124,14 +126,11 @@ class DatasetTrainingConfig(BaseModel):
""" """
Validates numeric fields: Validates numeric fields:
- batch_size and num_epochs must be > 0. - batch_size and num_epochs must be > 0.
- roi_size must be > 0.
- val_freq must be >= 0. - val_freq must be >= 0.
- offsets must be >= 0. - offsets must be >= 0.
""" """
if self.batch_size <= 0: if self.batch_size <= 0:
raise ValueError("batch_size must be > 0") raise ValueError("batch_size must be > 0")
if self.roi_size <= 0:
raise ValueError("roi_size must be > 0")
if self.num_epochs <= 0: if self.num_epochs <= 0:
raise ValueError("num_epochs must be > 0") raise ValueError("num_epochs must be > 0")
if self.val_freq < 0: if self.val_freq < 0:

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import tifffile as tif import tifffile as tif
import skimage.io as io import skimage.io as io
from typing import List, Optional, Sequence, Type, Union from typing import Final, List, Optional, Sequence, Type, Union
from monai.utils.enums import PostFix from monai.utils.enums import PostFix
from monai.utils.module import optional_import from monai.utils.module import optional_import
@ -24,8 +24,11 @@ __all__ = [
"CustomLoadImaged", # Dictionary-based image loader "CustomLoadImaged", # Dictionary-based image loader
"CustomLoadImageD", # Dictionary-based image loader "CustomLoadImageD", # Dictionary-based image loader
"CustomLoadImageDict", # Dictionary-based image loader "CustomLoadImageDict", # Dictionary-based image loader
"SUPPORTED_IMAGE_FORMATS"
] ]
SUPPORTED_IMAGE_FORMATS: Final[Sequence[str]] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg"]
class CustomLoadImage(LoadImage): class CustomLoadImage(LoadImage):
""" """
@ -150,8 +153,7 @@ class UniversalImageReader(NumpyReader):
Supported extensions: tif, tiff, png, jpg, bmp, jpeg. Supported extensions: tif, tiff, png, jpg, bmp, jpeg.
""" """
suffixes: Sequence[str] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg"] return has_itk or is_supported_format(filename, SUPPORTED_IMAGE_FORMATS)
return has_itk or is_supported_format(filename, suffixes)
def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs):
""" """

@ -38,6 +38,7 @@ import csv
import copy import copy
import time import time
import tifffile as tiff import tifffile as tiff
from itertools import chain
from pprint import pformat from pprint import pformat
from tabulate import tabulate from tabulate import tabulate
@ -56,6 +57,7 @@ from core.utils import (
compute_f1_score, compute_f1_score,
compute_average_precision_score compute_average_precision_score
) )
from core.data.transforms.load_image import SUPPORTED_IMAGE_FORMATS
from core.logger import get_logger from core.logger import get_logger
@ -749,8 +751,9 @@ class CellSegmentator:
""" """
# Collect sorted list of image paths # Collect sorted list of image paths
images = sorted( images = sorted(
glob.glob(os.path.join(images_dir, '*.tif')) + chain(glob.glob(
glob.glob(os.path.join(images_dir, '*.tiff')) os.path.join(images_dir, f'*.{ext}')) for ext in SUPPORTED_IMAGE_FORMATS
)
) )
if not images: if not images:
raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'") raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'")
@ -758,8 +761,9 @@ class CellSegmentator:
if masks_dir is not None: if masks_dir is not None:
# Collect and validate sorted list of mask paths # Collect and validate sorted list of mask paths
masks = sorted( masks = sorted(
glob.glob(os.path.join(masks_dir, '*.tif')) + chain(glob.glob(
glob.glob(os.path.join(masks_dir, '*.tiff')) os.path.join(masks_dir, f'*.{ext}')) for ext in SUPPORTED_IMAGE_FORMATS
)
) )
if len(images) != len(masks): if len(images) != len(masks):
raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})") raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})")
@ -1031,7 +1035,7 @@ class CellSegmentator:
# Use sliding window inference for non-training phases # Use sliding window inference for non-training phases
outputs = sliding_window_inference( outputs = sliding_window_inference(
inputs, inputs,
roi_size=512, roi_size=self._dataset_setup.common.roi_size,
sw_batch_size=4, sw_batch_size=4,
predictor=self._model, predictor=self._model,
padding_mode="constant", padding_mode="constant",

@ -60,7 +60,7 @@ def main():
segmentator = CellSegmentator(config) segmentator = CellSegmentator(config)
segmentator.create_dataloaders( segmentator.create_dataloaders(
train_transforms=get_train_transforms( train_transforms=get_train_transforms(
roi_size=config.dataset_config.training.roi_size) if mode == "train" else None, roi_size=config.dataset_config.common.roi_size) if mode == "train" else None,
valid_transforms=get_valid_transforms() if mode == "train" else None, valid_transforms=get_valid_transforms() if mode == "train" else None,
test_transforms=get_test_transforms() if mode in ("train", "test") else None, test_transforms=get_test_transforms() if mode in ("train", "test") else None,
predict_transforms=get_predict_transforms() if mode == "predict" else None predict_transforms=get_predict_transforms() if mode == "predict" else None

Loading…
Cancel
Save