From 8885950b9e14ef386b315171bbf3d1fd0e2559b1 Mon Sep 17 00:00:00 2001 From: laynholt Date: Wed, 16 Apr 2025 16:49:17 +0000 Subject: [PATCH] renamed img->image; label->mask --- core/data/__init__.py | 4 +-- core/data/transforms/__init__.py | 56 +++++++++++++++--------------- core/data/transforms/cell_aware.py | 28 +++++++-------- 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/core/data/__init__.py b/core/data/__init__.py index 3840b10..c28c183 100644 --- a/core/data/__init__.py +++ b/core/data/__init__.py @@ -2,7 +2,7 @@ from .transforms import ( get_train_transforms, get_valid_transforms, get_test_transforms, - get_pred_transforms + get_predict_transforms ) @@ -10,5 +10,5 @@ __all__ = [ "get_train_transforms", "get_valid_transforms", "get_test_transforms", - "get_pred_transforms", + "get_predict_transforms", ] \ No newline at end of file diff --git a/core/data/transforms/__init__.py b/core/data/transforms/__init__.py index f70acdd..9cd2b06 100644 --- a/core/data/transforms/__init__.py +++ b/core/data/transforms/__init__.py @@ -9,7 +9,7 @@ __all__ = [ "get_train_transforms", "get_valid_transforms", "get_test_transforms", - "get_pred_transforms", + "get_predict_transforms", ] @@ -33,21 +33,21 @@ def get_train_transforms(): train_transforms = Compose( [ # Load image and label data in (H, W, C) format (image loaded as image-only). - CustomLoadImaged(keys=["img", "label"], image_only=True), + CustomLoadImaged(keys=["image", "mask"], image_only=True), # Normalize the (H, W, C) image using the specified percentiles. CustomNormalizeImaged( - keys=["img"], + keys=["image"], allow_missing_keys=True, channel_wise=False, percentiles=[0.0, 99.5], ), # Ensure both image and label are in channel-first format. - EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1), + EnsureChannelFirstd(keys=["image", "mask"], channel_dim=-1), # Scale image intensities (do not scale the label). - ScaleIntensityd(keys=["img"], allow_missing_keys=True), + ScaleIntensityd(keys=["image"], allow_missing_keys=True), # Apply random zoom to both image and label. RandZoomd( - keys=["img", "label"], + keys=["image", "mask"], prob=0.5, min_zoom=0.25, max_zoom=1.5, @@ -55,27 +55,27 @@ def get_train_transforms(): keep_size=False, ), # Pad spatial dimensions to ensure a size of 512. - SpatialPadd(keys=["img", "label"], spatial_size=512), + SpatialPadd(keys=["image", "mask"], spatial_size=512), # Randomly crop a region of interest of size 512. - RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False), + RandSpatialCropd(keys=["image", "mask"], roi_size=512, random_size=False), # Randomly flip the image and label along an axis. - RandAxisFlipd(keys=["img", "label"], prob=0.5), + RandAxisFlipd(keys=["image", "mask"], prob=0.5), # Randomly rotate the image and label by 90 degrees. - RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=(0, 1)), + RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=(0, 1)), # Diversify intensities for selected cell regions. - IntensityDiversification(keys=["img", "label"], allow_missing_keys=True), + IntensityDiversification(keys=["image", "mask"], allow_missing_keys=True), # Apply random Gaussian noise to the image. - RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1), + RandGaussianNoised(keys=["image"], prob=0.25, mean=0, std=0.1), # Randomly adjust the contrast of the image. - RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)), + RandAdjustContrastd(keys=["image"], prob=0.25, gamma=(1, 2)), # Apply random Gaussian smoothing to the image. - RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)), + RandGaussianSmoothd(keys=["image"], prob=0.25, sigma_x=(1, 2)), # Randomly shift the histogram of the image. - RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3), + RandHistogramShiftd(keys=["image"], prob=0.25, num_control_points=3), # Apply random Gaussian sharpening to the image. - RandGaussianSharpend(keys=["img"], prob=0.25), + RandGaussianSharpend(keys=["image"], prob=0.25), # Ensure that the data types are correct. - EnsureTyped(keys=["img", "label"]), + EnsureTyped(keys=["image", "mask"]), ] ) return train_transforms @@ -98,20 +98,20 @@ def get_valid_transforms(): valid_transforms = Compose( [ # Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys). - CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True), + CustomLoadImaged(keys=["image", "mask"], allow_missing_keys=True, image_only=True), # Normalize the (H, W, C) image using the specified percentiles. CustomNormalizeImaged( - keys=["img"], + keys=["image"], allow_missing_keys=True, channel_wise=False, percentiles=[0.0, 99.5], ), # Ensure both image and label are in channel-first format. - EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1), + EnsureChannelFirstd(keys=["image", "mask"], allow_missing_keys=True, channel_dim=-1), # Scale image intensities. - ScaleIntensityd(keys=["img"], allow_missing_keys=True), + ScaleIntensityd(keys=["image"], allow_missing_keys=True), # Ensure that the data types are correct. - EnsureTyped(keys=["img", "label"], allow_missing_keys=True), + EnsureTyped(keys=["image", "mask"], allow_missing_keys=True), ] ) return valid_transforms @@ -134,26 +134,26 @@ def get_test_transforms(): test_transforms = Compose( [ # Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys). - CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True), + CustomLoadImaged(keys=["image", "mask"], allow_missing_keys=True, image_only=True), # Normalize the (H, W, C) image using the specified percentiles. CustomNormalizeImaged( - keys=["img"], + keys=["image"], allow_missing_keys=True, channel_wise=False, percentiles=[0.0, 99.5], ), # Ensure both image and label are in channel-first format. - EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1), + EnsureChannelFirstd(keys=["image", "mask"], allow_missing_keys=True, channel_dim=-1), # Scale image intensities. - ScaleIntensityd(keys=["img"], allow_missing_keys=True), + ScaleIntensityd(keys=["image"], allow_missing_keys=True), # Ensure that the data types are correct. - EnsureTyped(keys=["img", "label"], allow_missing_keys=True), + EnsureTyped(keys=["image", "mask"], allow_missing_keys=True), ] ) return test_transforms -def get_pred_transforms(): +def get_predict_transforms(): """ Returns the transformation pipeline for prediction preprocessing. diff --git a/core/data/transforms/cell_aware.py b/core/data/transforms/cell_aware.py index 92eca9d..a19e030 100644 --- a/core/data/transforms/cell_aware.py +++ b/core/data/transforms/cell_aware.py @@ -18,11 +18,11 @@ class BoundaryExclusion(MapTransform): touches the image boundary. """ - def __init__(self, keys: Sequence[str] = ("label",), allow_missing_keys: bool = False) -> None: + def __init__(self, keys: Sequence[str] = ("mask",), allow_missing_keys: bool = False) -> None: """ Args: keys (Sequence[str]): Keys in the input dictionary corresponding to the label image. - Default is ("label",). + Default is ("mask",). allow_missing_keys (bool): If True, missing keys in the input will be ignored. Default is False. """ @@ -41,13 +41,13 @@ class BoundaryExclusion(MapTransform): 6. Assigning the transformed label back into the input dictionary. Args: - data (Dict[str, np.ndarray]): Dictionary containing at least the "label" key with a label image. + data (Dict[str, np.ndarray]): Dictionary containing at least the "mask" key with a label image. Returns: - Dict[str, np.ndarray]: The input dictionary with the "label" key updated after boundary exclusion. + Dict[str, np.ndarray]: The input dictionary with the "mask" key updated after boundary exclusion. """ # Retrieve the original label image. - label_original: np.ndarray = data["label"] + label_original: np.ndarray = data["mask"] # Create a deep copy of the original label for processing. label: np.ndarray = copy.deepcopy(label_original) # Detect cell boundaries with a thick boundary. @@ -77,7 +77,7 @@ class BoundaryExclusion(MapTransform): new_label += label_original * bd # Update the input dictionary with the transformed label. - data["label"] = new_label + data["mask"] = new_label return data @@ -93,7 +93,7 @@ class IntensityDiversification(MapTransform): def __init__( self, - keys: Sequence[str] = ("img",), + keys: Sequence[str] = ("image",), change_cell_ratio: float = 0.4, scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7), allow_missing_keys: bool = False, @@ -101,7 +101,7 @@ class IntensityDiversification(MapTransform): """ Args: keys (Sequence[str]): Keys in the input dictionary corresponding to the image. - Default is ("img",). + Default is ("image",). change_cell_ratio (float): Ratio of cells to apply the intensity scaling. For example, 0.4 means 40% of the cells will be transformed. Default is 0.4. @@ -137,11 +137,11 @@ class IntensityDiversification(MapTransform): Args: data (Dict[str, np.ndarray]): A dictionary containing: - - "img": The original image array. - - "label": The corresponding cell label image array. + - "image": The original image array. + - "mask": The corresponding cell label image array. Returns: - Dict[str, np.ndarray]: The updated dictionary with the "img" key modified after applying + Dict[str, np.ndarray]: The updated dictionary with the "image" key modified after applying the intensity transformation. Raises: @@ -149,13 +149,13 @@ class IntensityDiversification(MapTransform): """ # Extract the label information for all channels. # The label array has dimensions (C, H, W), where C is the number of channels. - label = data["label"] # shape: (C, H, W) + label = data["mask"] # shape: (C, H, W) # Process each channel independently. for c in range(label.shape[0]): # Extract the label and corresponding image channel for the current channel. channel_label = label[c] - img_channel = data["img"][c] + img_channel = data["image"][c] # Retrieve all unique cell IDs in the current channel. # Exclude the background (0) from these IDs. @@ -187,6 +187,6 @@ class IntensityDiversification(MapTransform): img_changed = self.randscale_intensity(img_changed) # Combine the unchanged and modified parts to update the image channel. - data["img"][c] = img_orig + img_changed + data["image"][c] = img_orig + img_changed return data