From 78f97a72a27ddf61f5c6b9d1593b4b6a0d913e8d Mon Sep 17 00:00:00 2001 From: laynholt Date: Fri, 28 Mar 2025 20:21:47 +0000 Subject: [PATCH] simple project restruct, added transforms --- core/__init__.py | 2 +- core/data/__init__.py | 14 ++ core/data/transforms/__init__.py | 184 ++++++++++++++++++++ core/data/transforms/cell_aware.py | 192 ++++++++++++++++++++ core/data/transforms/load_image.py | 202 ++++++++++++++++++++++ core/data/transforms/normalize_image.py | 139 +++++++++++++++ core/{criteria => losses}/__init__.py | 0 core/{criteria => losses}/base.py | 0 core/{criteria => losses}/bce.py | 0 core/{criteria => losses}/ce.py | 0 core/{criteria => losses}/mse.py | 0 core/{criteria => losses}/mse_with_bce.py | 0 core/utils/__init__.py | 0 generate_config.py | 2 +- train.py | 6 +- 15 files changed, 736 insertions(+), 5 deletions(-) create mode 100644 core/data/__init__.py create mode 100644 core/data/transforms/__init__.py create mode 100644 core/data/transforms/cell_aware.py create mode 100644 core/data/transforms/load_image.py create mode 100644 core/data/transforms/normalize_image.py rename core/{criteria => losses}/__init__.py (100%) rename core/{criteria => losses}/base.py (100%) rename core/{criteria => losses}/bce.py (100%) rename core/{criteria => losses}/ce.py (100%) rename core/{criteria => losses}/mse.py (100%) rename core/{criteria => losses}/mse_with_bce.py (100%) create mode 100644 core/utils/__init__.py diff --git a/core/__init__.py b/core/__init__.py index 1c28630..4d1ec1f 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -1,5 +1,5 @@ from .models import ModelRegistry -from .criteria import CriterionRegistry +from .losses import CriterionRegistry from .optimizers import OptimizerRegistry from .schedulers import SchedulerRegistry diff --git a/core/data/__init__.py b/core/data/__init__.py new file mode 100644 index 0000000..3840b10 --- /dev/null +++ b/core/data/__init__.py @@ -0,0 +1,14 @@ +from .transforms import ( + get_train_transforms, + get_valid_transforms, + get_test_transforms, + get_pred_transforms +) + + +__all__ = [ + "get_train_transforms", + "get_valid_transforms", + "get_test_transforms", + "get_pred_transforms", +] \ No newline at end of file diff --git a/core/data/transforms/__init__.py b/core/data/transforms/__init__.py new file mode 100644 index 0000000..f70acdd --- /dev/null +++ b/core/data/transforms/__init__.py @@ -0,0 +1,184 @@ +from .cell_aware import IntensityDiversification +from .load_image import CustomLoadImage, CustomLoadImaged +from .normalize_image import CustomNormalizeImage, CustomNormalizeImaged + +from monai.transforms import * # type: ignore + + +__all__ = [ + "get_train_transforms", + "get_valid_transforms", + "get_test_transforms", + "get_pred_transforms", +] + + +def get_train_transforms(): + """ + Returns the transformation pipeline for training data. + + The training pipeline applies a series of image and label preprocessing steps: + 1. Load image and label data. + 2. Normalize the image intensities. + 3. Ensure the image and label have channel-first format. + 4. Scale image intensities. + 5. Apply spatial transformations (zoom, padding, cropping, flipping, and rotation). + 6. Diversify intensities for selected cell regions. + 7. Apply additional intensity perturbations (noise, contrast, smoothing, histogram shift, and sharpening). + 8. Convert the data types to the desired format. + + Returns: + Compose: The composed transformation pipeline for training. + """ + train_transforms = Compose( + [ + # Load image and label data in (H, W, C) format (image loaded as image-only). + CustomLoadImaged(keys=["img", "label"], image_only=True), + # Normalize the (H, W, C) image using the specified percentiles. + CustomNormalizeImaged( + keys=["img"], + allow_missing_keys=True, + channel_wise=False, + percentiles=[0.0, 99.5], + ), + # Ensure both image and label are in channel-first format. + EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1), + # Scale image intensities (do not scale the label). + ScaleIntensityd(keys=["img"], allow_missing_keys=True), + # Apply random zoom to both image and label. + RandZoomd( + keys=["img", "label"], + prob=0.5, + min_zoom=0.25, + max_zoom=1.5, + mode=["area", "nearest"], + keep_size=False, + ), + # Pad spatial dimensions to ensure a size of 512. + SpatialPadd(keys=["img", "label"], spatial_size=512), + # Randomly crop a region of interest of size 512. + RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False), + # Randomly flip the image and label along an axis. + RandAxisFlipd(keys=["img", "label"], prob=0.5), + # Randomly rotate the image and label by 90 degrees. + RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=(0, 1)), + # Diversify intensities for selected cell regions. + IntensityDiversification(keys=["img", "label"], allow_missing_keys=True), + # Apply random Gaussian noise to the image. + RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1), + # Randomly adjust the contrast of the image. + RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)), + # Apply random Gaussian smoothing to the image. + RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)), + # Randomly shift the histogram of the image. + RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3), + # Apply random Gaussian sharpening to the image. + RandGaussianSharpend(keys=["img"], prob=0.25), + # Ensure that the data types are correct. + EnsureTyped(keys=["img", "label"]), + ] + ) + return train_transforms + + +def get_valid_transforms(): + """ + Returns the transformation pipeline for validation data. + + The validation pipeline includes the following steps: + 1. Load image and label data (with missing keys allowed). + 2. Normalize the image intensities. + 3. Ensure the image and label are in channel-first format. + 4. Scale image intensities. + 5. Convert the data types to the desired format. + + Returns: + Compose: The composed transformation pipeline for validation. + """ + valid_transforms = Compose( + [ + # Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys). + CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True), + # Normalize the (H, W, C) image using the specified percentiles. + CustomNormalizeImaged( + keys=["img"], + allow_missing_keys=True, + channel_wise=False, + percentiles=[0.0, 99.5], + ), + # Ensure both image and label are in channel-first format. + EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1), + # Scale image intensities. + ScaleIntensityd(keys=["img"], allow_missing_keys=True), + # Ensure that the data types are correct. + EnsureTyped(keys=["img", "label"], allow_missing_keys=True), + ] + ) + return valid_transforms + + +def get_test_transforms(): + """ + Returns the transformation pipeline for test data. + + The test pipeline is similar to the validation pipeline and includes: + 1. Load image and label data (with missing keys allowed). + 2. Normalize the image intensities. + 3. Ensure the image and label are in channel-first format. + 4. Scale image intensities. + 5. Convert the data types to the desired format. + + Returns: + Compose: The composed transformation pipeline for testing. + """ + test_transforms = Compose( + [ + # Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys). + CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True), + # Normalize the (H, W, C) image using the specified percentiles. + CustomNormalizeImaged( + keys=["img"], + allow_missing_keys=True, + channel_wise=False, + percentiles=[0.0, 99.5], + ), + # Ensure both image and label are in channel-first format. + EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1), + # Scale image intensities. + ScaleIntensityd(keys=["img"], allow_missing_keys=True), + # Ensure that the data types are correct. + EnsureTyped(keys=["img", "label"], allow_missing_keys=True), + ] + ) + return test_transforms + + +def get_pred_transforms(): + """ + Returns the transformation pipeline for prediction preprocessing. + + The prediction pipeline includes the following steps: + 1. Load the image data. + 2. Normalize the image intensities. + 3. Ensure the image is in channel-first format. + 4. Scale image intensities. + 5. Convert the image to the required tensor type. + + Returns: + Compose: The composed transformation pipeline for prediction. + """ + pred_transforms = Compose( + [ + # Load the image data in (H, W, C) format (image loaded as image-only). + CustomLoadImage(image_only=True), + # Normalize the (H, W, C) image using the specified percentiles. + CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]), + # Ensure the image is in channel-first format. + EnsureChannelFirst(channel_dim=-1), # image shape: (C, H, W) + # Scale image intensities. + ScaleIntensity(), + # Convert the image to the required tensor type. + EnsureType(data_type="tensor"), + ] + ) + return pred_transforms diff --git a/core/data/transforms/cell_aware.py b/core/data/transforms/cell_aware.py new file mode 100644 index 0000000..92eca9d --- /dev/null +++ b/core/data/transforms/cell_aware.py @@ -0,0 +1,192 @@ +import copy +import numpy as np +from typing import Dict, Sequence, Tuple, Union +from skimage.segmentation import find_boundaries +from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore + + +__all__ = ["BoundaryExclusion", "IntensityDiversification"] + + +class BoundaryExclusion(MapTransform): + """ + Map the cell boundary pixel labels to the background class (0). + + This transform processes a label image by first detecting boundaries of cell regions + and then excluding those boundary pixels by setting them to 0. However, it retains + the original cell label if the cell is too small (less than 14x14 pixels) or if the cell + touches the image boundary. + """ + + def __init__(self, keys: Sequence[str] = ("label",), allow_missing_keys: bool = False) -> None: + """ + Args: + keys (Sequence[str]): Keys in the input dictionary corresponding to the label image. + Default is ("label",). + allow_missing_keys (bool): If True, missing keys in the input will be ignored. + Default is False. + """ + super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) + + def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Apply the boundary exclusion transform to the label image. + + The process involves: + 1. Deep-copying the original label. + 2. Finding boundaries using a thick mode with connectivity=1. + 3. Setting the boundary pixels to background (0). + 4. Restoring original labels for cells that are too small (< 14x14 pixels). + 5. Ensuring that cells touching the image boundary are not excluded. + 6. Assigning the transformed label back into the input dictionary. + + Args: + data (Dict[str, np.ndarray]): Dictionary containing at least the "label" key with a label image. + + Returns: + Dict[str, np.ndarray]: The input dictionary with the "label" key updated after boundary exclusion. + """ + # Retrieve the original label image. + label_original: np.ndarray = data["label"] + # Create a deep copy of the original label for processing. + label: np.ndarray = copy.deepcopy(label_original) + # Detect cell boundaries with a thick boundary. + boundary: np.ndarray = find_boundaries(label, connectivity=1, mode="thick") + # Exclude boundary pixels by setting them to 0. + label[boundary] = 0 + + # Create a new label copy for selective exclusion. + new_label: np.ndarray = copy.deepcopy(label_original) + new_label[label == 0] = 0 + + # Obtain unique cell indices and their pixel counts. + cell_idx, cell_counts = np.unique(label_original, return_counts=True) + + # If a cell is too small (< 196 pixels, approx. 14x14), restore its original label. + for k in range(len(cell_counts)): + if cell_counts[k] < 196: + new_label[label_original == cell_idx[k]] = cell_idx[k] + + # Ensure that cells at the image boundaries are not excluded. + # Get the dimensions of the label image. + H, W, _ = label_original.shape + # Create a binary mask with a border of 2 pixels preserved. + bd: np.ndarray = np.zeros_like(label_original, dtype=label.dtype) + bd[2 : H - 2, 2 : W - 2, :] = 1 + # Combine the preserved boundaries with the new label. + new_label += label_original * bd + + # Update the input dictionary with the transformed label. + data["label"] = new_label + + return data + + +class IntensityDiversification(MapTransform): + """ + Randomly rescale the intensity of cell pixels. + + This transform selects a subset of cells (based on the change_cell_ratio) and + applies a random intensity scaling to those cells. The intensity scaling is performed + using the RandScaleIntensity transform from MONAI. + """ + + def __init__( + self, + keys: Sequence[str] = ("img",), + change_cell_ratio: float = 0.4, + scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7), + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys (Sequence[str]): Keys in the input dictionary corresponding to the image. + Default is ("img",). + change_cell_ratio (float): Ratio of cells to apply the intensity scaling. + For example, 0.4 means 40% of the cells will be transformed. + Default is 0.4. + scale_factors (Sequence[float]): Factors used for random intensity scaling. + Default is (0.0, 0.7). + allow_missing_keys (bool): If True, missing keys in the input will be ignored. + Default is False. + """ + super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) + self.change_cell_ratio: float = change_cell_ratio + # Compose a random intensity scaling transform with 100% probability. + self.randscale_intensity = Compose([RandScaleIntensity(prob=1.0, factors=scale_factors)]) + + def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Apply a cell-wise intensity diversification transform to an input image. + + This function modifies the image by randomly selecting a subset of labeled cell regions + (per channel) and applying a random intensity scaling operation exclusively to those regions. + The transformation is performed independently on each channel of the image. + + The steps are as follows: + 1. Extract the label image for all channels (expected shape: (C, H, W)). + 2. For each channel, determine the unique cell IDs, excluding the background (labeled as 0). + 3. Raise a ValueError if no unique objects are found in the current label channel. + 4. Compute the number of cells to modify based on the provided change_cell_ratio. + 5. Randomly select the corresponding cell IDs for intensity modification. + 6. Create a binary mask that highlights the selected cell regions. + 7. Separate the image channel into two parts: one that remains unchanged and one that is + subjected to random intensity scaling. + 8. Apply the random intensity scaling to the selected regions. + 9. Combine the unchanged and modified parts to update the image for that channel. + + Args: + data (Dict[str, np.ndarray]): A dictionary containing: + - "img": The original image array. + - "label": The corresponding cell label image array. + + Returns: + Dict[str, np.ndarray]: The updated dictionary with the "img" key modified after applying + the intensity transformation. + + Raises: + ValueError: If no unique cell objects are found in a label channel. + """ + # Extract the label information for all channels. + # The label array has dimensions (C, H, W), where C is the number of channels. + label = data["label"] # shape: (C, H, W) + + # Process each channel independently. + for c in range(label.shape[0]): + # Extract the label and corresponding image channel for the current channel. + channel_label = label[c] + img_channel = data["img"][c] + + # Retrieve all unique cell IDs in the current channel. + # Exclude the background (0) from these IDs. + cell_ids = np.unique(channel_label) + cell_ids = cell_ids[cell_ids > 0] + + # If there are no unique cell objects in this channel, raise an exception. + if cell_ids.size == 0: + raise ValueError(f"No unique objects found in the label mask for channel {c}") + + # Determine the number of cells to modify using the change_cell_ratio. + change_count = int(len(cell_ids) * self.change_cell_ratio) + + # Randomly select a subset of cell IDs for intensity modification. + selected = np.random.choice(cell_ids, change_count, replace=False) + + # Create a binary mask for the current channel: + # - Pixels corresponding to the selected cell IDs are set to 1. + # - All other pixels are set to 0. + mask = np.isin(channel_label, selected).astype(np.float32) + + # Separate the image channel into two components: + # 1. img_orig: The portion of the image that remains unchanged. + # 2. img_changed: The portion that will have its intensity altered. + img_orig = (1 - mask) * img_channel + img_changed = mask * img_channel + + # Apply a random intensity scaling transformation to the selected regions. + img_changed = self.randscale_intensity(img_changed) + + # Combine the unchanged and modified parts to update the image channel. + data["img"][c] = img_orig + img_changed + + return data diff --git a/core/data/transforms/load_image.py b/core/data/transforms/load_image.py new file mode 100644 index 0000000..2929131 --- /dev/null +++ b/core/data/transforms/load_image.py @@ -0,0 +1,202 @@ +import numpy as np +import tifffile as tif +import skimage.io as io +from typing import List, Optional, Sequence, Type, Union + +from monai.utils.enums import PostFix +from monai.utils.module import optional_import +from monai.utils.misc import ensure_tuple, ensure_tuple_rep +from monai.data.utils import is_supported_format +from monai.data.image_reader import ImageReader, NumpyReader +from monai.transforms import LoadImage, LoadImaged # type: ignore +from monai.config.type_definitions import DtypeLike, PathLike, KeysCollection + + +# Default value for metadata postfix +DEFAULT_POST_FIX = PostFix.meta() + +# Try to import ITK library; if not available, has_itk will be False +itk, has_itk = optional_import("itk", allow_namespace_pkg=True) + + +__all__ = [ + "CustomLoadImage", # Basic image loader + "CustomLoadImaged", # Dictionary-based image loader + "CustomLoadImageD", # Dictionary-based image loader + "CustomLoadImageDict", # Dictionary-based image loader +] + + +class CustomLoadImage(LoadImage): + """ + Class for loading one or multiple images from a given path. + + If a reader is not specified, the appropriate file reading method is automatically chosen + based on the file extension. Priority: + - Reader passed by the user at runtime. + - Reader specified in the constructor. + - Registered readers (from last to first). + - Standard readers for different formats (e.g., NibabelReader for nii, PILReader for png/jpg, etc.). + + [Note] Here, the original ITKReader is replaced by the universal reader UniversalImageReader. + """ + def __init__( + self, + reader: Optional[Union[ImageReader, Type[ImageReader], str]] = None, + image_only: bool = False, + dtype: DtypeLike = np.float32, + ensure_channel_first: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__( + reader=reader, + image_only=image_only, + dtype=dtype, + ensure_channel_first=ensure_channel_first, + *args, **kwargs + ) + # Clear the list of registered readers + self.readers = [] + # Register the universal reader that handles TIFF, PNG, JPG, BMP, etc. + self.register(UniversalImageReader(*args, **kwargs)) + + +class CustomLoadImaged(LoadImaged): + """ + Dictionary-based image loader. + + Wraps image loading with CustomLoadImage and allows processing of data represented as a dictionary, + where keys point to file paths. + """ + def __init__( + self, + keys: KeysCollection, + reader: Optional[Union[Type[ImageReader], str]] = None, + dtype: DtypeLike = np.float32, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = DEFAULT_POST_FIX, + overwriting: bool = False, + image_only: bool = False, + ensure_channel_first: bool = False, + simple_keys: bool = False, + allow_missing_keys: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__( + keys=keys, + reader=reader, + dtype=dtype, + meta_keys=meta_keys, + meta_key_postfix=meta_key_postfix, + overwriting=overwriting, + image_only=image_only, + ensure_channel_first=ensure_channel_first, + simple_keys=simple_keys, + allow_missing_keys=allow_missing_keys, + *args, + **kwargs, + ) + # Assign the custom image loader + self._loader = CustomLoadImage( + reader=reader, + image_only=image_only, + dtype=dtype, + ensure_channel_first=ensure_channel_first, + *args, **kwargs + ) + # Ensure that meta_key_postfix is a string + if not isinstance(meta_key_postfix, str): + raise TypeError( + f"meta_key_postfix must be a string, but got {type(meta_key_postfix).__name__}." + ) + # If meta_keys are not provided, create a tuple of None for each key + self.meta_keys = ( + ensure_tuple_rep(None, len(self.keys)) + if meta_keys is None + else ensure_tuple(meta_keys) + ) + # Check that the number of meta_keys matches the number of keys + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys must have the same length as keys.") + # Assign each key its corresponding metadata postfix + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + self.overwriting = overwriting + + +class UniversalImageReader(NumpyReader): + """ + Universal image reader for TIFF, PNG, JPG, BMP, etc. + + Uses: + - tifffile for reading TIFF files. + - ITK (if available) for reading other formats. + - skimage.io for reading if the previous methods fail. + + The image is loaded with its original number of channels (layers) without forced modifications + (e.g., repeating or cropping channels). + """ + def __init__( + self, channel_dim: Optional[int] = None, **kwargs, + ): + super().__init__(channel_dim=channel_dim, **kwargs) + self.kwargs = kwargs + self.channel_dim = channel_dim + + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + """ + Check if the file format is supported for reading. + + Supported extensions: tif, tiff, png, jpg, bmp, jpeg. + """ + suffixes: Sequence[str] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg"] + return has_itk or is_supported_format(filename, suffixes) + + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): + """ + Read image(s) from the given path. + + Arguments: + data: A file path or a sequence of file paths. + kwargs: Additional parameters for reading. + + Returns: + A single image or a list of images depending on the number of paths provided. + """ + images: List[np.ndarray] = [] # List to store the loaded images + + # Convert data to a tuple to support multiple files + filenames: Sequence[PathLike] = ensure_tuple(data) + # Merge parameters provided in the constructor and the read() method + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + + for name in filenames: + # Convert file name to string + name = f"{name}" + # If the file has a .tif or .tiff extension (case-insensitive), use tifffile for reading + if name.lower().endswith((".tif", ".tiff")): + img_array = tif.imread(name) + else: + # Attempt to read the image using ITK (if available) + try: + img_itk = itk.imread(name, **kwargs_) + img_array = itk.array_view_from_image(img_itk, keep_axes=False) + except Exception: + # If ITK fails, use skimage.io for reading + img_array = io.imread(name) + + # Check the number of dimensions (axes) of the loaded image + if img_array.ndim == 2: + # If the image is 2D (height, width), add a new axis at the end to represent the channel + img_array = np.expand_dims(img_array, axis=-1) + + images.append(img_array) + + # Return a single image if only one file was provided, otherwise return a list of images + return images if len(filenames) > 1 else images[0] + + + +CustomLoadImageD = CustomLoadImageDict = CustomLoadImaged \ No newline at end of file diff --git a/core/data/transforms/normalize_image.py b/core/data/transforms/normalize_image.py new file mode 100644 index 0000000..42d0d1c --- /dev/null +++ b/core/data/transforms/normalize_image.py @@ -0,0 +1,139 @@ +import numpy as np +from skimage import exposure +from monai.config.type_definitions import KeysCollection +from monai.transforms.transform import Transform, MapTransform +from typing import Dict, Hashable, Mapping, Sequence + +__all__ = [ + "CustomNormalizeImage", + "CustomNormalizeImaged", + "CustomNormalizeImageD", + "CustomNormalizeImageDict", +] + + +class CustomNormalizeImage(Transform): + """ + Normalize the image by rescaling intensity values based on specified percentiles. + + The normalization can be applied either on the entire image or channel-wise. + If the image is 2D (only height and width), a channel dimension is added for consistency. + """ + + def __init__(self, percentiles: Sequence[float] = (0, 99), channel_wise: bool = False) -> None: + """ + Args: + percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling. + Default is (0, 99). + channel_wise (bool): Whether to apply normalization on each channel individually. + Default is False. + """ + self.lower, self.upper = percentiles # Unpack the lower and upper percentile values. + self.channel_wise = channel_wise # Flag for channel-wise normalization. + + def _normalize(self, img: np.ndarray) -> np.ndarray: + """ + Rescale image intensity using non-zero values for percentile calculation. + + Args: + img (np.ndarray): A numpy array representing a single-channel image. + + Returns: + np.ndarray: A uint8 numpy array with rescaled intensity values. + """ + # Extract non-zero values to avoid background influence. + non_zero_vals = img[np.nonzero(img)] + # Calculate the specified percentiles from the non-zero values. + computed_percentiles: np.ndarray = np.percentile(non_zero_vals, [self.lower, self.upper]) + # Rescale the intensity values to the full uint8 range. + img_norm = exposure.rescale_intensity( + img, in_range=(computed_percentiles[0], computed_percentiles[1]), out_range="uint8" # type: ignore + ) + return img_norm.astype(np.uint8) + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + Apply normalization to the input image. + + If the input image is 2D (height, width), a channel dimension is added. + Depending on the 'channel_wise' flag, normalization is applied either to each channel individually or to the entire image. + + Args: + img (np.ndarray): Input image as a numpy array. + + Returns: + np.ndarray: Normalized image as a numpy array. + """ + # Check if the image is 2D (grayscale). If so, add a new axis for the channel. + if img.ndim == 2: + img = np.expand_dims(img, axis=-1) # Added channel dimension for consistency. + + if self.channel_wise: + # Initialize an empty array with the same shape as the input image to store normalized channels. + normalized_img = np.zeros(img.shape, dtype=np.uint8) + + # Process each channel individually. + for i in range(img.shape[-1]): + channel_img: np.ndarray = img[:, :, i] + + # Only normalize the channel if there are non-zero values present. + if np.count_nonzero(channel_img) > 0: + normalized_img[:, :, i] = self._normalize(channel_img) + + img = normalized_img + else: + # Apply normalization to the entire image. + img = self._normalize(img) + + return img + + +class CustomNormalizeImaged(MapTransform): + """ + Dictionary-based wrapper for CustomNormalizeImage. + + This transform applies normalization to one or more images contained in a dictionary, + where the keys point to the image data. + """ + + def __init__( + self, + keys: KeysCollection, + percentiles: Sequence[float] = (1, 99), + channel_wise: bool = False, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys (KeysCollection): Keys identifying the image entries in the dictionary. + percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling. + Default is (1, 99). + channel_wise (bool): Whether to apply normalization on each channel individually. + Default is False. + allow_missing_keys (bool): If True, missing keys in the dictionary will be ignored. + Default is False. + """ + super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) + # Create an instance of the normalization transform with specified parameters. + self.normalizer: CustomNormalizeImage = CustomNormalizeImage(percentiles, channel_wise) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + """ + Apply the normalization transform to each image in the input dictionary. + + Args: + data (Mapping[Hashable, np.ndarray]): A dictionary mapping keys to numpy arrays representing images. + + Returns: + Dict[Hashable, np.ndarray]: A dictionary with the same keys where the images have been normalized. + """ + # Copy the input dictionary to avoid modifying the original data. + d: Dict[Hashable, np.ndarray] = dict(data) + # Iterate over each key specified in the transform and normalize the corresponding image. + for key in self.keys: + d[key] = self.normalizer(d[key]) + return d + + +# Create aliases for the dictionary-based normalization transform. +CustomNormalizeImageD = CustomNormalizeImageDict = CustomNormalizeImaged diff --git a/core/criteria/__init__.py b/core/losses/__init__.py similarity index 100% rename from core/criteria/__init__.py rename to core/losses/__init__.py diff --git a/core/criteria/base.py b/core/losses/base.py similarity index 100% rename from core/criteria/base.py rename to core/losses/base.py diff --git a/core/criteria/bce.py b/core/losses/bce.py similarity index 100% rename from core/criteria/bce.py rename to core/losses/bce.py diff --git a/core/criteria/ce.py b/core/losses/ce.py similarity index 100% rename from core/criteria/ce.py rename to core/losses/ce.py diff --git a/core/criteria/mse.py b/core/losses/mse.py similarity index 100% rename from core/criteria/mse.py rename to core/losses/mse.py diff --git a/core/criteria/mse_with_bce.py b/core/losses/mse_with_bce.py similarity index 100% rename from core/criteria/mse_with_bce.py rename to core/losses/mse_with_bce.py diff --git a/core/utils/__init__.py b/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generate_config.py b/generate_config.py index 054eaaa..9fa783d 100644 --- a/generate_config.py +++ b/generate_config.py @@ -98,7 +98,7 @@ def main(): base_filename = f"{chosen_model}_{chosen_criterion}_{chosen_optimizer}_{chosen_scheduler}" # Determine the output directory relative to this script. - base_dir = os.path.join(script_path, "config/jsons", "train" if is_training else "predict") + base_dir = os.path.join(script_path, "config/templates", "train" if is_training else "predict") os.makedirs(base_dir, exist_ok=True) filename = f"{base_filename}.json" diff --git a/train.py b/train.py index f03ef41..1cd24f2 100644 --- a/train.py +++ b/train.py @@ -2,13 +2,13 @@ from config.config import Config from pprint import pprint -config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json') +config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json') pprint(config, indent=4) print('\n\n') -config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV.json') +config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json') pprint(config, indent=4) print('\n\n') -config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV_1.json') +config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV_1.json') pprint(config, indent=4) \ No newline at end of file