fix: correct auto split sampling, fix non-TIF memory leak, update README

**Bug fix**

1. Fixed an error in dataset sampling during automatic splitting.
2. Fixed a memory leak when loading images in formats other than `.tif`.

**Changes:**

1. `shuffle` can now be used with both `pre-split` and `split` methods.
2. `offsets` can be specified as floating-point values for dynamic computation.
3. `size` and `offsets` now support mixed formats (i.e., both `int` and `float`).
4. F1 and mAP metrics are now computed by default in `micro`, `macro`, and `per_class` variations.
5. Classes `BoundaryExclusion` and `IntensityDiversification` have been renamed to follow the MONAI naming style.
6. Minor updates to the README for clarity and consistency.
master
laynholt 3 weeks ago
parent f728464a52
commit bf90288a95

@ -160,6 +160,8 @@ A brief overview of the key parameters you can adjust in your JSON config:
* `is_split` (bool): Whether your data is already split (`true`) or needs splitting (`false`, default).
* `split` / `pre_split`: Directories for data when pre-split or unsplit.
* `train_size`, `valid_size`, `test_size` (int/float): Size or ratio of your splits (e.g., `0.7`, `0.1`, `0.2`).
* `train_offset`, `valid_offset`, `test_offset` (int/float): The offset by which to take samples. When the data is not split, the samples are formed in the following order: `train`, `valid`, `test` (default: `0`, `0`, `0`).
* `shuffle` (bool): Flag for shuffling data when creating samples (default: `false`).
* `batch_size` (int): Number of samples per training batch (default: `1`).
* `num_epochs` (int): Total training epochs (default: `100`).
* `val_freq` (int): Frequency (in epochs) to run validation (default: `1`).
@ -168,6 +170,7 @@ A brief overview of the key parameters you can adjust in your JSON config:
* `test_dir` (str): Directory containing test data (default: `"."`).
* `test_size` (int/float): Portion or count of data for testing (default: `1.0`).
* `test_offset` (int/float): The amount of data by which the sample will be shifted before forming (default: `0`).
* `shuffle` (bool): Shuffle test data before evaluation (default: `true`).
> **Batch size note:** Validation, testing, and prediction runs always use a batch size of `1`, regardless of the `batch_size` setting in the training configuration.
@ -190,12 +193,22 @@ python generate_config.py
python main.py -c config/templates/train/YourConfig.json -m train
```
> After training, the model will automatically attempt to perform testing if the directory for the test data was specified in the configuration file.
### Test a model
```bash
python main.py -c config/templates/predict/YourConfig.json -m test
```
### Predict on new data
```bash
python main.py -c config/templates/predict/YourConfig.json -m predict
```
> Unlike prediction testing, it is not necessary that the specified test directory contains a folder with true masks.
---
## Acknowledgments

@ -9,7 +9,7 @@ class DatasetCommonConfig(BaseModel):
"""
seed: int | None = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
use_amp: bool = True # Flag to use Automatic Mixed Precision (AMP)
roi_size: int = 512 # The size of the square window for cropping
remove_boundary_objects: bool = True # Flag to remove boundary objects when testing
masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars'
@ -34,7 +34,6 @@ class TrainingSplitInfo(BaseModel):
Contains:
- all_data_dir: Directory containing all data.
"""
shuffle: bool = True # Shuffle data before splitting
all_data_dir: str = "." # Directory containing all data if not pre-split
class TrainingPreSplitInfo(BaseModel):
@ -69,30 +68,31 @@ class DatasetTrainingConfig(BaseModel):
train_size: int | float = 0.7 # Training data size (int for static, float in (0,1] for dynamic)
valid_size: int | float = 0.1 # Validation data size (int for static, float in (0,1] for dynamic)
test_size: int | float = 0.2 # Testing data size (int for static, float in (0,1] for dynamic)
train_offset: int = 0 # Offset for training data
valid_offset: int = 0 # Offset for validation data
test_offset: int = 0 # Offset for testing data
train_offset: int | float = 0 # Offset for training data (int for static, float in (0,1] for dynamic)
valid_offset: int | float = 0 # Offset for validation data (int for static, float in (0,1] for dynamic)
test_offset: int | float = 0 # Offset for testing data (int for static, float in (0,1] for dynamic)
shuffle: bool = False # Shuffle data
batch_size: int = 1 # Batch size for training
num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training
@field_validator("train_size", "valid_size", "test_size", mode="before")
def validate_sizes(cls, v: int | float) -> int | float:
@field_validator("train_size", "valid_size", "test_size", "train_offset", "valid_offset", "test_offset", mode="before")
def validate_sizes_and_offsets(cls, v: int | float) -> int | float:
"""
Validates size values:
Validates size and offset values:
- If provided as a float, must be in the range (0, 1].
- If provided as an int, must be non-negative.
"""
if isinstance(v, float):
if not (0 <= v <= 1):
raise ValueError("When provided as a float, size must be in the range (0, 1]")
raise ValueError("When provided as a float, size and offset must be in the range [0, 1]")
elif isinstance(v, int):
if v < 0:
raise ValueError("When provided as an int, size must be non-negative")
raise ValueError("When provided as an int, size and offset must be non-negative")
else:
raise ValueError("Size must be either an int or a float")
raise ValueError("Size and offset must be either an int or a float")
return v
@model_validator(mode="after")
@ -102,10 +102,6 @@ class DatasetTrainingConfig(BaseModel):
- If is_split is True, validates pre_split (train_dir must be non-empty and exist; if provided, valid_dir and test_dir must exist).
- If is_split is False, validates split (all_data_dir must be non-empty and exist).
"""
if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)):
if (self.train_size + self.valid_size + self.test_size) > 1 and not self.is_split:
raise ValueError("The total sample size with dynamically defined sizes must be <= 1")
if not self.is_split:
if not self.split.all_data_dir:
raise ValueError("When is_split is False, all_data_dir must be provided and non-empty in pre_split")
@ -128,7 +124,6 @@ class DatasetTrainingConfig(BaseModel):
Validates numeric fields:
- batch_size and num_epochs must be > 0.
- val_freq must be >= 0.
- offsets must be >= 0.
"""
if self.batch_size <= 0:
raise ValueError("batch_size must be > 0")
@ -136,8 +131,6 @@ class DatasetTrainingConfig(BaseModel):
raise ValueError("num_epochs must be > 0")
if self.val_freq < 0:
raise ValueError("val_freq must be >= 0")
if self.train_offset < 0 or self.valid_offset < 0 or self.test_offset < 0:
raise ValueError("offsets must be >= 0")
return self
@ -145,36 +138,26 @@ class DatasetTestingConfig(BaseModel):
"""
Configuration fields used only in testing mode.
"""
test_dir: str = "." # Test data directory; must be non-empty
test_size: int | float = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
test_offset: int = 0 # Offset for testing data
test_dir: str = "." # Test data directory; must be non-empty
test_size: int | float = 1.0 # Testing data size (int for static, float in (0,1] for dynamic)
test_offset: int | float = 0 # Offset for testing data
shuffle: bool = True # Shuffle data
@field_validator("test_size", mode="before")
def validate_test_size(cls, v: int | float) -> int | float:
@field_validator("test_size", "test_offset", mode="before")
def validate_test_size_and_offset(cls, v: int | float) -> int | float:
"""
Validates the test_size value.
"""
if isinstance(v, float):
if not (0 < v <= 1):
raise ValueError("When provided as a float, test_size must be in the range (0, 1]")
raise ValueError("When provided as a float, test_size and test_offset must be in the range (0, 1]")
elif isinstance(v, int):
if v < 0:
raise ValueError("When provided as an int, test_size must be non-negative")
raise ValueError("When provided as an int, test_size and test_offset must be non-negative")
else:
raise ValueError("test_size must be either an int or a float")
raise ValueError("test_size and test_offset must be either an int or a float")
return v
@model_validator(mode="after")
def validate_numeric_fields(self) -> "DatasetTestingConfig":
"""
Validates numeric fields:
- test_offset must be >= 0.
"""
if self.test_offset < 0:
raise ValueError("test_offset must be >= 0")
return self
@model_validator(mode="after")
def validate_testing(self) -> "DatasetTestingConfig":
"""
@ -185,8 +168,6 @@ class DatasetTestingConfig(BaseModel):
raise ValueError("In testing configuration, test_dir must be provided and non-empty")
if not os.path.exists(self.test_dir):
raise ValueError(f"Path for test_dir does not exist: {self.test_dir}")
if self.test_offset < 0:
raise ValueError("test_offset must be >= 0")
return self

@ -1,4 +1,4 @@
from .cell_aware import IntensityDiversification
from .cell_aware import IntensityDiversificationd
from .load_image import CustomLoadImaged
from .normalize_image import CustomNormalizeImaged
@ -66,7 +66,7 @@ def get_train_transforms(roi_size: int = 512):
# Randomly rotate the image and label by 90 degrees.
RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=(0, 1)),
# Diversify intensities for selected cell regions.
IntensityDiversification(keys=["image", "mask"], allow_missing_keys=True),
IntensityDiversificationd(keys=["image", "mask"], allow_missing_keys=True),
# Apply random Gaussian noise to the image.
RandGaussianNoised(keys=["image"], prob=0.25, mean=0, std=0.1),
# Randomly adjust the contrast of the image.

@ -7,13 +7,20 @@ from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: i
from core.logger import get_logger
__all__ = ["BoundaryExclusion", "IntensityDiversification"]
__all__ = [
"BoundaryExclusiond",
"BoundaryExclusionD",
"BoundaryExclusionDict",
"IntensityDiversificationd",
"IntensityDiversificationD",
"IntensityDiversificationDict"
]
logger = get_logger(__name__)
class BoundaryExclusion(MapTransform):
class BoundaryExclusiond(MapTransform):
"""
Map the cell boundary pixel labels to the background class (0).
@ -87,7 +94,7 @@ class BoundaryExclusion(MapTransform):
return data
class IntensityDiversification(MapTransform):
class IntensityDiversificationd(MapTransform):
"""
Randomly rescale the intensity of cell pixels.
@ -202,3 +209,7 @@ class IntensityDiversification(MapTransform):
data["image"][c] = img_orig + img_changed
return data
BoundaryExclusionD = BoundaryExclusionDict = BoundaryExclusiond
IntensityDiversificationD = IntensityDiversificationDict = IntensityDiversificationd

@ -1,10 +1,8 @@
import numpy as np
import tifffile as tif
import skimage.io as io
import imageio.v3 as iio
from typing import Final, Sequence, Type
from monai.utils.enums import PostFix
from monai.utils.module import optional_import
from monai.utils.misc import ensure_tuple, ensure_tuple_rep
from monai.data.utils import is_supported_format
from monai.data.image_reader import ImageReader, NumpyReader
@ -15,9 +13,6 @@ from monai.config.type_definitions import DtypeLike, PathLike, KeysCollection
# Default value for metadata postfix
DEFAULT_POST_FIX = PostFix.meta()
# Try to import ITK library; if not available, has_itk will be False
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
__all__ = [
"CustomLoadImage", # Basic image loader
@ -142,9 +137,7 @@ class UniversalImageReader(NumpyReader):
Universal image reader for TIFF, PNG, JPG, BMP, etc.
Uses:
- tifffile for reading TIFF files.
- ITK (if available) for reading other formats.
- skimage.io for reading if the previous methods fail.
- imageio.v3 for reading files.
The image is loaded with its original number of channels (layers) without forced modifications
(e.g., repeating or cropping channels).
@ -162,7 +155,7 @@ class UniversalImageReader(NumpyReader):
Supported extensions: tif, tiff, png, jpg, bmp, jpeg.
"""
return has_itk or is_supported_format(filename, SUPPORTED_IMAGE_FORMATS)
return is_supported_format(filename, SUPPORTED_IMAGE_FORMATS)
def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
"""
@ -186,17 +179,17 @@ class UniversalImageReader(NumpyReader):
for name in filenames:
# Convert file name to string
name = f"{name}"
# If the file has a .tif or .tiff extension (case-insensitive), use tifffile for reading
if name.lower().endswith((".tif", ".tiff")):
img_array = tif.imread(name)
else:
# Attempt to read the image using ITK (if available)
try:
img_itk = itk.imread(name, **kwargs_)
img_array = itk.array_view_from_image(img_itk, keep_axes=False)
except Exception:
# If ITK fails, use skimage.io for reading
img_array = io.imread(name)
img_array = iio.imread(name, **kwargs)
# copy only if needed: not contiguous, not writeable, or is a view
needs_copy = (
not img_array.flags["C_CONTIGUOUS"]
or not img_array.flags["WRITEABLE"]
or (img_array.base is not None) # likely a view on another object/buffer
)
if needs_copy:
# Ensure NumPy owns the buffer to avoid backend-held memory.
img_array = np.array(img_array, copy=True, order="C")
# Check the number of dimensions (axes) of the loaded image
if img_array.ndim == 2:

@ -127,11 +127,27 @@ class CellSegmentator:
valid_dir = self._dataset_setup.training.pre_split.valid_dir
test_dir = self._dataset_setup.training.pre_split.test_dir
train_offset = self._dataset_setup.training.train_offset
valid_offset = self._dataset_setup.training.valid_offset
test_offset = self._dataset_setup.training.test_offset
train_number_of_images = len(os.listdir(os.path.join(train_dir, 'images')))
valid_number_of_images = len(os.listdir(os.path.join(valid_dir, 'images')))
test_number_of_images = len(os.listdir(os.path.join(test_dir, 'images')))
train_offset = (
self._dataset_setup.training.train_offset
if isinstance(self._dataset_setup.training.train_offset, int)
else int(train_number_of_images * self._dataset_setup.training.train_offset)
)
valid_offset = (
self._dataset_setup.training.valid_offset
if isinstance(self._dataset_setup.training.valid_offset, int)
else int(valid_number_of_images * self._dataset_setup.training.valid_offset)
)
test_offset = (
self._dataset_setup.training.test_offset
if isinstance(self._dataset_setup.training.test_offset, int)
else int(test_number_of_images * self._dataset_setup.training.test_offset)
)
shuffle = False
shuffle = self._dataset_setup.training.shuffle
else:
# Same validation for split mode with full data directory
if (
@ -160,11 +176,23 @@ class CellSegmentator:
else int(number_of_images * self._dataset_setup.training.valid_size)
)
train_offset = self._dataset_setup.training.train_offset
valid_offset = self._dataset_setup.training.valid_offset + train_size
test_offset = self._dataset_setup.training.test_offset + train_size + valid_size
shuffle = True
train_offset = (
self._dataset_setup.training.train_offset
if isinstance(self._dataset_setup.training.train_offset, int)
else int(number_of_images * self._dataset_setup.training.train_offset)
)
valid_offset = (
self._dataset_setup.training.valid_offset
if isinstance(self._dataset_setup.training.valid_offset, int)
else int(number_of_images * self._dataset_setup.training.valid_offset)
) + train_size + train_offset
test_offset = (
self._dataset_setup.training.test_offset
if isinstance(self._dataset_setup.training.test_offset, int)
else int(number_of_images * self._dataset_setup.training.test_offset)
) + valid_offset + valid_size
shuffle = self._dataset_setup.training.shuffle
# Train dataloader
train_dataset = self.__get_dataset(
@ -370,7 +398,7 @@ class CellSegmentator:
self.__print_with_logging(valid_metrics, epoch)
# Update best model on improved F1
f1 = valid_metrics.get("valid_f1_score", 0.0)
f1 = valid_metrics.get("valid_f1_score_micro", 0.0)
if f1 > best_f1_score:
best_f1_score = f1
# Deep copy weights to avoid reference issues
@ -426,7 +454,7 @@ class CellSegmentator:
):
# Disable gradient computation for inference
with torch.no_grad():
# Run the models forward pass in predict mode
# Run the model's forward pass in 'predict' mode
raw_output = self.__run_inference(inputs, mode="predict")
# Convert logits/probabilities to discrete instance masks
@ -451,7 +479,7 @@ class CellSegmentator:
- If training is enabled in the dataset setup, start training.
- Otherwise, if a test DataLoader is provided, run evaluation.
- Else if a prediction DataLoader is provided, run inference/prediction.
- If neither loader is available in nontraining mode, raise an error.
- If neither loader is available in non-training mode, raise an error.
Args:
save_results (bool): If True, the predicted masks and test metrics will be saved.
@ -481,7 +509,7 @@ class CellSegmentator:
else:
# 3) ERROR: no appropriate loader found
raise RuntimeError(
"Neither test nor predict DataLoader is set for nontraining mode."
"Neither test nor predict DataLoader is set for non-training mode."
)
elapsed = time.time() - start_time
@ -652,6 +680,7 @@ class CellSegmentator:
if config.dataset_config.is_training:
training = config.dataset_config.training
logger.info("[MODE] Training")
logger.info(f"├─ Shuffle: {'yes' if training.shuffle else 'no'}")
logger.info(f"├─ Batch size: {training.batch_size}")
logger.info(f"├─ Epochs: {training.num_epochs}")
logger.info(f"├─ Validation frequency: {training.val_freq}")
@ -664,7 +693,6 @@ class CellSegmentator:
else:
logger.info( "├─ Using unified dataset with splits:")
logger.info( "│ ├─ All data dir: {training.split.all_data_dir}")
logger.info(f"│ └─ Shuffle: {'yes' if training.split.shuffle else 'no'}")
logger.info( "└─ Dataset split:")
logger.info(f" ├─ Train size: {training.train_size}, offset: {training.train_offset}")
@ -780,6 +808,7 @@ class CellSegmentator:
# Shuffle image-mask pairs if requested
if shuffle:
self.__set_seed(self._dataset_setup.common.seed)
if masks_dir is not None:
combined = list(zip(images, masks)) # type: ignore
random.shuffle(combined)
@ -1001,15 +1030,27 @@ class CellSegmentator:
fp_array = np.vstack(all_fp)
fn_array = np.vstack(all_fn)
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric(
epoch_metrics[f"{mode}_f1_score_micro"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="micro"
)
epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="imagewise"
)
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric(
epoch_metrics[f"{mode}_f1_score_pc"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="per_class"
)
epoch_metrics[f"{mode}_f1_score_macro"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="macro"
)
epoch_metrics[f"{mode}_mAP_micro"] = self.__compute_average_precision_metric(
tp_array, fp_array, fn_array, reduction="micro"
)
epoch_metrics[f"{mode}_mAP_macro"] = self.__compute_average_precision_metric(
tp_array, fp_array, fn_array, reduction="macro"
)
epoch_metrics[f"{mode}_mAP_pc"] = self.__compute_average_precision_metric(
tp_array, fp_array, fn_array, reduction="per_class"
)
return epoch_metrics

@ -9,10 +9,12 @@ dependencies = [
"fastremap>=1.16.1",
"fill-voids>=2.0.8",
"imagecodecs>=2025.3.30",
"imageio>=2.37.0",
"matplotlib>=3.10.3",
"monai>=1.4.0",
"numba>=0.61.2",
"numpy>=1.26.4",
"pillow>=11.2.1",
"pydantic>=2.11.4",
"scikit-image>=0.25.2",
"scipy>=1.15.3",

1468
uv.lock

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save