simple project restruct, added transforms

master
laynholt 1 month ago
parent 33ce003657
commit 78f97a72a2

@ -1,5 +1,5 @@
from .models import ModelRegistry
from .criteria import CriterionRegistry
from .losses import CriterionRegistry
from .optimizers import OptimizerRegistry
from .schedulers import SchedulerRegistry

@ -0,0 +1,14 @@
from .transforms import (
get_train_transforms,
get_valid_transforms,
get_test_transforms,
get_pred_transforms
)
__all__ = [
"get_train_transforms",
"get_valid_transforms",
"get_test_transforms",
"get_pred_transforms",
]

@ -0,0 +1,184 @@
from .cell_aware import IntensityDiversification
from .load_image import CustomLoadImage, CustomLoadImaged
from .normalize_image import CustomNormalizeImage, CustomNormalizeImaged
from monai.transforms import * # type: ignore
__all__ = [
"get_train_transforms",
"get_valid_transforms",
"get_test_transforms",
"get_pred_transforms",
]
def get_train_transforms():
"""
Returns the transformation pipeline for training data.
The training pipeline applies a series of image and label preprocessing steps:
1. Load image and label data.
2. Normalize the image intensities.
3. Ensure the image and label have channel-first format.
4. Scale image intensities.
5. Apply spatial transformations (zoom, padding, cropping, flipping, and rotation).
6. Diversify intensities for selected cell regions.
7. Apply additional intensity perturbations (noise, contrast, smoothing, histogram shift, and sharpening).
8. Convert the data types to the desired format.
Returns:
Compose: The composed transformation pipeline for training.
"""
train_transforms = Compose(
[
# Load image and label data in (H, W, C) format (image loaded as image-only).
CustomLoadImaged(keys=["img", "label"], image_only=True),
# Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImaged(
keys=["img"],
allow_missing_keys=True,
channel_wise=False,
percentiles=[0.0, 99.5],
),
# Ensure both image and label are in channel-first format.
EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
# Scale image intensities (do not scale the label).
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
# Apply random zoom to both image and label.
RandZoomd(
keys=["img", "label"],
prob=0.5,
min_zoom=0.25,
max_zoom=1.5,
mode=["area", "nearest"],
keep_size=False,
),
# Pad spatial dimensions to ensure a size of 512.
SpatialPadd(keys=["img", "label"], spatial_size=512),
# Randomly crop a region of interest of size 512.
RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
# Randomly flip the image and label along an axis.
RandAxisFlipd(keys=["img", "label"], prob=0.5),
# Randomly rotate the image and label by 90 degrees.
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=(0, 1)),
# Diversify intensities for selected cell regions.
IntensityDiversification(keys=["img", "label"], allow_missing_keys=True),
# Apply random Gaussian noise to the image.
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
# Randomly adjust the contrast of the image.
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
# Apply random Gaussian smoothing to the image.
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
# Randomly shift the histogram of the image.
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
# Apply random Gaussian sharpening to the image.
RandGaussianSharpend(keys=["img"], prob=0.25),
# Ensure that the data types are correct.
EnsureTyped(keys=["img", "label"]),
]
)
return train_transforms
def get_valid_transforms():
"""
Returns the transformation pipeline for validation data.
The validation pipeline includes the following steps:
1. Load image and label data (with missing keys allowed).
2. Normalize the image intensities.
3. Ensure the image and label are in channel-first format.
4. Scale image intensities.
5. Convert the data types to the desired format.
Returns:
Compose: The composed transformation pipeline for validation.
"""
valid_transforms = Compose(
[
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
# Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImaged(
keys=["img"],
allow_missing_keys=True,
channel_wise=False,
percentiles=[0.0, 99.5],
),
# Ensure both image and label are in channel-first format.
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
# Scale image intensities.
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
# Ensure that the data types are correct.
EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
]
)
return valid_transforms
def get_test_transforms():
"""
Returns the transformation pipeline for test data.
The test pipeline is similar to the validation pipeline and includes:
1. Load image and label data (with missing keys allowed).
2. Normalize the image intensities.
3. Ensure the image and label are in channel-first format.
4. Scale image intensities.
5. Convert the data types to the desired format.
Returns:
Compose: The composed transformation pipeline for testing.
"""
test_transforms = Compose(
[
# Load image and label data in (H, W, C) format (image loaded as image-only; allow missing keys).
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
# Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImaged(
keys=["img"],
allow_missing_keys=True,
channel_wise=False,
percentiles=[0.0, 99.5],
),
# Ensure both image and label are in channel-first format.
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
# Scale image intensities.
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
# Ensure that the data types are correct.
EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
]
)
return test_transforms
def get_pred_transforms():
"""
Returns the transformation pipeline for prediction preprocessing.
The prediction pipeline includes the following steps:
1. Load the image data.
2. Normalize the image intensities.
3. Ensure the image is in channel-first format.
4. Scale image intensities.
5. Convert the image to the required tensor type.
Returns:
Compose: The composed transformation pipeline for prediction.
"""
pred_transforms = Compose(
[
# Load the image data in (H, W, C) format (image loaded as image-only).
CustomLoadImage(image_only=True),
# Normalize the (H, W, C) image using the specified percentiles.
CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]),
# Ensure the image is in channel-first format.
EnsureChannelFirst(channel_dim=-1), # image shape: (C, H, W)
# Scale image intensities.
ScaleIntensity(),
# Convert the image to the required tensor type.
EnsureType(data_type="tensor"),
]
)
return pred_transforms

