parent
33ce003657
commit
78f97a72a2
@ -0,0 +1,14 @@
|
||||
from .transforms import (
|
||||
get_train_transforms,
|
||||
get_valid_transforms,
|
||||
get_test_transforms,
|
||||
get_pred_transforms
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_train_transforms",
|
||||
"get_valid_transforms",
|
||||
"get_test_transforms",
|
||||
"get_pred_transforms",
|
||||
]
|
@ -0,0 +1,184 @@
|
||||
from .cell_aware import IntensityDiversification
|
||||
from .load_image import CustomLoadImage, CustomLoadImaged
|
||||
from .normalize_image import CustomNormalizeImage, CustomNormalizeImaged
|
||||
|
||||
from monai.transforms import * # type: ignore
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_train_transforms",
|
||||
"get_valid_transforms",
|
||||
"get_test_transforms",
|
||||
"get_pred_transforms",
|
||||
]
|
||||
|
||||
|
||||
def get_train_transforms():
|
||||
"""
|
||||
Returns the transformation pipeline for training data.
|
||||
|
||||
The training pipeline applies a series of image and label preprocessing steps:
|
||||
1. Load image and label data.
|
||||
2. Normalize the image intensities.
|
||||
3. Ensure the image and label have channel-first format.
|
||||
4. Scale image intensities.
|
||||
5. Apply spatial transformations (zoom, padding, cropping, flipping, and rotation).
|
||||
6. Diversify intensities for selected cell regions.
|
||||
7. Apply additional intensity perturbations (noise, contrast, smoothing, histogram shift, and sharpening).
|
||||
8. Convert the data types to the desired format.
|
||||
|
||||
Returns:
|
||||
Compose: The composed transformation pipeline for training.
|
||||
"""
|
||||
train_transforms = Compose(
|
||||
[
|
||||
# Load image and label data in (H, W, C) format (image loaded as image-only).
|
||||
CustomLoadImaged(keys=["img", "label"], image_only=True),
|
||||
# Normalize the (H, W, C) image using the specified percentiles.
|
||||
CustomNormalizeImaged(
|
||||
keys=["img"],
|
||||
allow_missing_keys=True,
|
||||
channel_wise=False,
|
||||
percentiles=[0.0, 99.5],
|
||||
),
|
||||
# Ensure both image and label are in channel-first format.
|
||||
EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
|
||||
# Scale image intensities (do not scale the label).
|
||||
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
||||
# Apply random zoom to both image and label.
|
||||
RandZoomd(
|
||||
keys=["img", "label"],
|
||||
prob=0.5,
|
||||
min_zoom=0.25,
|
||||
max_zoom=1.5,
|
||||
mode=["area", "nearest"],
|
||||
keep_size=False,
|
||||
),
|
||||
# Pad spatial dimensions to ensure a size of 512.
|
||||
SpatialPadd(keys=["img", "label"], spatial_size=512),
|
||||
# Randomly crop a region of interest of size 512.
|
||||
RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
|
||||
# Randomly flip the image and label along an axis.
|
||||
RandAxisFlipd(keys=["img", "label"], prob=0.5),
|
||||
# Randomly rotate the image and label by 90 degrees.
|
||||
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=(0, 1)),
|
||||
# Diversify intensities for selected cell regions.
|
||||
IntensityDiversification(keys=["img", "label"], allow_missing_keys=True),
|
||||
# Apply random Gaussian noise to the image.
|
||||
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
|
||||
# Randomly adjust the contrast of the image.
|
||||
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
|
||||
# Apply random Gaussian smoothing to the image.
|
||||
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
|
||||
# Randomly shift the histogram of the image.
|
||||
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
|
||||
# Apply random Gaussian sharpening to the image.
|
||||
RandGaussianSharpend(keys=["img"], prob=0.25),
|
||||
# Ensure that the data types are correct.
|
||||
EnsureTyped(keys=["img", "label"]),
|
||||
]
|
||||
)
|
||||
return train_transforms
|
||||
|
||||
|
||||
def get_valid_transforms():
|
||||
"""
|
||||
Returns the transformation pipeline for validation data.
|
||||
|
||||
The validation pipeline includes the following steps:
|
||||
1. Load image and label data (with missing keys allowed).
|
||||
2. Normalize the image intensities.
|
||||
3. Ensure the image and label are in channel-first format.
|
||||
4. Scale image intensities.
|
||||
5. Convert the data types to the desired format.
|
||||
|
||||
Returns:
|
||||
Compose: The composed transformation pipeline for validation.
|
||||
"""
|
||||
valid_transforms = Compose(
|
||||
[
|
||||
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
|
||||
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
|
||||
# Normalize the (H, W, C) image using the specified percentiles.
|
||||
CustomNormalizeImaged(
|
||||
keys=["img"],
|
||||
allow_missing_keys=True,
|
||||
channel_wise=False,
|
||||
percentiles=[0.0, 99.5],
|
||||
),
|
||||
# Ensure both image and label are in channel-first format.
|
||||
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
|
||||
# Scale image intensities.
|
||||
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
||||
# Ensure that the data types are correct.
|
||||
EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
|
||||
]
|
||||
)
|
||||
return valid_transforms
|
||||
|
||||
|
||||
def get_test_transforms():
|
||||
"""
|
||||
Returns the transformation pipeline for test data.
|
||||
|
||||
The test pipeline is similar to the validation pipeline and includes:
|
||||
1. Load image and label data (with missing keys allowed).
|
||||
2. Normalize the image intensities.
|
||||
3. Ensure the image and label are in channel-first format.
|
||||
4. Scale image intensities.
|
||||
5. Convert the data types to the desired format.
|
||||
|
||||
Returns:
|
||||
Compose: The composed transformation pipeline for testing.
|
||||
"""
|
||||
test_transforms = Compose(
|
||||
[
|
||||
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
|
||||
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
|
||||
# Normalize the (H, W, C) image using the specified percentiles.
|
||||
CustomNormalizeImaged(
|
||||
keys=["img"],
|
||||
allow_missing_keys=True,
|
||||
channel_wise=False,
|
||||
percentiles=[0.0, 99.5],
|
||||
),
|
||||
# Ensure both image and label are in channel-first format.
|
||||
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
|
||||
# Scale image intensities.
|
||||
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
||||
# Ensure that the data types are correct.
|
||||
EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
|
||||
]
|
||||
)
|
||||
return test_transforms
|
||||
|
||||
|
||||
def get_pred_transforms():
|
||||
"""
|
||||
Returns the transformation pipeline for prediction preprocessing.
|
||||
|
||||
The prediction pipeline includes the following steps:
|
||||
1. Load the image data.
|
||||
2. Normalize the image intensities.
|
||||
3. Ensure the image is in channel-first format.
|
||||
4. Scale image intensities.
|
||||
5. Convert the image to the required tensor type.
|
||||
|
||||
Returns:
|
||||
Compose: The composed transformation pipeline for prediction.
|
||||
"""
|
||||
pred_transforms = Compose(
|
||||
[
|
||||
# Load the image data in (H, W, C) format (image loaded as image-only).
|
||||
CustomLoadImage(image_only=True),
|
||||
# Normalize the (H, W, C) image using the specified percentiles.
|
||||
CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]),
|
||||
# Ensure the image is in channel-first format.
|
||||
EnsureChannelFirst(channel_dim=-1), # image shape: (C, H, W)
|
||||
# Scale image intensities.
|
||||
ScaleIntensity(),
|
||||
# Convert the image to the required tensor type.
|
||||
EnsureType(data_type="tensor"),
|
||||
]
|
||||
)
|
||||
return pred_transforms
|
@ -0,0 +1,192 @@
|
||||
import copy
|
||||
import numpy as np
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
from skimage.segmentation import find_boundaries
|
||||
from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore
|
||||
|
||||
|
||||
__all__ = ["BoundaryExclusion", "IntensityDiversification"]
|
||||
|
||||
|
||||
class BoundaryExclusion(MapTransform):
|
||||
"""
|
||||
Map the cell boundary pixel labels to the background class (0).
|
||||
|
||||
This transform processes a label image by first detecting boundaries of cell regions
|
||||
and then excluding those boundary pixels by setting them to 0. However, it retains
|
||||
the original cell label if the cell is too small (less than 14x14 pixels) or if the cell
|
||||
touches the image boundary.
|
||||
"""
|
||||
|
||||
def __init__(self, keys: Sequence[str] = ("label",), allow_missing_keys: bool = False) -> None:
|
||||
"""
|
||||
Args:
|
||||
keys (Sequence[str]): Keys in the input dictionary corresponding to the label image.
|
||||
Default is ("label",).
|
||||
allow_missing_keys (bool): If True, missing keys in the input will be ignored.
|
||||
Default is False.
|
||||
"""
|
||||
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
|
||||
|
||||
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Apply the boundary exclusion transform to the label image.
|
||||
|
||||
The process involves:
|
||||
1. Deep-copying the original label.
|
||||
2. Finding boundaries using a thick mode with connectivity=1.
|
||||
3. Setting the boundary pixels to background (0).
|
||||
4. Restoring original labels for cells that are too small (< 14x14 pixels).
|
||||
5. Ensuring that cells touching the image boundary are not excluded.
|
||||
6. Assigning the transformed label back into the input dictionary.
|
||||
|
||||
Args:
|
||||
data (Dict[str, np.ndarray]): Dictionary containing at least the "label" key with a label image.
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: The input dictionary with the "label" key updated after boundary exclusion.
|
||||
"""
|
||||
# Retrieve the original label image.
|
||||
label_original: np.ndarray = data["label"]
|
||||
# Create a deep copy of the original label for processing.
|
||||
label: np.ndarray = copy.deepcopy(label_original)
|
||||
# Detect cell boundaries with a thick boundary.
|
||||
boundary: np.ndarray = find_boundaries(label, connectivity=1, mode="thick")
|
||||
# Exclude boundary pixels by setting them to 0.
|
||||
label[boundary] = 0
|
||||
|
||||
# Create a new label copy for selective exclusion.
|
||||
new_label: np.ndarray = copy.deepcopy(label_original)
|
||||
new_label[label == 0] = 0
|
||||
|
||||
# Obtain unique cell indices and their pixel counts.
|
||||
cell_idx, cell_counts = np.unique(label_original, return_counts=True)
|
||||
|
||||
# If a cell is too small (< 196 pixels, approx. 14x14), restore its original label.
|
||||
for k in range(len(cell_counts)):
|
||||
if cell_counts[k] < 196:
|
||||
new_label[label_original == cell_idx[k]] = cell_idx[k]
|
||||
|
||||
# Ensure that cells at the image boundaries are not excluded.
|
||||
# Get the dimensions of the label image.
|
||||
H, W, _ = label_original.shape
|
||||
# Create a binary mask with a border of 2 pixels preserved.
|
||||
bd: np.ndarray = np.zeros_like(label_original, dtype=label.dtype)
|
||||
bd[2 : H - 2, 2 : W - 2, :] = 1
|
||||
# Combine the preserved boundaries with the new label.
|
||||
new_label += label_original * bd
|
||||
|
||||
# Update the input dictionary with the transformed label.
|
||||
data["label"] = new_label
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class IntensityDiversification(MapTransform):
|
||||
"""
|
||||
Randomly rescale the intensity of cell pixels.
|
||||
|
||||
This transform selects a subset of cells (based on the change_cell_ratio) and
|
||||
applies a random intensity scaling to those cells. The intensity scaling is performed
|
||||
using the RandScaleIntensity transform from MONAI.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keys: Sequence[str] = ("img",),
|
||||
change_cell_ratio: float = 0.4,
|
||||
scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7),
|
||||
allow_missing_keys: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
keys (Sequence[str]): Keys in the input dictionary corresponding to the image.
|
||||
Default is ("img",).
|
||||
change_cell_ratio (float): Ratio of cells to apply the intensity scaling.
|
||||
For example, 0.4 means 40% of the cells will be transformed.
|
||||
Default is 0.4.
|
||||
scale_factors (Sequence[float]): Factors used for random intensity scaling.
|
||||
Default is (0.0, 0.7).
|
||||
allow_missing_keys (bool): If True, missing keys in the input will be ignored.
|
||||
Default is False.
|
||||
"""
|
||||
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
|
||||
self.change_cell_ratio: float = change_cell_ratio
|
||||
# Compose a random intensity scaling transform with 100% probability.
|
||||
self.randscale_intensity = Compose([RandScaleIntensity(prob=1.0, factors=scale_factors)])
|
||||
|
||||
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Apply a cell-wise intensity diversification transform to an input image.
|
||||
|
||||
This function modifies the image by randomly selecting a subset of labeled cell regions
|
||||
(per channel) and applying a random intensity scaling operation exclusively to those regions.
|
||||
The transformation is performed independently on each channel of the image.
|
||||
|
||||
The steps are as follows:
|
||||
1. Extract the label image for all channels (expected shape: (C, H, W)).
|
||||
2. For each channel, determine the unique cell IDs, excluding the background (labeled as 0).
|
||||
3. Raise a ValueError if no unique objects are found in the current label channel.
|
||||
4. Compute the number of cells to modify based on the provided change_cell_ratio.
|
||||
5. Randomly select the corresponding cell IDs for intensity modification.
|
||||
6. Create a binary mask that highlights the selected cell regions.
|
||||
7. Separate the image channel into two parts: one that remains unchanged and one that is
|
||||
subjected to random intensity scaling.
|
||||
8. Apply the random intensity scaling to the selected regions.
|
||||
9. Combine the unchanged and modified parts to update the image for that channel.
|
||||
|
||||
Args:
|
||||
data (Dict[str, np.ndarray]): A dictionary containing:
|
||||
- "img": The original image array.
|
||||
- "label": The corresponding cell label image array.
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: The updated dictionary with the "img" key modified after applying
|
||||
the intensity transformation.
|
||||
|
||||
Raises:
|
||||
ValueError: If no unique cell objects are found in a label channel.
|
||||
"""
|
||||
# Extract the label information for all channels.
|
||||
# The label array has dimensions (C, H, W), where C is the number of channels.
|
||||
label = data["label"] # shape: (C, H, W)
|
||||
|
||||
# Process each channel independently.
|
||||
for c in range(label.shape[0]):
|
||||
# Extract the label and corresponding image channel for the current channel.
|
||||
channel_label = label[c]
|
||||
img_channel = data["img"][c]
|
||||
|
||||
# Retrieve all unique cell IDs in the current channel.
|
||||
# Exclude the background (0) from these IDs.
|
||||
cell_ids = np.unique(channel_label)
|
||||
cell_ids = cell_ids[cell_ids > 0]
|
||||
|
||||
# If there are no unique cell objects in this channel, raise an exception.
|
||||
if cell_ids.size == 0:
|
||||
raise ValueError(f"No unique objects found in the label mask for channel {c}")
|
||||
|
||||
# Determine the number of cells to modify using the change_cell_ratio.
|
||||
change_count = int(len(cell_ids) * self.change_cell_ratio)
|
||||
|
||||
# Randomly select a subset of cell IDs for intensity modification.
|
||||
selected = np.random.choice(cell_ids, change_count, replace=False)
|
||||
|
||||
# Create a binary mask for the current channel:
|
||||
# - Pixels corresponding to the selected cell IDs are set to 1.
|
||||
# - All other pixels are set to 0.
|
||||
mask = np.isin(channel_label, selected).astype(np.float32)
|
||||
|
||||
# Separate the image channel into two components:
|
||||
# 1. img_orig: The portion of the image that remains unchanged.
|
||||
# 2. img_changed: The portion that will have its intensity altered.
|
||||
img_orig = (1 - mask) * img_channel
|
||||
img_changed = mask * img_channel
|
||||
|
||||
# Apply a random intensity scaling transformation to the selected regions.
|
||||
img_changed = self.randscale_intensity(img_changed)
|
||||
|
||||
# Combine the unchanged and modified parts to update the image channel.
|
||||
data["img"][c] = img_orig + img_changed
|
||||
|
||||
return data
|
@ -0,0 +1,202 @@
|
||||
import numpy as np
|
||||
import tifffile as tif
|
||||
import skimage.io as io
|
||||
from typing import List, Optional, Sequence, Type, Union
|
||||
|
||||
from monai.utils.enums import PostFix
|
||||
from monai.utils.module import optional_import
|
||||
from monai.utils.misc import ensure_tuple, ensure_tuple_rep
|
||||
from monai.data.utils import is_supported_format
|
||||
from monai.data.image_reader import ImageReader, NumpyReader
|
||||
from monai.transforms import LoadImage, LoadImaged # type: ignore
|
||||
from monai.config.type_definitions import DtypeLike, PathLike, KeysCollection
|
||||
|
||||
|
||||
# Default value for metadata postfix
|
||||
DEFAULT_POST_FIX = PostFix.meta()
|
||||
|
||||
# Try to import ITK library; if not available, has_itk will be False
|
||||
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CustomLoadImage", # Basic image loader
|
||||
"CustomLoadImaged", # Dictionary-based image loader
|
||||
"CustomLoadImageD", # Dictionary-based image loader
|
||||
"CustomLoadImageDict", # Dictionary-based image loader
|
||||
]
|
||||
|
||||
|
||||
class CustomLoadImage(LoadImage):
|
||||
"""
|
||||
Class for loading one or multiple images from a given path.
|
||||
|
||||
If a reader is not specified, the appropriate file reading method is automatically chosen
|
||||
based on the file extension. Priority:
|
||||
- Reader passed by the user at runtime.
|
||||
- Reader specified in the constructor.
|
||||
- Registered readers (from last to first).
|
||||
- Standard readers for different formats (e.g., NibabelReader for nii, PILReader for png/jpg, etc.).
|
||||
|
||||
[Note] Here, the original ITKReader is replaced by the universal reader UniversalImageReader.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
reader: Optional[Union[ImageReader, Type[ImageReader], str]] = None,
|
||||
image_only: bool = False,
|
||||
dtype: DtypeLike = np.float32,
|
||||
ensure_channel_first: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
reader=reader,
|
||||
image_only=image_only,
|
||||
dtype=dtype,
|
||||
ensure_channel_first=ensure_channel_first,
|
||||
*args, **kwargs
|
||||
)
|
||||
# Clear the list of registered readers
|
||||
self.readers = []
|
||||
# Register the universal reader that handles TIFF, PNG, JPG, BMP, etc.
|
||||
self.register(UniversalImageReader(*args, **kwargs))
|
||||
|
||||
|
||||
class CustomLoadImaged(LoadImaged):
|
||||
"""
|
||||
Dictionary-based image loader.
|
||||
|
||||
Wraps image loading with CustomLoadImage and allows processing of data represented as a dictionary,
|
||||
where keys point to file paths.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
keys: KeysCollection,
|
||||
reader: Optional[Union[Type[ImageReader], str]] = None,
|
||||
dtype: DtypeLike = np.float32,
|
||||
meta_keys: Optional[KeysCollection] = None,
|
||||
meta_key_postfix: str = DEFAULT_POST_FIX,
|
||||
overwriting: bool = False,
|
||||
image_only: bool = False,
|
||||
ensure_channel_first: bool = False,
|
||||
simple_keys: bool = False,
|
||||
allow_missing_keys: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
keys=keys,
|
||||
reader=reader,
|
||||
dtype=dtype,
|
||||
meta_keys=meta_keys,
|
||||
meta_key_postfix=meta_key_postfix,
|
||||
overwriting=overwriting,
|
||||
image_only=image_only,
|
||||
ensure_channel_first=ensure_channel_first,
|
||||
simple_keys=simple_keys,
|
||||
allow_missing_keys=allow_missing_keys,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
# Assign the custom image loader
|
||||
self._loader = CustomLoadImage(
|
||||
reader=reader,
|
||||
image_only=image_only,
|
||||
dtype=dtype,
|
||||
ensure_channel_first=ensure_channel_first,
|
||||
*args, **kwargs
|
||||
)
|
||||
# Ensure that meta_key_postfix is a string
|
||||
if not isinstance(meta_key_postfix, str):
|
||||
raise TypeError(
|
||||
f"meta_key_postfix must be a string, but got {type(meta_key_postfix).__name__}."
|
||||
)
|
||||
# If meta_keys are not provided, create a tuple of None for each key
|
||||
self.meta_keys = (
|
||||
ensure_tuple_rep(None, len(self.keys))
|
||||
if meta_keys is None
|
||||
else ensure_tuple(meta_keys)
|
||||
)
|
||||
# Check that the number of meta_keys matches the number of keys
|
||||
if len(self.keys) != len(self.meta_keys):
|
||||
raise ValueError("meta_keys must have the same length as keys.")
|
||||
# Assign each key its corresponding metadata postfix
|
||||
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
|
||||
self.overwriting = overwriting
|
||||
|
||||
|
||||
class UniversalImageReader(NumpyReader):
|
||||
"""
|
||||
Universal image reader for TIFF, PNG, JPG, BMP, etc.
|
||||
|
||||
Uses:
|
||||
- tifffile for reading TIFF files.
|
||||
- ITK (if available) for reading other formats.
|
||||
- skimage.io for reading if the previous methods fail.
|
||||
|
||||
The image is loaded with its original number of channels (layers) without forced modifications
|
||||
(e.g., repeating or cropping channels).
|
||||
"""
|
||||
def __init__(
|
||||
self, channel_dim: Optional[int] = None, **kwargs,
|
||||
):
|
||||
super().__init__(channel_dim=channel_dim, **kwargs)
|
||||
self.kwargs = kwargs
|
||||
self.channel_dim = channel_dim
|
||||
|
||||
def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool:
|
||||
"""
|
||||
Check if the file format is supported for reading.
|
||||
|
||||
Supported extensions: tif, tiff, png, jpg, bmp, jpeg.
|
||||
"""
|
||||
suffixes: Sequence[str] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg"]
|
||||
return has_itk or is_supported_format(filename, suffixes)
|
||||
|
||||
def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs):
|
||||
"""
|
||||
Read image(s) from the given path.
|
||||
|
||||
Arguments:
|
||||
data: A file path or a sequence of file paths.
|
||||
kwargs: Additional parameters for reading.
|
||||
|
||||
Returns:
|
||||
A single image or a list of images depending on the number of paths provided.
|
||||
"""
|
||||
images: List[np.ndarray] = [] # List to store the loaded images
|
||||
|
||||
# Convert data to a tuple to support multiple files
|
||||
filenames: Sequence[PathLike] = ensure_tuple(data)
|
||||
# Merge parameters provided in the constructor and the read() method
|
||||
kwargs_ = self.kwargs.copy()
|
||||
kwargs_.update(kwargs)
|
||||
|
||||
for name in filenames:
|
||||
# Convert file name to string
|
||||
name = f"{name}"
|
||||
# If the file has a .tif or .tiff extension (case-insensitive), use tifffile for reading
|
||||
if name.lower().endswith((".tif", ".tiff")):
|
||||
img_array = tif.imread(name)
|
||||
else:
|
||||
# Attempt to read the image using ITK (if available)
|
||||
try:
|
||||
img_itk = itk.imread(name, **kwargs_)
|
||||
img_array = itk.array_view_from_image(img_itk, keep_axes=False)
|
||||
except Exception:
|
||||
# If ITK fails, use skimage.io for reading
|
||||
img_array = io.imread(name)
|
||||
|
||||
# Check the number of dimensions (axes) of the loaded image
|
||||
if img_array.ndim == 2:
|
||||
# If the image is 2D (height, width), add a new axis at the end to represent the channel
|
||||
img_array = np.expand_dims(img_array, axis=-1)
|
||||
|
||||
images.append(img_array)
|
||||
|
||||
# Return a single image if only one file was provided, otherwise return a list of images
|
||||
return images if len(filenames) > 1 else images[0]
|
||||
|
||||
|
||||
|
||||
CustomLoadImageD = CustomLoadImageDict = CustomLoadImaged
|
@ -0,0 +1,139 @@
|
||||
import numpy as np
|
||||
from skimage import exposure
|
||||
from monai.config.type_definitions import KeysCollection
|
||||
from monai.transforms.transform import Transform, MapTransform
|
||||
from typing import Dict, Hashable, Mapping, Sequence
|
||||
|
||||
__all__ = [
|
||||
"CustomNormalizeImage",
|
||||
"CustomNormalizeImaged",
|
||||
"CustomNormalizeImageD",
|
||||
"CustomNormalizeImageDict",
|
||||
]
|
||||
|
||||
|
||||
class CustomNormalizeImage(Transform):
|
||||
"""
|
||||
Normalize the image by rescaling intensity values based on specified percentiles.
|
||||
|
||||
The normalization can be applied either on the entire image or channel-wise.
|
||||
If the image is 2D (only height and width), a channel dimension is added for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, percentiles: Sequence[float] = (0, 99), channel_wise: bool = False) -> None:
|
||||
"""
|
||||
Args:
|
||||
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling.
|
||||
Default is (0, 99).
|
||||
channel_wise (bool): Whether to apply normalization on each channel individually.
|
||||
Default is False.
|
||||
"""
|
||||
self.lower, self.upper = percentiles # Unpack the lower and upper percentile values.
|
||||
self.channel_wise = channel_wise # Flag for channel-wise normalization.
|
||||
|
||||
def _normalize(self, img: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Rescale image intensity using non-zero values for percentile calculation.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): A numpy array representing a single-channel image.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A uint8 numpy array with rescaled intensity values.
|
||||
"""
|
||||
# Extract non-zero values to avoid background influence.
|
||||
non_zero_vals = img[np.nonzero(img)]
|
||||
# Calculate the specified percentiles from the non-zero values.
|
||||
computed_percentiles: np.ndarray = np.percentile(non_zero_vals, [self.lower, self.upper])
|
||||
# Rescale the intensity values to the full uint8 range.
|
||||
img_norm = exposure.rescale_intensity(
|
||||
img, in_range=(computed_percentiles[0], computed_percentiles[1]), out_range="uint8" # type: ignore
|
||||
)
|
||||
return img_norm.astype(np.uint8)
|
||||
|
||||
def __call__(self, img: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Apply normalization to the input image.
|
||||
|
||||
If the input image is 2D (height, width), a channel dimension is added.
|
||||
Depending on the 'channel_wise' flag, normalization is applied either to each channel individually or to the entire image.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): Input image as a numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized image as a numpy array.
|
||||
"""
|
||||
# Check if the image is 2D (grayscale). If so, add a new axis for the channel.
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=-1) # Added channel dimension for consistency.
|
||||
|
||||
if self.channel_wise:
|
||||
# Initialize an empty array with the same shape as the input image to store normalized channels.
|
||||
normalized_img = np.zeros(img.shape, dtype=np.uint8)
|
||||
|
||||
# Process each channel individually.
|
||||
for i in range(img.shape[-1]):
|
||||
channel_img: np.ndarray = img[:, :, i]
|
||||
|
||||
# Only normalize the channel if there are non-zero values present.
|
||||
if np.count_nonzero(channel_img) > 0:
|
||||
normalized_img[:, :, i] = self._normalize(channel_img)
|
||||
|
||||
img = normalized_img
|
||||
else:
|
||||
# Apply normalization to the entire image.
|
||||
img = self._normalize(img)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class CustomNormalizeImaged(MapTransform):
|
||||
"""
|
||||
Dictionary-based wrapper for CustomNormalizeImage.
|
||||
|
||||
This transform applies normalization to one or more images contained in a dictionary,
|
||||
where the keys point to the image data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keys: KeysCollection,
|
||||
percentiles: Sequence[float] = (1, 99),
|
||||
channel_wise: bool = False,
|
||||
allow_missing_keys: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
keys (KeysCollection): Keys identifying the image entries in the dictionary.
|
||||
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling.
|
||||
Default is (1, 99).
|
||||
channel_wise (bool): Whether to apply normalization on each channel individually.
|
||||
Default is False.
|
||||
allow_missing_keys (bool): If True, missing keys in the dictionary will be ignored.
|
||||
Default is False.
|
||||
"""
|
||||
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
|
||||
# Create an instance of the normalization transform with specified parameters.
|
||||
self.normalizer: CustomNormalizeImage = CustomNormalizeImage(percentiles, channel_wise)
|
||||
|
||||
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
|
||||
"""
|
||||
Apply the normalization transform to each image in the input dictionary.
|
||||
|
||||
Args:
|
||||
data (Mapping[Hashable, np.ndarray]): A dictionary mapping keys to numpy arrays representing images.
|
||||
|
||||
Returns:
|
||||
Dict[Hashable, np.ndarray]: A dictionary with the same keys where the images have been normalized.
|
||||
"""
|
||||
# Copy the input dictionary to avoid modifying the original data.
|
||||
d: Dict[Hashable, np.ndarray] = dict(data)
|
||||
# Iterate over each key specified in the transform and normalize the corresponding image.
|
||||
for key in self.keys:
|
||||
d[key] = self.normalizer(d[key])
|
||||
return d
|
||||
|
||||
|
||||
# Create aliases for the dictionary-based normalization transform.
|
||||
CustomNormalizeImageD = CustomNormalizeImageDict = CustomNormalizeImaged
|
Loading…
Reference in new issue