add roi_size parameter to training

master
laynholt 2 months ago
parent ca6465296a
commit 662889c7a7

@ -70,6 +70,7 @@ class DatasetTrainingConfig(BaseModel):
test_offset: int = 0 # Offset for testing data test_offset: int = 0 # Offset for testing data
batch_size: int = 1 # Batch size for training batch_size: int = 1 # Batch size for training
roi_size: int = 512 # The size of the square window for cropping
num_epochs: int = 100 # Number of training epochs num_epochs: int = 100 # Number of training epochs
val_freq: int = 1 # Frequency of validation during training val_freq: int = 1 # Frequency of validation during training
@ -123,11 +124,14 @@ class DatasetTrainingConfig(BaseModel):
""" """
Validates numeric fields: Validates numeric fields:
- batch_size and num_epochs must be > 0. - batch_size and num_epochs must be > 0.
- roi_size must be > 0.
- val_freq must be >= 0. - val_freq must be >= 0.
- offsets must be >= 0. - offsets must be >= 0.
""" """
if self.batch_size <= 0: if self.batch_size <= 0:
raise ValueError("batch_size must be > 0") raise ValueError("batch_size must be > 0")
if self.roi_size <= 0:
raise ValueError("roi_size must be > 0")
if self.num_epochs <= 0: if self.num_epochs <= 0:
raise ValueError("num_epochs must be > 0") raise ValueError("num_epochs must be > 0")
if self.val_freq < 0: if self.val_freq < 0:

@ -13,7 +13,7 @@ __all__ = [
] ]
def get_train_transforms(): def get_train_transforms(roi_size: int = 512):
""" """
Returns the transformation pipeline for training data. Returns the transformation pipeline for training data.
@ -27,6 +27,9 @@ def get_train_transforms():
7. Apply additional intensity perturbations (noise, contrast, smoothing, histogram shift, and sharpening). 7. Apply additional intensity perturbations (noise, contrast, smoothing, histogram shift, and sharpening).
8. Convert the data types to the desired format. 8. Convert the data types to the desired format.
Args:
roi_size (int):
The size of the square window for cropping. (Default 512).
Returns: Returns:
Compose: The composed transformation pipeline for training. Compose: The composed transformation pipeline for training.
""" """
@ -55,9 +58,9 @@ def get_train_transforms():
keep_size=False, keep_size=False,
), ),
# Pad spatial dimensions to ensure a size of 512. # Pad spatial dimensions to ensure a size of 512.
SpatialPadd(keys=["image", "mask"], spatial_size=512), SpatialPadd(keys=["image", "mask"], spatial_size=roi_size),
# Randomly crop a region of interest of size 512. # Randomly crop a region of interest of size 512.
RandSpatialCropd(keys=["image", "mask"], roi_size=512, random_size=False), RandSpatialCropd(keys=["image", "mask"], roi_size=roi_size, random_size=False),
# Randomly flip the image and label along an axis. # Randomly flip the image and label along an axis.
RandAxisFlipd(keys=["image", "mask"], prob=0.5), RandAxisFlipd(keys=["image", "mask"], prob=0.5),
# Randomly rotate the image and label by 90 degrees. # Randomly rotate the image and label by 90 degrees.

@ -59,7 +59,8 @@ def main():
segmentator = CellSegmentator(config) segmentator = CellSegmentator(config)
segmentator.create_dataloaders( segmentator.create_dataloaders(
train_transforms=get_train_transforms() if mode == "train" else None, train_transforms=get_train_transforms(
roi_size=config.dataset_config.training.roi_size) if mode == "train" else None,
valid_transforms=get_valid_transforms() if mode == "train" else None, valid_transforms=get_valid_transforms() if mode == "train" else None,
test_transforms=get_test_transforms() if mode in ("train", "test") else None, test_transforms=get_test_transforms() if mode in ("train", "test") else None,
predict_transforms=get_predict_transforms() if mode == "predict" else None predict_transforms=get_predict_transforms() if mode == "predict" else None

@ -607,7 +607,7 @@ wheels = [
[[package]] [[package]]
name = "model-v" name = "model-v"
version = "0.1.0" version = "1.0.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "colorlog" }, { name = "colorlog" },

Loading…
Cancel
Save