@ -0,0 +1,192 @@
import copy
import numpy as np
from typing import Dict, Sequence, Tuple, Union
from skimage.segmentation import find_boundaries
from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore
__all__ = ["BoundaryExclusion", "IntensityDiversification"]
class BoundaryExclusion(MapTransform):
"""
Map the cell boundary pixel labels to the background class (0).
This transform processes a label image by first detecting boundaries of cell regions
and then excluding those boundary pixels by setting them to 0. However, it retains
the original cell label if the cell is too small (less than 14x14 pixels) or if the cell
touches the image boundary.
"""
def __init__(self, keys: Sequence[str] = ("label",), allow_missing_keys: bool = False) -> None:
"""
Args:
keys (Sequence[str]): Keys in the input dictionary corresponding to the label image.
Default is ("label",).
allow_missing_keys (bool): If True, missing keys in the input will be ignored.
Default is False.
"""
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
Apply the boundary exclusion transform to the label image.
The process involves:
1. Deep-copying the original label.
2. Finding boundaries using a thick mode with connectivity=1.
3. Setting the boundary pixels to background (0).
4. Restoring original labels for cells that are too small (< 14x14 pixels).
5. Ensuring that cells touching the image boundary are not excluded.
6. Assigning the transformed label back into the input dictionary.
Args:
data (Dict[str, np.ndarray]): Dictionary containing at least the "label" key with a label image.
Returns:
Dict[str, np.ndarray]: The input dictionary with the "label" key updated after boundary exclusion.
"""
# Retrieve the original label image.
label_original: np.ndarray = data["label"]
# Create a deep copy of the original label for processing.
label: np.ndarray = copy.deepcopy(label_original)
# Detect cell boundaries with a thick boundary.
boundary: np.ndarray = find_boundaries(label, connectivity=1, mode="thick")
# Exclude boundary pixels by setting them to 0.
label[boundary] = 0
# Create a new label copy for selective exclusion.
new_label: np.ndarray = copy.deepcopy(label_original)
new_label[label == 0] = 0
# Obtain unique cell indices and their pixel counts.
cell_idx, cell_counts = np.unique(label_original, return_counts=True)
# If a cell is too small (< 196 pixels, approx. 14x14), restore its original label.
for k in range(len(cell_counts)):
if cell_counts[k] < 196:
new_label[label_original == cell_idx[k]] = cell_idx[k]
# Ensure that cells at the image boundaries are not excluded.
# Get the dimensions of the label image.
H, W, _ = label_original.shape
# Create a binary mask with a border of 2 pixels preserved.
bd: np.ndarray = np.zeros_like(label_original, dtype=label.dtype)
bd[2 : H - 2, 2 : W - 2, :] = 1
# Combine the preserved boundaries with the new label.
new_label += label_original * bd
# Update the input dictionary with the transformed label.
data["label"] = new_label
return data
class IntensityDiversification(MapTransform):
"""
Randomly rescale the intensity of cell pixels.
This transform selects a subset of cells (based on the change_cell_ratio) and
applies a random intensity scaling to those cells. The intensity scaling is performed
using the RandScaleIntensity transform from MONAI.
"""
def __init__(
self,
keys: Sequence[str] = ("img",),
change_cell_ratio: float = 0.4,
scale_factors: Union[Tuple[float, float], float] = (0.0, 0.7),
allow_missing_keys: bool = False,
) -> None:
"""
Args:
keys (Sequence[str]): Keys in the input dictionary corresponding to the image.
Default is ("img",).
change_cell_ratio (float): Ratio of cells to apply the intensity scaling.
For example, 0.4 means 40% of the cells will be transformed.
Default is 0.4.
scale_factors (Sequence[float]): Factors used for random intensity scaling.
Default is (0.0, 0.7).
allow_missing_keys (bool): If True, missing keys in the input will be ignored.
Default is False.
"""
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
self.change_cell_ratio: float = change_cell_ratio
# Compose a random intensity scaling transform with 100% probability.
self.randscale_intensity = Compose([RandScaleIntensity(prob=1.0, factors=scale_factors)])
def __call__(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
Apply a cell-wise intensity diversification transform to an input image.
This function modifies the image by randomly selecting a subset of labeled cell regions
(per channel) and applying a random intensity scaling operation exclusively to those regions.
The transformation is performed independently on each channel of the image.
The steps are as follows:
1. Extract the label image for all channels (expected shape: (C, H, W)).
2. For each channel, determine the unique cell IDs, excluding the background (labeled as 0).
3. Raise a ValueError if no unique objects are found in the current label channel.
4. Compute the number of cells to modify based on the provided change_cell_ratio.
5. Randomly select the corresponding cell IDs for intensity modification.
6. Create a binary mask that highlights the selected cell regions.
7. Separate the image channel into two parts: one that remains unchanged and one that is
subjected to random intensity scaling.
8. Apply the random intensity scaling to the selected regions.
9. Combine the unchanged and modified parts to update the image for that channel.
Args:
data (Dict[str, np.ndarray]): A dictionary containing:
- "img": The original image array.
- "label": The corresponding cell label image array.
Returns:
Dict[str, np.ndarray]: The updated dictionary with the "img" key modified after applying
the intensity transformation.
Raises:
ValueError: If no unique cell objects are found in a label channel.
"""
# Extract the label information for all channels.
# The label array has dimensions (C, H, W), where C is the number of channels.
label = data["label"] # shape: (C, H, W)
# Process each channel independently.
for c in range(label.shape[0]):
# Extract the label and corresponding image channel for the current channel.
channel_label = label[c]
img_channel = data["img"][c]
# Retrieve all unique cell IDs in the current channel.
# Exclude the background (0) from these IDs.
cell_ids = np.unique(channel_label)
cell_ids = cell_ids[cell_ids > 0]
# If there are no unique cell objects in this channel, raise an exception.
if cell_ids.size == 0:
raise ValueError(f"No unique objects found in the label mask for channel {c}")
# Determine the number of cells to modify using the change_cell_ratio.
change_count = int(len(cell_ids) * self.change_cell_ratio)
# Randomly select a subset of cell IDs for intensity modification.
selected = np.random.choice(cell_ids, change_count, replace=False)
# Create a binary mask for the current channel:
# - Pixels corresponding to the selected cell IDs are set to 1.
# - All other pixels are set to 0.
mask = np.isin(channel_label, selected).astype(np.float32)
# Separate the image channel into two components:
# 1. img_orig: The portion of the image that remains unchanged.
# 2. img_changed: The portion that will have its intensity altered.
img_orig = (1 - mask) * img_channel
img_changed = mask * img_channel
# Apply a random intensity scaling transformation to the selected regions.
img_changed = self.randscale_intensity(img_changed)
# Combine the unchanged and modified parts to update the image channel.
data["img"][c] = img_orig + img_changed
return data

