parent
33ce003657
commit
78f97a72a2
@ -0,0 +1,14 @@
|
|||||||
|
from .transforms import (
|
||||||
|
get_train_transforms,
|
||||||
|
get_valid_transforms,
|
||||||
|
get_test_transforms,
|
||||||
|
get_pred_transforms
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_train_transforms",
|
||||||
|
"get_valid_transforms",
|
||||||
|
"get_test_transforms",
|
||||||
|
"get_pred_transforms",
|
||||||
|
]
|
@ -0,0 +1,184 @@
|
|||||||
|
from .cell_aware import IntensityDiversification
|
||||||
|
from .load_image import CustomLoadImage, CustomLoadImaged
|
||||||
|
from .normalize_image import CustomNormalizeImage, CustomNormalizeImaged
|
||||||
|
|
||||||
|
from monai.transforms import * # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_train_transforms",
|
||||||
|
"get_valid_transforms",
|
||||||
|
"get_test_transforms",
|
||||||
|
"get_pred_transforms",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_transforms():
|
||||||
|
"""
|
||||||
|
Returns the transformation pipeline for training data.
|
||||||
|
|
||||||
|
The training pipeline applies a series of image and label preprocessing steps:
|
||||||
|
1. Load image and label data.
|
||||||
|
2. Normalize the image intensities.
|
||||||
|
3. Ensure the image and label have channel-first format.
|
||||||
|
4. Scale image intensities.
|
||||||
|
5. Apply spatial transformations (zoom, padding, cropping, flipping, and rotation).
|
||||||
|
6. Diversify intensities for selected cell regions.
|
||||||
|
7. Apply additional intensity perturbations (noise, contrast, smoothing, histogram shift, and sharpening).
|
||||||
|
8. Convert the data types to the desired format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compose: The composed transformation pipeline for training.
|
||||||
|
"""
|
||||||
|
train_transforms = Compose(
|
||||||
|
[
|
||||||
|
# Load image and label data in (H, W, C) format (image loaded as image-only).
|
||||||
|
CustomLoadImaged(keys=["img", "label"], image_only=True),
|
||||||
|
# Normalize the (H, W, C) image using the specified percentiles.
|
||||||
|
CustomNormalizeImaged(
|
||||||
|
keys=["img"],
|
||||||
|
allow_missing_keys=True,
|
||||||
|
channel_wise=False,
|
||||||
|
percentiles=[0.0, 99.5],
|
||||||
|
),
|
||||||
|
# Ensure both image and label are in channel-first format.
|
||||||
|
EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
|
||||||
|
# Scale image intensities (do not scale the label).
|
||||||
|
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
||||||
|
# Apply random zoom to both image and label.
|
||||||
|
RandZoomd(
|
||||||
|
keys=["img", "label"],
|
||||||
|
prob=0.5,
|
||||||
|
min_zoom=0.25,
|
||||||
|
max_zoom=1.5,
|
||||||
|
mode=["area", "nearest"],
|
||||||
|
keep_size=False,
|
||||||
|
),
|
||||||
|
# Pad spatial dimensions to ensure a size of 512.
|
||||||
|
SpatialPadd(keys=["img", "label"], spatial_size=512),
|
||||||
|
# Randomly crop a region of interest of size 512.
|
||||||
|
RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
|
||||||
|
# Randomly flip the image and label along an axis.
|
||||||
|
RandAxisFlipd(keys=["img", "label"], prob=0.5),
|
||||||
|
# Randomly rotate the image and label by 90 degrees.
|
||||||
|
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=(0, 1)),
|
||||||
|
# Diversify intensities for selected cell regions.
|
||||||
|
IntensityDiversification(keys=["img", "label"], allow_missing_keys=True),
|
||||||
|
# Apply random Gaussian noise to the image.
|
||||||
|
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
|
||||||
|
# Randomly adjust the contrast of the image.
|
||||||
|
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
|
||||||
|
# Apply random Gaussian smoothing to the image.
|
||||||
|
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
|
||||||
|
# Randomly shift the histogram of the image.
|
||||||
|
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
|
||||||
|
# Apply random Gaussian sharpening to the image.
|
||||||
|
RandGaussianSharpend(keys=["img"], prob=0.25),
|
||||||
|
# Ensure that the data types are correct.
|
||||||
|
EnsureTyped(keys=["img", "label"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return train_transforms
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_transforms():
|
||||||
|
"""
|
||||||
|
Returns the transformation pipeline for validation data.
|
||||||
|
|
||||||
|
The validation pipeline includes the following steps:
|
||||||
|
1. Load image and label data (with missing keys allowed).
|
||||||
|
2. Normalize the image intensities.
|
||||||
|
3. Ensure the image and label are in channel-first format.
|
||||||
|
4. Scale image intensities.
|
||||||
|
5. Convert the data types to the desired format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compose: The composed transformation pipeline for validation.
|
||||||
|
"""
|
||||||
|
valid_transforms = Compose(
|
||||||
|
[
|
||||||
|
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
|
||||||
|
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
|
||||||
|
# Normalize the (H, W, C) image using the specified percentiles.
|
||||||
|
CustomNormalizeImaged(
|
||||||
|
keys=["img"],
|
||||||
|
allow_missing_keys=True,
|
||||||
|
channel_wise=False,
|
||||||
|
percentiles=[0.0, 99.5],
|
||||||
|
),
|
||||||
|
# Ensure both image and label are in channel-first format.
|
||||||
|
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
|
||||||
|
# Scale image intensities.
|
||||||
|
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
||||||
|
# Ensure that the data types are correct.
|
||||||
|
EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return valid_transforms
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_transforms():
|
||||||
|
"""
|
||||||
|
Returns the transformation pipeline for test data.
|
||||||
|
|
||||||
|
The test pipeline is similar to the validation pipeline and includes:
|
||||||
|
1. Load image and label data (with missing keys allowed).
|
||||||
|
2. Normalize the image intensities.
|
||||||
|
3. Ensure the image and label are in channel-first format.
|
||||||
|
4. Scale image intensities.
|
||||||
|
5. Convert the data types to the desired format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compose: The composed transformation pipeline for testing.
|
||||||
|
"""
|
||||||
|
test_transforms = Compose(
|
||||||
|
[
|
||||||
|
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
|
||||||
|
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
|
||||||
|
# Normalize the (H, W, C) image using the specified percentiles.
|
||||||
|
CustomNormalizeImaged(
|
||||||
|
keys=["img"],
|
||||||
|
allow_missing_keys=True,
|
||||||
|
channel_wise=False,
|
||||||
|
percentiles=[0.0, 99.5],
|
||||||
|
),
|
||||||
|
# Ensure both image and label are in channel-first format.
|
||||||
|
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
|
||||||
|
# Scale image intensities.
|
||||||
|
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
||||||
|
# Ensure that the data types are correct.
|
||||||
|
EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return test_transforms
|
||||||
|
|
||||||
|
|
||||||
|
def get_pred_transforms():
|
||||||
|
"""
|
||||||
|
Returns the transformation pipeline for prediction preprocessing.
|
||||||
|
|
||||||
|
The prediction pipeline includes the following steps:
|
||||||
|
1. Load the image data.
|
||||||
|
2. Normalize the image intensities.
|
||||||
|
3. Ensure the image is in channel-first format.
|
||||||
|
4. Scale image intensities.
|
||||||
|
5. Convert the image to the required tensor type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compose: The composed transformation pipeline for prediction.
|
||||||
|
"""
|
||||||
|
pred_transforms = Compose(
|
||||||
|
[
|
||||||
|
# Load the image data in (H, W, C) format (image loaded as image-only).
|
||||||
|
CustomLoadImage(image_only=True),
|
||||||
|
# Normalize the (H, W, C) image using the specified percentiles.
|
||||||
|
CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]),
|
||||||
|
# Ensure the image is in channel-first format.
|
||||||
|
EnsureChannelFirst(channel_dim=-1), # image shape: (C, H, W)
|
||||||
|
# Scale image intensities.
|
||||||
|
ScaleIntensity(),
|
||||||
|
# Convert the image to the required tensor type.
|
||||||
|
EnsureType(data_type="tensor"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return pred_transforms
|
@ -0,0 +1,192 @@
|
|||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, Sequence, Tuple, Union
|
||||||
|
from skimage.segmentation import find_boundaries
|
||||||
|
from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BoundaryExclusion", "IntensityDiversification"]
|
||||||
|
|
||||||
|
|
||||||
|
class BoundaryExclusion(MapTransform):
|
||||||
|
"""
|
||||||
|
Map the cell boundary pixel labels to the background class (0).
|
||||||
|
|
||||||
|
This transform processes a label image by first detecting boundaries of cell regions
|
||||||
|
and then excluding those boundary pixels by setting them to 0. However, it retains
|
||||||
|
the original cell label if the cell is too small (less than 14x14 pixels) or if the cell
|
||||||
|
touches the image boundary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, keys: Sequence[str] = ("label",), allow_missing_keys: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
keys (Sequence[str]): Keys in the input dictionary corresponding to the label image.
|
||||||
|
Default is ("label",).
|
||||||
|
allow_missing_keys (bool): If True, missing keys in the input will be ignored.
|
||||||
|
Default is False.
|
||||||
|
"""
|
||||||
|
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
|
||||||
|
|
||||||
|
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Apply the boundary exclusion transform to the label image.
|
||||||
|
|
||||||
|
The process involves:
|
||||||
|
1. Deep-copying the original label.
|
||||||
|
2. Finding boundaries using a thick mode with connectivity=1.
|
||||||
|
3. Setting the boundary pixels to background (0).
|
||||||
|
4. Restoring original labels for cells that are too small (< 14x14 pixels).
|
||||||
|
5. Ensuring that cells touching the image boundary are not excluded.
|
||||||
|
6. Assigning the transformed label back into the input dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Dict[str, np.ndarray]): Dictionary containing at least the "label" key with a label image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, np.ndarray]: The input dictionary with the "label" key updated after boundary exclusion.
|
||||||
|
"""
|
||||||
|
# Retrieve the original label image.
|
||||||
|
label_original: np.ndarray = data["label"]
|
||||||
|
# Create a deep copy of the original label for processing.
|
||||||
|
label: np.ndarray = copy.deepcopy(label_original)
|
||||||
|
# Detect cell boundaries with a thick boundary.
|
||||||
|
boundary: np.ndarray = find_boundaries(label, connectivity=1, mode="thick")
|
||||||
|
# Exclude boundary pixels by setting them to 0.
|
||||||
|
label[boundary] = 0
|
||||||
|
|
||||||
|
# Create a new label copy for selective exclusion.
|
||||||
|
new_label: np.ndarray = copy.deepcopy(label_original)
|
||||||
|
new_label[label == 0] = 0
|
||||||
|
|
||||||
|
# Obtain unique cell indices and their pixel counts.
|
||||||
|
cell_idx, cell_counts = np.unique(label_original, return_counts=True)
|
||||||
|
|
||||||
|
# If a cell is too small (< 196 pixels, approx. 14x14), restore its original label.
|
||||||
|
for k in range(len(cell_counts)):
|
||||||
|
if cell_counts[k] < 196:
|
||||||
|
new_label[label_original == cell_idx[k]] = cell_idx[k]
|
||||||
|
|
||||||
|
# Ensure that cells at the image boundaries are not excluded.
|
||||||
|
# Get the dimensions of the label image.
|
||||||
|
H, W, _ = label_original.shape
|
||||||
|
# Create a binary mask with a border of 2 pixels preserved.
|
||||||
|
bd: np.ndarray = np.zeros_like(label_original, dtype=label.dtype)
|
||||||
|
bd[2 : H - 2, 2 : W - 2, :] = 1
|
||||||
|
# Combine the preserved boundaries with the new label.
|
||||||
|
new_label += label_original * bd
|
||||||
|
|
||||||
|
# Update the input dictionary with the transformed label.
|
||||||
|
data["label"] = new_label
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class IntensityDiversification(MapTransform):
|
||||||
|
"""
|
||||||
|
Randomly rescale the intensity of cell pixels.
|
||||||
|
|
||||||
|
This transform selects a subset of cells (based on the change_cell_ratio) and
|
||||||
|
applies a random intensity scaling to those cells. The intensity scaling is performed
|
||||||
|
using the RandScaleIntensity transform from MONAI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
keys: Sequence[str] = ("img",),
|
||||||
|
change_cell_ratio: float = 0.4,
|
||||||
|
scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7),
|
||||||
|
allow_missing_keys: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
keys (Sequence[str]): Keys in the input dictionary corresponding to the image.
|
||||||
|
Default is ("img",).
|
||||||
|
change_cell_ratio (float): Ratio of cells to apply the intensity scaling.
|
||||||
|
For example, 0.4 means 40% of the cells will be transformed.
|
||||||
|
Default is 0.4.
|
||||||
|
scale_factors (Sequence[float]): Factors used for random intensity scaling.
|
||||||
|
Default is (0.0, 0.7).
|
||||||
|
allow_missing_keys (bool): If True, missing keys in the input will be ignored.
|
||||||
|
Default is False.
|
||||||
|
"""
|
||||||
|
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
|
||||||
|
self.change_cell_ratio: float = change_cell_ratio
|
||||||
|
# Compose a random intensity scaling transform with 100% probability.
|
||||||
|
self.randscale_intensity = Compose([RandScaleIntensity(prob=1.0, factors=scale_factors)])
|
||||||
|
|
||||||
|
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Apply a cell-wise intensity diversification transform to an input image.
|
||||||
|
|
||||||
|
This function modifies the image by randomly selecting a subset of labeled cell regions
|
||||||
|
(per channel) and applying a random intensity scaling operation exclusively to those regions.
|
||||||
|
The transformation is performed independently on each channel of the image.
|
||||||
|
|
||||||
|
The steps are as follows:
|
||||||
|
1. Extract the label image for all channels (expected shape: (C, H, W)).
|
||||||
|
2. For each channel, determine the unique cell IDs, excluding the background (labeled as 0).
|
||||||
|
3. Raise a ValueError if no unique objects are found in the current label channel.
|
||||||
|
4. Compute the number of cells to modify based on the provided change_cell_ratio.
|
||||||
|
5. Randomly select the corresponding cell IDs for intensity modification.
|
||||||
|
6. Create a binary mask that highlights the selected cell regions.
|
||||||
|
7. Separate the image channel into two parts: one that remains unchanged and one that is
|
||||||
|
subjected to random intensity scaling.
|
||||||
|
8. Apply the random intensity scaling to the selected regions.
|
||||||
|
9. Combine the unchanged and modified parts to update the image for that channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Dict[str, np.ndarray]): A dictionary containing:
|
||||||
|
- "img": The original image array.
|
||||||
|
- "label": The corresponding cell label image array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, np.ndarray]: The updated dictionary with the "img" key modified after applying
|
||||||
|
the intensity transformation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no unique cell objects are found in a label channel.
|
||||||
|
"""
|
||||||
|
# Extract the label information for all channels.
|
||||||
|
# The label array has dimensions (C, H, W), where C is the number of channels.
|
||||||
|
label = data["label"] # shape: (C, H, W)
|
||||||
|
|
||||||
|
# Process each channel independently.
|
||||||
|
for c in range(label.shape[0]):
|
||||||
|
# Extract the label and corresponding image channel for the current channel.
|
||||||
|
channel_label = label[c]
|
||||||
|
img_channel = data["img"][c]
|
||||||
|
|
||||||
|
# Retrieve all unique cell IDs in the current channel.
|
||||||
|
# Exclude the background (0) from these IDs.
|
||||||
|
cell_ids = np.unique(channel_label)
|
||||||
|
cell_ids = cell_ids[cell_ids > 0]
|
||||||
|
|
||||||
|
# If there are no unique cell objects in this channel, raise an exception.
|
||||||
|
if cell_ids.size == 0:
|
||||||
|
raise ValueError(f"No unique objects found in the label mask for channel {c}")
|
||||||
|
|
||||||
|
# Determine the number of cells to modify using the change_cell_ratio.
|
||||||
|
change_count = int(len(cell_ids) * self.change_cell_ratio)
|
||||||
|
|
||||||
|
# Randomly select a subset of cell IDs for intensity modification.
|
||||||
|
selected = np.random.choice(cell_ids, change_count, replace=False)
|
||||||
|
|
||||||
|
# Create a binary mask for the current channel:
|
||||||
|
# - Pixels corresponding to the selected cell IDs are set to 1.
|
||||||
|
# - All other pixels are set to 0.
|
||||||
|
mask = np.isin(channel_label, selected).astype(np.float32)
|
||||||
|
|
||||||
|
# Separate the image channel into two components:
|
||||||
|
# 1. img_orig: The portion of the image that remains unchanged.
|
||||||
|
# 2. img_changed: The portion that will have its intensity altered.
|
||||||
|
img_orig = (1 - mask) * img_channel
|
||||||
|
img_changed = mask * img_channel
|
||||||
|
|
||||||
|
# Apply a random intensity scaling transformation to the selected regions.
|
||||||
|
img_changed = self.randscale_intensity(img_changed)
|
||||||
|
|
||||||
|
# Combine the unchanged and modified parts to update the image channel.
|
||||||
|
data["img"][c] = img_orig + img_changed
|
||||||
|
|
||||||
|
return data
|
@ -0,0 +1,202 @@
|
|||||||
|
import numpy as np
|
||||||
|
import tifffile as tif
|
||||||
|
import skimage.io as io
|
||||||
|
from typing import List, Optional, Sequence, Type, Union
|
||||||
|
|
||||||
|
from monai.utils.enums import PostFix
|
||||||
|
from monai.utils.module import optional_import
|
||||||
|
from monai.utils.misc import ensure_tuple, ensure_tuple_rep
|
||||||
|
from monai.data.utils import is_supported_format
|
||||||
|
from monai.data.image_reader import ImageReader, NumpyReader
|
||||||
|
from monai.transforms import LoadImage, LoadImaged # type: ignore
|
||||||
|
from monai.config.type_definitions import DtypeLike, PathLike, KeysCollection
|
||||||
|
|
||||||
|
|
||||||
|
# Default value for metadata postfix
|
||||||
|
DEFAULT_POST_FIX = PostFix.meta()
|
||||||
|
|
||||||
|
# Try to import ITK library; if not available, has_itk will be False
|
||||||
|
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CustomLoadImage", # Basic image loader
|
||||||
|
"CustomLoadImaged", # Dictionary-based image loader
|
||||||
|
"CustomLoadImageD", # Dictionary-based image loader
|
||||||
|
"CustomLoadImageDict", # Dictionary-based image loader
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLoadImage(LoadImage):
|
||||||
|
"""
|
||||||
|
Class for loading one or multiple images from a given path.
|
||||||
|
|
||||||
|
If a reader is not specified, the appropriate file reading method is automatically chosen
|
||||||
|
based on the file extension. Priority:
|
||||||
|
- Reader passed by the user at runtime.
|
||||||
|
- Reader specified in the constructor.
|
||||||
|
- Registered readers (from last to first).
|
||||||
|
- Standard readers for different formats (e.g., NibabelReader for nii, PILReader for png/jpg, etc.).
|
||||||
|
|
||||||
|
[Note] Here, the original ITKReader is replaced by the universal reader UniversalImageReader.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reader: Optional[Union[ImageReader, Type[ImageReader], str]] = None,
|
||||||
|
image_only: bool = False,
|
||||||
|
dtype: DtypeLike = np.float32,
|
||||||
|
ensure_channel_first: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
reader=reader,
|
||||||
|
image_only=image_only,
|
||||||
|
dtype=dtype,
|
||||||
|
ensure_channel_first=ensure_channel_first,
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
|
# Clear the list of registered readers
|
||||||
|
self.readers = []
|
||||||
|
# Register the universal reader that handles TIFF, PNG, JPG, BMP, etc.
|
||||||
|
self.register(UniversalImageReader(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLoadImaged(LoadImaged):
|
||||||
|
"""
|
||||||
|
Dictionary-based image loader.
|
||||||
|
|
||||||
|
Wraps image loading with CustomLoadImage and allows processing of data represented as a dictionary,
|
||||||
|
where keys point to file paths.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
keys: KeysCollection,
|
||||||
|
reader: Optional[Union[Type[ImageReader], str]] = None,
|
||||||
|
dtype: DtypeLike = np.float32,
|
||||||
|
meta_keys: Optional[KeysCollection] = None,
|
||||||
|
meta_key_postfix: str = DEFAULT_POST_FIX,
|
||||||
|
overwriting: bool = False,
|
||||||
|
image_only: bool = False,
|
||||||
|
ensure_channel_first: bool = False,
|
||||||
|
simple_keys: bool = False,
|
||||||
|
allow_missing_keys: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
keys=keys,
|
||||||
|
reader=reader,
|
||||||
|
dtype=dtype,
|
||||||
|
meta_keys=meta_keys,
|
||||||
|
meta_key_postfix=meta_key_postfix,
|
||||||
|
overwriting=overwriting,
|
||||||
|
image_only=image_only,
|
||||||
|
ensure_channel_first=ensure_channel_first,
|
||||||
|
simple_keys=simple_keys,
|
||||||
|
allow_missing_keys=allow_missing_keys,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Assign the custom image loader
|
||||||
|
self._loader = CustomLoadImage(
|
||||||
|
reader=reader,
|
||||||
|
image_only=image_only,
|
||||||
|
dtype=dtype,
|
||||||
|
ensure_channel_first=ensure_channel_first,
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
|
# Ensure that meta_key_postfix is a string
|
||||||
|
if not isinstance(meta_key_postfix, str):
|
||||||
|
raise TypeError(
|
||||||
|
f"meta_key_postfix must be a string, but got {type(meta_key_postfix).__name__}."
|
||||||
|
)
|
||||||
|
# If meta_keys are not provided, create a tuple of None for each key
|
||||||
|
self.meta_keys = (
|
||||||
|
ensure_tuple_rep(None, len(self.keys))
|
||||||
|
if meta_keys is None
|
||||||
|
else ensure_tuple(meta_keys)
|
||||||
|
)
|
||||||
|
# Check that the number of meta_keys matches the number of keys
|
||||||
|
if len(self.keys) != len(self.meta_keys):
|
||||||
|
raise ValueError("meta_keys must have the same length as keys.")
|
||||||
|
# Assign each key its corresponding metadata postfix
|
||||||
|
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
|
||||||
|
self.overwriting = overwriting
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalImageReader(NumpyReader):
|
||||||
|
"""
|
||||||
|
Universal image reader for TIFF, PNG, JPG, BMP, etc.
|
||||||
|
|
||||||
|
Uses:
|
||||||
|
- tifffile for reading TIFF files.
|
||||||
|
- ITK (if available) for reading other formats.
|
||||||
|
- skimage.io for reading if the previous methods fail.
|
||||||
|
|
||||||
|
The image is loaded with its original number of channels (layers) without forced modifications
|
||||||
|
(e.g., repeating or cropping channels).
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self, channel_dim: Optional[int] = None, **kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(channel_dim=channel_dim, **kwargs)
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.channel_dim = channel_dim
|
||||||
|
|
||||||
|
def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the file format is supported for reading.
|
||||||
|
|
||||||
|
Supported extensions: tif, tiff, png, jpg, bmp, jpeg.
|
||||||
|
"""
|
||||||
|
suffixes: Sequence[str] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg"]
|
||||||
|
return has_itk or is_supported_format(filename, suffixes)
|
||||||
|
|
||||||
|
def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs):
|
||||||
|
"""
|
||||||
|
Read image(s) from the given path.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
data: A file path or a sequence of file paths.
|
||||||
|
kwargs: Additional parameters for reading.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A single image or a list of images depending on the number of paths provided.
|
||||||
|
"""
|
||||||
|
images: List[np.ndarray] = [] # List to store the loaded images
|
||||||
|
|
||||||
|
# Convert data to a tuple to support multiple files
|
||||||
|
filenames: Sequence[PathLike] = ensure_tuple(data)
|
||||||
|
# Merge parameters provided in the constructor and the read() method
|
||||||
|
kwargs_ = self.kwargs.copy()
|
||||||
|
kwargs_.update(kwargs)
|
||||||
|
|
||||||
|
for name in filenames:
|
||||||
|
# Convert file name to string
|
||||||
|
name = f"{name}"
|
||||||
|
# If the file has a .tif or .tiff extension (case-insensitive), use tifffile for reading
|
||||||
|
if name.lower().endswith((".tif", ".tiff")):
|
||||||
|
img_array = tif.imread(name)
|
||||||
|
else:
|
||||||
|
# Attempt to read the image using ITK (if available)
|
||||||
|
try:
|
||||||
|
img_itk = itk.imread(name, **kwargs_)
|
||||||
|
img_array = itk.array_view_from_image(img_itk, keep_axes=False)
|
||||||
|
except Exception:
|
||||||
|
# If ITK fails, use skimage.io for reading
|
||||||
|
img_array = io.imread(name)
|
||||||
|
|
||||||
|
# Check the number of dimensions (axes) of the loaded image
|
||||||
|
if img_array.ndim == 2:
|
||||||
|
# If the image is 2D (height, width), add a new axis at the end to represent the channel
|
||||||
|
img_array = np.expand_dims(img_array, axis=-1)
|
||||||
|
|
||||||
|
images.append(img_array)
|
||||||
|
|
||||||
|
# Return a single image if only one file was provided, otherwise return a list of images
|
||||||
|
return images if len(filenames) > 1 else images[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
CustomLoadImageD = CustomLoadImageDict = CustomLoadImaged
|
@ -0,0 +1,139 @@
|
|||||||
|
import numpy as np
|
||||||
|
from skimage import exposure
|
||||||
|
from monai.config.type_definitions import KeysCollection
|
||||||
|
from monai.transforms.transform import Transform, MapTransform
|
||||||
|
from typing import Dict, Hashable, Mapping, Sequence
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CustomNormalizeImage",
|
||||||
|
"CustomNormalizeImaged",
|
||||||
|
"CustomNormalizeImageD",
|
||||||
|
"CustomNormalizeImageDict",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomNormalizeImage(Transform):
|
||||||
|
"""
|
||||||
|
Normalize the image by rescaling intensity values based on specified percentiles.
|
||||||
|
|
||||||
|
The normalization can be applied either on the entire image or channel-wise.
|
||||||
|
If the image is 2D (only height and width), a channel dimension is added for consistency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, percentiles: Sequence[float] = (0, 99), channel_wise: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling.
|
||||||
|
Default is (0, 99).
|
||||||
|
channel_wise (bool): Whether to apply normalization on each channel individually.
|
||||||
|
Default is False.
|
||||||
|
"""
|
||||||
|
self.lower, self.upper = percentiles # Unpack the lower and upper percentile values.
|
||||||
|
self.channel_wise = channel_wise # Flag for channel-wise normalization.
|
||||||
|
|
||||||
|
def _normalize(self, img: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Rescale image intensity using non-zero values for percentile calculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): A numpy array representing a single-channel image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: A uint8 numpy array with rescaled intensity values.
|
||||||
|
"""
|
||||||
|
# Extract non-zero values to avoid background influence.
|
||||||
|
non_zero_vals = img[np.nonzero(img)]
|
||||||
|
# Calculate the specified percentiles from the non-zero values.
|
||||||
|
computed_percentiles: np.ndarray = np.percentile(non_zero_vals, [self.lower, self.upper])
|
||||||
|
# Rescale the intensity values to the full uint8 range.
|
||||||
|
img_norm = exposure.rescale_intensity(
|
||||||
|
img, in_range=(computed_percentiles[0], computed_percentiles[1]), out_range="uint8" # type: ignore
|
||||||
|
)
|
||||||
|
return img_norm.astype(np.uint8)
|
||||||
|
|
||||||
|
def __call__(self, img: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply normalization to the input image.
|
||||||
|
|
||||||
|
If the input image is 2D (height, width), a channel dimension is added.
|
||||||
|
Depending on the 'channel_wise' flag, normalization is applied either to each channel individually or to the entire image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): Input image as a numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Normalized image as a numpy array.
|
||||||
|
"""
|
||||||
|
# Check if the image is 2D (grayscale). If so, add a new axis for the channel.
|
||||||
|
if img.ndim == 2:
|
||||||
|
img = np.expand_dims(img, axis=-1) # Added channel dimension for consistency.
|
||||||
|
|
||||||
|
if self.channel_wise:
|
||||||
|
# Initialize an empty array with the same shape as the input image to store normalized channels.
|
||||||
|
normalized_img = np.zeros(img.shape, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Process each channel individually.
|
||||||
|
for i in range(img.shape[-1]):
|
||||||
|
channel_img: np.ndarray = img[:, :, i]
|
||||||
|
|
||||||
|
# Only normalize the channel if there are non-zero values present.
|
||||||
|
if np.count_nonzero(channel_img) > 0:
|
||||||
|
normalized_img[:, :, i] = self._normalize(channel_img)
|
||||||
|
|
||||||
|
img = normalized_img
|
||||||
|
else:
|
||||||
|
# Apply normalization to the entire image.
|
||||||
|
img = self._normalize(img)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class CustomNormalizeImaged(MapTransform):
|
||||||
|
"""
|
||||||
|
Dictionary-based wrapper for CustomNormalizeImage.
|
||||||
|
|
||||||
|
This transform applies normalization to one or more images contained in a dictionary,
|
||||||
|
where the keys point to the image data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
keys: KeysCollection,
|
||||||
|
percentiles: Sequence[float] = (1, 99),
|
||||||
|
channel_wise: bool = False,
|
||||||
|
allow_missing_keys: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
keys (KeysCollection): Keys identifying the image entries in the dictionary.
|
||||||
|
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling.
|
||||||
|
Default is (1, 99).
|
||||||
|
channel_wise (bool): Whether to apply normalization on each channel individually.
|
||||||
|
Default is False.
|
||||||
|
allow_missing_keys (bool): If True, missing keys in the dictionary will be ignored.
|
||||||
|
Default is False.
|
||||||
|
"""
|
||||||
|
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
|
||||||
|
# Create an instance of the normalization transform with specified parameters.
|
||||||
|
self.normalizer: CustomNormalizeImage = CustomNormalizeImage(percentiles, channel_wise)
|
||||||
|
|
||||||
|
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Apply the normalization transform to each image in the input dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Mapping[Hashable, np.ndarray]): A dictionary mapping keys to numpy arrays representing images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[Hashable, np.ndarray]: A dictionary with the same keys where the images have been normalized.
|
||||||
|
"""
|
||||||
|
# Copy the input dictionary to avoid modifying the original data.
|
||||||
|
d: Dict[Hashable, np.ndarray] = dict(data)
|
||||||
|
# Iterate over each key specified in the transform and normalize the corresponding image.
|
||||||
|
for key in self.keys:
|
||||||
|
d[key] = self.normalizer(d[key])
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
# Create aliases for the dictionary-based normalization transform.
|
||||||
|
CustomNormalizeImageD = CustomNormalizeImageDict = CustomNormalizeImaged
|
Loading…
Reference in new issue