renamed img->image; label->mask

master
laynholt 3 weeks ago
parent 7a0b8ffa19
commit 8885950b9e

@ -2,7 +2,7 @@ from .transforms import (
get_train_transforms, get_train_transforms,
get_valid_transforms, get_valid_transforms,
get_test_transforms, get_test_transforms,
get_pred_transforms get_predict_transforms
) )
@ -10,5 +10,5 @@ __all__ = [
"get_train_transforms", "get_train_transforms",
"get_valid_transforms", "get_valid_transforms",
"get_test_transforms", "get_test_transforms",
"get_pred_transforms", "get_predict_transforms",
] ]

@ -9,7 +9,7 @@ __all__ = [
"get_train_transforms", "get_train_transforms",
"get_valid_transforms", "get_valid_transforms",
"get_test_transforms", "get_test_transforms",
"get_pred_transforms", "get_predict_transforms",
] ]
@ -33,21 +33,21 @@ def get_train_transforms():
train_transforms = Compose( train_transforms = Compose(
[ [
# Load image and label data in (H, W, C) format (image loaded as image-only). # Load image and label data in (H, W, C) format (image loaded as image-only).
CustomLoadImaged(keys=["img", "label"], image_only=True), CustomLoadImaged(keys=["image", "mask"], image_only=True),
# Normalize the (H, W, C) image using the specified percentiles. # Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImaged( CustomNormalizeImaged(
keys=["img"], keys=["image"],
allow_missing_keys=True, allow_missing_keys=True,
channel_wise=False, channel_wise=False,
percentiles=[0.0, 99.5], percentiles=[0.0, 99.5],
), ),
# Ensure both image and label are in channel-first format. # Ensure both image and label are in channel-first format.
EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1), EnsureChannelFirstd(keys=["image", "mask"], channel_dim=-1),
# Scale image intensities (do not scale the label). # Scale image intensities (do not scale the label).
ScaleIntensityd(keys=["img"], allow_missing_keys=True), ScaleIntensityd(keys=["image"], allow_missing_keys=True),
# Apply random zoom to both image and label. # Apply random zoom to both image and label.
RandZoomd( RandZoomd(
keys=["img", "label"], keys=["image", "mask"],
prob=0.5, prob=0.5,
min_zoom=0.25, min_zoom=0.25,
max_zoom=1.5, max_zoom=1.5,
@ -55,27 +55,27 @@ def get_train_transforms():
keep_size=False, keep_size=False,
), ),
# Pad spatial dimensions to ensure a size of 512. # Pad spatial dimensions to ensure a size of 512.
SpatialPadd(keys=["img", "label"], spatial_size=512), SpatialPadd(keys=["image", "mask"], spatial_size=512),
# Randomly crop a region of interest of size 512. # Randomly crop a region of interest of size 512.
RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False), RandSpatialCropd(keys=["image", "mask"], roi_size=512, random_size=False),
# Randomly flip the image and label along an axis. # Randomly flip the image and label along an axis.
RandAxisFlipd(keys=["img", "label"], prob=0.5), RandAxisFlipd(keys=["image", "mask"], prob=0.5),
# Randomly rotate the image and label by 90 degrees. # Randomly rotate the image and label by 90 degrees.
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=(0, 1)), RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=(0, 1)),
# Diversify intensities for selected cell regions. # Diversify intensities for selected cell regions.
IntensityDiversification(keys=["img", "label"], allow_missing_keys=True), IntensityDiversification(keys=["image", "mask"], allow_missing_keys=True),
# Apply random Gaussian noise to the image. # Apply random Gaussian noise to the image.
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1), RandGaussianNoised(keys=["image"], prob=0.25, mean=0, std=0.1),
# Randomly adjust the contrast of the image. # Randomly adjust the contrast of the image.
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)), RandAdjustContrastd(keys=["image"], prob=0.25, gamma=(1, 2)),
# Apply random Gaussian smoothing to the image. # Apply random Gaussian smoothing to the image.
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)), RandGaussianSmoothd(keys=["image"], prob=0.25, sigma_x=(1, 2)),
# Randomly shift the histogram of the image. # Randomly shift the histogram of the image.
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3), RandHistogramShiftd(keys=["image"], prob=0.25, num_control_points=3),
# Apply random Gaussian sharpening to the image. # Apply random Gaussian sharpening to the image.
RandGaussianSharpend(keys=["img"], prob=0.25), RandGaussianSharpend(keys=["image"], prob=0.25),
# Ensure that the data types are correct. # Ensure that the data types are correct.
EnsureTyped(keys=["img", "label"]), EnsureTyped(keys=["image", "mask"]),
] ]
) )
return train_transforms return train_transforms
@ -98,20 +98,20 @@ def get_valid_transforms():
valid_transforms = Compose( valid_transforms = Compose(
[ [
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys). # Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True), CustomLoadImaged(keys=["image", "mask"], allow_missing_keys=True, image_only=True),
# Normalize the (H, W, C) image using the specified percentiles. # Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImaged( CustomNormalizeImaged(
keys=["img"], keys=["image"],
allow_missing_keys=True, allow_missing_keys=True,
channel_wise=False, channel_wise=False,
percentiles=[0.0, 99.5], percentiles=[0.0, 99.5],
), ),
# Ensure both image and label are in channel-first format. # Ensure both image and label are in channel-first format.
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1), EnsureChannelFirstd(keys=["image", "mask"], allow_missing_keys=True, channel_dim=-1),
# Scale image intensities. # Scale image intensities.
ScaleIntensityd(keys=["img"], allow_missing_keys=True), ScaleIntensityd(keys=["image"], allow_missing_keys=True),
# Ensure that the data types are correct. # Ensure that the data types are correct.
EnsureTyped(keys=["img", "label"], allow_missing_keys=True), EnsureTyped(keys=["image", "mask"], allow_missing_keys=True),
] ]
) )
return valid_transforms return valid_transforms
@ -134,26 +134,26 @@ def get_test_transforms():
test_transforms = Compose( test_transforms = Compose(
[ [
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys). # Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True), CustomLoadImaged(keys=["image", "mask"], allow_missing_keys=True, image_only=True),
# Normalize the (H, W, C) image using the specified percentiles. # Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImaged( CustomNormalizeImaged(
keys=["img"], keys=["image"],
allow_missing_keys=True, allow_missing_keys=True,
channel_wise=False, channel_wise=False,
percentiles=[0.0, 99.5], percentiles=[0.0, 99.5],
), ),
# Ensure both image and label are in channel-first format. # Ensure both image and label are in channel-first format.
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1), EnsureChannelFirstd(keys=["image", "mask"], allow_missing_keys=True, channel_dim=-1),
# Scale image intensities. # Scale image intensities.
ScaleIntensityd(keys=["img"], allow_missing_keys=True), ScaleIntensityd(keys=["image"], allow_missing_keys=True),
# Ensure that the data types are correct. # Ensure that the data types are correct.
EnsureTyped(keys=["img", "label"], allow_missing_keys=True), EnsureTyped(keys=["image", "mask"], allow_missing_keys=True),
] ]
) )
return test_transforms return test_transforms
def get_pred_transforms(): def get_predict_transforms():
""" """
Returns the transformation pipeline for prediction preprocessing. Returns the transformation pipeline for prediction preprocessing.