@ -0,0 +1,202 @@
import numpy as np
import tifffile as tif
import skimage.io as io
from typing import List, Optional, Sequence, Type, Union
from monai.utils.enums import PostFix
from monai.utils.module import optional_import
from monai.utils.misc import ensure_tuple, ensure_tuple_rep
from monai.data.utils import is_supported_format
from monai.data.image_reader import ImageReader, NumpyReader
from monai.transforms import LoadImage, LoadImaged # type: ignore
from monai.config.type_definitions import DtypeLike, PathLike, KeysCollection
# Default value for metadata postfix
DEFAULT_POST_FIX = PostFix.meta()
# Try to import ITK library; if not available, has_itk will be False
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
__all__ = [
"CustomLoadImage", # Basic image loader
"CustomLoadImaged", # Dictionary-based image loader
"CustomLoadImageD", # Dictionary-based image loader
"CustomLoadImageDict", # Dictionary-based image loader
]
class CustomLoadImage(LoadImage):
"""
Class for loading one or multiple images from a given path.
If a reader is not specified, the appropriate file reading method is automatically chosen
based on the file extension. Priority:
- Reader passed by the user at runtime.
- Reader specified in the constructor.
- Registered readers (from last to first).
- Standard readers for different formats (e.g., NibabelReader for nii, PILReader for png/jpg, etc.).
[Note] Here, the original ITKReader is replaced by the universal reader UniversalImageReader.
"""
def __init__(
self,
reader: Optional[Union[ImageReader, Type[ImageReader], str]] = None,
image_only: bool = False,
dtype: DtypeLike = np.float32,
ensure_channel_first: bool = False,
*args,
**kwargs,
) -> None:
super().__init__(
reader=reader,
image_only=image_only,
dtype=dtype,
ensure_channel_first=ensure_channel_first,
*args, **kwargs
)
# Clear the list of registered readers
self.readers = []
# Register the universal reader that handles TIFF, PNG, JPG, BMP, etc.
self.register(UniversalImageReader(*args, **kwargs))
class CustomLoadImaged(LoadImaged):
"""
Dictionary-based image loader.
Wraps image loading with CustomLoadImage and allows processing of data represented as a dictionary,
where keys point to file paths.
"""
def __init__(
self,
keys: KeysCollection,
reader: Optional[Union[Type[ImageReader], str]] = None,
dtype: DtypeLike = np.float32,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = DEFAULT_POST_FIX,
overwriting: bool = False,
image_only: bool = False,
ensure_channel_first: bool = False,
simple_keys: bool = False,
allow_missing_keys: bool = False,
*args,
**kwargs,
) -> None:
super().__init__(
keys=keys,
reader=reader,
dtype=dtype,
meta_keys=meta_keys,
meta_key_postfix=meta_key_postfix,
overwriting=overwriting,
image_only=image_only,
ensure_channel_first=ensure_channel_first,
simple_keys=simple_keys,
allow_missing_keys=allow_missing_keys,
*args,
**kwargs,
)
# Assign the custom image loader
self._loader = CustomLoadImage(
reader=reader,
image_only=image_only,
dtype=dtype,
ensure_channel_first=ensure_channel_first,
*args, **kwargs
)
# Ensure that meta_key_postfix is a string
if not isinstance(meta_key_postfix, str):
raise TypeError(
f"meta_key_postfix must be a string, but got {type(meta_key_postfix).__name__}."
)
# If meta_keys are not provided, create a tuple of None for each key
self.meta_keys = (
ensure_tuple_rep(None, len(self.keys))
if meta_keys is None
else ensure_tuple(meta_keys)
)
# Check that the number of meta_keys matches the number of keys
if len(self.keys) != len(self.meta_keys):
raise ValueError("meta_keys must have the same length as keys.")
# Assign each key its corresponding metadata postfix
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.overwriting = overwriting
class UniversalImageReader(NumpyReader):
"""
Universal image reader for TIFF, PNG, JPG, BMP, etc.
Uses:
- tifffile for reading TIFF files.
- ITK (if available) for reading other formats.
- skimage.io for reading if the previous methods fail.
The image is loaded with its original number of channels (layers) without forced modifications
(e.g., repeating or cropping channels).
"""
def __init__(
self, channel_dim: Optional[int] = None, **kwargs,
):
super().__init__(channel_dim=channel_dim, **kwargs)
self.kwargs = kwargs
self.channel_dim = channel_dim
def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool:
"""
Check if the file format is supported for reading.
Supported extensions: tif, tiff, png, jpg, bmp, jpeg.
"""
suffixes: Sequence[str] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg"]
return has_itk or is_supported_format(filename, suffixes)
def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs):
"""
Read image(s) from the given path.
Arguments:
data: A file path or a sequence of file paths.
kwargs: Additional parameters for reading.
Returns:
A single image or a list of images depending on the number of paths provided.
"""
images: List[np.ndarray] = [] # List to store the loaded images
# Convert data to a tuple to support multiple files
filenames: Sequence[PathLike] = ensure_tuple(data)
# Merge parameters provided in the constructor and the read() method
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
# Convert file name to string
name = f"{name}"
# If the file has a .tif or .tiff extension (case-insensitive), use tifffile for reading
if name.lower().endswith((".tif", ".tiff")):
img_array = tif.imread(name)
else:
# Attempt to read the image using ITK (if available)
try:
img_itk = itk.imread(name, **kwargs_)
img_array = itk.array_view_from_image(img_itk, keep_axes=False)
except Exception:
# If ITK fails, use skimage.io for reading
img_array = io.imread(name)
# Check the number of dimensions (axes) of the loaded image
if img_array.ndim == 2:
# If the image is 2D (height, width), add a new axis at the end to represent the channel
img_array = np.expand_dims(img_array, axis=-1)
images.append(img_array)
# Return a single image if only one file was provided, otherwise return a list of images
return images if len(filenames) > 1 else images[0]
CustomLoadImageD = CustomLoadImageDict = CustomLoadImaged

