add roi_size parameter to training

master
laynholt 2 months ago
parent ca6465296a
commit 662889c7a7

@ -70,6 +70,7 @@ class DatasetTrainingConfig(BaseModel):
test_offset: int = 0 # Offset for testing data
batch_size: int = 1 # Batch size for training
roi_size: int = 512 # The size of the square window for cropping
num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training
@ -123,11 +124,14 @@ class DatasetTrainingConfig(BaseModel):
"""
Validates numeric fields:
- batch_size and num_epochs must be > 0.
- roi_size must be > 0.
- val_freq must be >= 0.
- offsets must be >= 0.
"""
if self.batch_size <= 0:
raise ValueError("batch_size must be > 0")
if self.roi_size <= 0:
raise ValueError("roi_size must be > 0")
if self.num_epochs <= 0:
raise ValueError("num_epochs must be > 0")
if self.val_freq < 0:

@ -13,7 +13,7 @@ __all__ = [
]
def get_train_transforms():
def get_train_transforms(roi_size: int = 512):
"""
Returns the transformation pipeline for training data.
@ -27,6 +27,9 @@ def get_train_transforms():
7. Apply additional intensity perturbations (noise, contrast, smoothing, histogram shift, and sharpening).
8. Convert the data types to the desired format.
Args:
roi_size (int):
The size of the square window for cropping. (Default 512).
Returns:
Compose: The composed transformation pipeline for training.
"""
@ -55,9 +58,9 @@ def get_train_transforms():
keep_size=False,
),
# Pad spatial dimensions to ensure a size of 512.
SpatialPadd(keys=["image", "mask"], spatial_size=512),
SpatialPadd(keys=["image", "mask"], spatial_size=roi_size),
# Randomly crop a region of interest of size 512.
RandSpatialCropd(keys=["image", "mask"], roi_size=512, random_size=False),
RandSpatialCropd(keys=["image", "mask"], roi_size=roi_size, random_size=False),
# Randomly flip the image and label along an axis.
RandAxisFlipd(keys=["image", "mask"], prob=0.5),
# Randomly rotate the image and label by 90 degrees.

@ -59,7 +59,8 @@ def main():
segmentator = CellSegmentator(config)
segmentator.create_dataloaders(
train_transforms=get_train_transforms() if mode == "train" else None,
train_transforms=get_train_transforms(
roi_size=config.dataset_config.training.roi_size) if mode == "train" else None,
valid_transforms=get_valid_transforms() if mode == "train" else None,
test_transforms=get_test_transforms() if mode in ("train", "test") else None,
predict_transforms=get_predict_transforms() if mode == "predict" else None

@ -607,7 +607,7 @@ wheels = [
[[package]]
name = "model-v"
version = "0.1.0"
version = "1.0.0"
source = { virtual = "." }
dependencies = [
{ name = "colorlog" },

Loading…
Cancel
Save