@ -18,11 +18,11 @@ class BoundaryExclusion(MapTransform):
touches the image boundary. touches the image boundary.
""" """
def __init__(self, keys: Sequence[str] = ("label",), allow_missing_keys: bool = False) -> None: def __init__(self, keys: Sequence[str] = ("mask",), allow_missing_keys: bool = False) -> None:
""" """
Args: Args:
keys (Sequence[str]): Keys in the input dictionary corresponding to the label image. keys (Sequence[str]): Keys in the input dictionary corresponding to the label image.
Default is ("label",). Default is ("mask",).
allow_missing_keys (bool): If True, missing keys in the input will be ignored. allow_missing_keys (bool): If True, missing keys in the input will be ignored.
Default is False. Default is False.
""" """
@ -41,13 +41,13 @@ class BoundaryExclusion(MapTransform):
6. Assigning the transformed label back into the input dictionary. 6. Assigning the transformed label back into the input dictionary.
Args: Args:
data (Dict[str, np.ndarray]): Dictionary containing at least the "label" key with a label image. data (Dict[str, np.ndarray]): Dictionary containing at least the "mask" key with a label image.
Returns: Returns:
Dict[str, np.ndarray]: The input dictionary with the "label" key updated after boundary exclusion. Dict[str, np.ndarray]: The input dictionary with the "mask" key updated after boundary exclusion.
""" """
# Retrieve the original label image. # Retrieve the original label image.
label_original: np.ndarray = data["label"] label_original: np.ndarray = data["mask"]
# Create a deep copy of the original label for processing. # Create a deep copy of the original label for processing.
label: np.ndarray = copy.deepcopy(label_original) label: np.ndarray = copy.deepcopy(label_original)
# Detect cell boundaries with a thick boundary. # Detect cell boundaries with a thick boundary.
@ -77,7 +77,7 @@ class BoundaryExclusion(MapTransform):
new_label += label_original * bd new_label += label_original * bd
# Update the input dictionary with the transformed label. # Update the input dictionary with the transformed label.
data["label"] = new_label data["mask"] = new_label
return data return data
@ -93,7 +93,7 @@ class IntensityDiversification(MapTransform):
def __init__( def __init__(
self, self,
keys: Sequence[str] = ("img",), keys: Sequence[str] = ("image",),
change_cell_ratio: float = 0.4, change_cell_ratio: float = 0.4,
scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7), scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7),
allow_missing_keys: bool = False, allow_missing_keys: bool = False,
@ -101,7 +101,7 @@ class IntensityDiversification(MapTransform):
""" """
Args: Args:
keys (Sequence[str]): Keys in the input dictionary corresponding to the image. keys (Sequence[str]): Keys in the input dictionary corresponding to the image.
Default is ("img",). Default is ("image",).
change_cell_ratio (float): Ratio of cells to apply the intensity scaling. change_cell_ratio (float): Ratio of cells to apply the intensity scaling.
For example, 0.4 means 40% of the cells will be transformed. For example, 0.4 means 40% of the cells will be transformed.
Default is 0.4. Default is 0.4.
@ -137,11 +137,11 @@ class IntensityDiversification(MapTransform):
Args: Args:
data (Dict[str, np.ndarray]): A dictionary containing: data (Dict[str, np.ndarray]): A dictionary containing:
- "img": The original image array. - "image": The original image array.
- "label": The corresponding cell label image array. - "mask": The corresponding cell label image array.
Returns: Returns:
Dict[str, np.ndarray]: The updated dictionary with the "img" key modified after applying Dict[str, np.ndarray]: The updated dictionary with the "image" key modified after applying
the intensity transformation. the intensity transformation.
Raises: Raises:
@ -149,13 +149,13 @@ class IntensityDiversification(MapTransform):
""" """
# Extract the label information for all channels. # Extract the label information for all channels.
# The label array has dimensions (C, H, W), where C is the number of channels. # The label array has dimensions (C, H, W), where C is the number of channels.
label = data["label"] # shape: (C, H, W) label = data["mask"] # shape: (C, H, W)
# Process each channel independently. # Process each channel independently.
for c in range(label.shape[0]): for c in range(label.shape[0]):
# Extract the label and corresponding image channel for the current channel. # Extract the label and corresponding image channel for the current channel.
channel_label = label[c] channel_label = label[c]
img_channel = data["img"][c] img_channel = data["image"][c]
# Retrieve all unique cell IDs in the current channel. # Retrieve all unique cell IDs in the current channel.
# Exclude the background (0) from these IDs. # Exclude the background (0) from these IDs.
@ -187,6 +187,6 @@ class IntensityDiversification(MapTransform):
img_changed = self.randscale_intensity(img_changed) img_changed = self.randscale_intensity(img_changed)
# Combine the unchanged and modified parts to update the image channel. # Combine the unchanged and modified parts to update the image channel.
data["img"][c] = img_orig + img_changed data["image"][c] = img_orig + img_changed
return data return data

Loading…
Cancel
Save