@ -0,0 +1,139 @@
import numpy as np
from skimage import exposure
from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import Transform, MapTransform
from typing import Dict, Hashable, Mapping, Sequence
__all__ = [
"CustomNormalizeImage",
"CustomNormalizeImaged",
"CustomNormalizeImageD",
"CustomNormalizeImageDict",
]
class CustomNormalizeImage(Transform):
"""
Normalize the image by rescaling intensity values based on specified percentiles.
The normalization can be applied either on the entire image or channel-wise.
If the image is 2D (only height and width), a channel dimension is added for consistency.
"""
def __init__(self, percentiles: Sequence[float] = (0, 99), channel_wise: bool = False) -> None:
"""
Args:
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling.
Default is (0, 99).
channel_wise (bool): Whether to apply normalization on each channel individually.
Default is False.
"""
self.lower, self.upper = percentiles # Unpack the lower and upper percentile values.
self.channel_wise = channel_wise # Flag for channel-wise normalization.
def _normalize(self, img: np.ndarray) -> np.ndarray:
"""
Rescale image intensity using non-zero values for percentile calculation.
Args:
img (np.ndarray): A numpy array representing a single-channel image.
Returns:
np.ndarray: A uint8 numpy array with rescaled intensity values.
"""
# Extract non-zero values to avoid background influence.
non_zero_vals = img[np.nonzero(img)]
# Calculate the specified percentiles from the non-zero values.
computed_percentiles: np.ndarray = np.percentile(non_zero_vals, [self.lower, self.upper])
# Rescale the intensity values to the full uint8 range.
img_norm = exposure.rescale_intensity(
img, in_range=(computed_percentiles[0], computed_percentiles[1]), out_range="uint8" # type: ignore
)
return img_norm.astype(np.uint8)
def __call__(self, img: np.ndarray) -> np.ndarray:
"""
Apply normalization to the input image.
If the input image is 2D (height, width), a channel dimension is added.
Depending on the 'channel_wise' flag, normalization is applied either to each channel individually or to the entire image.
Args:
img (np.ndarray): Input image as a numpy array.
Returns:
np.ndarray: Normalized image as a numpy array.
"""
# Check if the image is 2D (grayscale). If so, add a new axis for the channel.
if img.ndim == 2:
img = np.expand_dims(img, axis=-1) # Added channel dimension for consistency.
if self.channel_wise:
# Initialize an empty array with the same shape as the input image to store normalized channels.
normalized_img = np.zeros(img.shape, dtype=np.uint8)
# Process each channel individually.
for i in range(img.shape[-1]):
channel_img: np.ndarray = img[:, :, i]
# Only normalize the channel if there are non-zero values present.
if np.count_nonzero(channel_img) > 0:
normalized_img[:, :, i] = self._normalize(channel_img)
img = normalized_img
else:
# Apply normalization to the entire image.
img = self._normalize(img)
return img
class CustomNormalizeImaged(MapTransform):
"""
Dictionary-based wrapper for CustomNormalizeImage.
This transform applies normalization to one or more images contained in a dictionary,
where the keys point to the image data.
"""
def __init__(
self,
keys: KeysCollection,
percentiles: Sequence[float] = (1, 99),
channel_wise: bool = False,
allow_missing_keys: bool = False,
) -> None:
"""
Args:
keys (KeysCollection): Keys identifying the image entries in the dictionary.
percentiles (Sequence[float]): Lower and upper percentiles used for intensity scaling.
Default is (1, 99).
channel_wise (bool): Whether to apply normalization on each channel individually.
Default is False.
allow_missing_keys (bool): If True, missing keys in the dictionary will be ignored.
Default is False.
"""
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
# Create an instance of the normalization transform with specified parameters.
self.normalizer: CustomNormalizeImage = CustomNormalizeImage(percentiles, channel_wise)
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
"""
Apply the normalization transform to each image in the input dictionary.
Args:
data (Mapping[Hashable, np.ndarray]): A dictionary mapping keys to numpy arrays representing images.
Returns:
Dict[Hashable, np.ndarray]: A dictionary with the same keys where the images have been normalized.
"""
# Copy the input dictionary to avoid modifying the original data.
d: Dict[Hashable, np.ndarray] = dict(data)
# Iterate over each key specified in the transform and normalize the corresponding image.
for key in self.keys:
d[key] = self.normalizer(d[key])
return d
# Create aliases for the dictionary-based normalization transform.
CustomNormalizeImageD = CustomNormalizeImageDict = CustomNormalizeImaged

@ -98,7 +98,7 @@ def main():
base_filename = f"{chosen_model}_{chosen_criterion}_{chosen_optimizer}_{chosen_scheduler}"
# Determine the output directory relative to this script.
base_dir = os.path.join(script_path, "config/jsons", "train" if is_training else "predict")
base_dir = os.path.join(script_path, "config/templates", "train" if is_training else "predict")
os.makedirs(base_dir, exist_ok=True)
filename = f"{base_filename}.json"

@ -2,13 +2,13 @@ from config.config import Config
from pprint import pprint
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json')
pprint(config, indent=4)
print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV.json')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV.json')
pprint(config, indent=4)
print('\n\n')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/jsons/predict/ModelV_1.json')
config = Config.load_json('/workspace/ext_data/projects/model-v/config/templates/predict/ModelV_1.json')
pprint(config, indent=4)
Loading…
Cancel
Save