parent
8885950b9e
commit
2413420620
@ -0,0 +1,30 @@
|
|||||||
|
import logging
|
||||||
|
import colorlog
|
||||||
|
|
||||||
|
__all__ = ["get_logger"]
|
||||||
|
|
||||||
|
def get_logger(name: str = "trainer") -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Creates and configures a logger with colored level names.
|
||||||
|
INFO is light blue, DEBUG is green. Message text remains white.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = colorlog.StreamHandler()
|
||||||
|
formatter = colorlog.ColoredFormatter(
|
||||||
|
fmt="%(log_color)s[%(levelname)s]%(reset)s %(message)s",
|
||||||
|
log_colors={
|
||||||
|
"DEBUG": "green",
|
||||||
|
"INFO": "light_blue",
|
||||||
|
"WARNING": "yellow",
|
||||||
|
"ERROR": "red",
|
||||||
|
"CRITICAL": "bold_red",
|
||||||
|
},
|
||||||
|
style="%"
|
||||||
|
)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
return logger
|
@ -0,0 +1,423 @@
|
|||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from monai.data.dataset import Dataset
|
||||||
|
from monai.transforms import * # type: ignore
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
from pprint import pformat
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from config import Config
|
||||||
|
from core.models import *
|
||||||
|
from core.losses import *
|
||||||
|
from core.optimizers import *
|
||||||
|
from core.schedulers import *
|
||||||
|
|
||||||
|
from core.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class CellSegmentator:
|
||||||
|
def __init__(self, config: Config) -> None:
|
||||||
|
self.__set_seed(config.dataset_config.common.seed)
|
||||||
|
self.__parse_config(config)
|
||||||
|
|
||||||
|
self._train_dataloader: Optional[DataLoader] = None
|
||||||
|
self._valid_dataloader: Optional[DataLoader] = None
|
||||||
|
self._test_dataloader: Optional[DataLoader] = None
|
||||||
|
self._predict_dataloader: Optional[DataLoader] = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloaders(
|
||||||
|
self,
|
||||||
|
train_transforms: Optional[Compose] = None,
|
||||||
|
valid_transforms: Optional[Compose] = None,
|
||||||
|
test_transforms: Optional[Compose] = None,
|
||||||
|
predict_transforms: Optional[Compose] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Creates train, validation, test, and prediction dataloaders based on dataset configuration
|
||||||
|
and provided transforms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_transforms (Optional[Compose]): Transformations for training data.
|
||||||
|
valid_transforms (Optional[Compose]): Transformations for validation data.
|
||||||
|
test_transforms (Optional[Compose]): Transformations for testing data.
|
||||||
|
predict_transforms (Optional[Compose]): Transformations for prediction data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required transforms are missing.
|
||||||
|
RuntimeError: If critical dataset config values are missing.
|
||||||
|
"""
|
||||||
|
if self._dataset_setup.is_training and train_transforms is None:
|
||||||
|
raise ValueError("Training mode requires 'train_transforms' to be provided.")
|
||||||
|
elif not self._dataset_setup.is_training and test_transforms is None and predict_transforms is None:
|
||||||
|
raise ValueError("In inference mode, at least one of 'test_transforms' or 'predict_transforms' must be provided.")
|
||||||
|
|
||||||
|
if self._dataset_setup.is_training:
|
||||||
|
# Training mode: handle either pre-split datasets or splitting on the fly
|
||||||
|
if self._dataset_setup.training.is_split:
|
||||||
|
# Validate presence of validation transforms if validation directory and size are set
|
||||||
|
if (
|
||||||
|
self._dataset_setup.training.pre_split.valid_dir and
|
||||||
|
self._dataset_setup.training.valid_size and
|
||||||
|
valid_transforms is None
|
||||||
|
):
|
||||||
|
raise ValueError("Validation transforms must be provided when using pre-split validation data.")
|
||||||
|
|
||||||
|
# Use explicitly split directories
|
||||||
|
train_dir = self._dataset_setup.training.pre_split.train_dir
|
||||||
|
valid_dir = self._dataset_setup.training.pre_split.valid_dir
|
||||||
|
test_dir = self._dataset_setup.training.pre_split.test_dir
|
||||||
|
|
||||||
|
train_offset = self._dataset_setup.training.train_offset
|
||||||
|
valid_offset = self._dataset_setup.training.valid_offset
|
||||||
|
test_offset = self._dataset_setup.training.test_offset
|
||||||
|
|
||||||
|
shuffle = False
|
||||||
|
else:
|
||||||
|
# Same validation for split mode with full data directory
|
||||||
|
if (
|
||||||
|
self._dataset_setup.training.split.all_data_dir and
|
||||||
|
self._dataset_setup.training.valid_size and
|
||||||
|
valid_transforms is None
|
||||||
|
):
|
||||||
|
raise ValueError("Validation transforms must be provided when splitting dataset.")
|
||||||
|
|
||||||
|
# Automatically split dataset from one directory
|
||||||
|
train_dir = valid_dir = test_dir = self._dataset_setup.training.split.all_data_dir
|
||||||
|
|
||||||
|
number_of_images = len(os.listdir(os.path.join(train_dir, 'images')))
|
||||||
|
if number_of_images == 0:
|
||||||
|
raise FileNotFoundError(f"No images found in '{train_dir}/images'")
|
||||||
|
|
||||||
|
# Calculate train/valid sizes
|
||||||
|
train_size = (
|
||||||
|
self._dataset_setup.training.train_size
|
||||||
|
if isinstance(self._dataset_setup.training.train_size, int)
|
||||||
|
else int(number_of_images * self._dataset_setup.training.train_size)
|
||||||
|
)
|
||||||
|
valid_size = (
|
||||||
|
self._dataset_setup.training.valid_size
|
||||||
|
if isinstance(self._dataset_setup.training.valid_size, int)
|
||||||
|
else int(number_of_images * self._dataset_setup.training.valid_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
train_offset = self._dataset_setup.training.train_offset
|
||||||
|
valid_offset = self._dataset_setup.training.valid_offset + train_size
|
||||||
|
test_offset = self._dataset_setup.training.test_offset + train_size + valid_size
|
||||||
|
|
||||||
|
shuffle = True
|
||||||
|
|
||||||
|
# Train dataloader
|
||||||
|
train_dataset = self.__get_dataset(
|
||||||
|
images_dir=os.path.join(train_dir, 'images'),
|
||||||
|
masks_dir=os.path.join(train_dir, 'masks'),
|
||||||
|
transforms=train_transforms, # type: ignore
|
||||||
|
size=self._dataset_setup.training.train_size,
|
||||||
|
offset=train_offset,
|
||||||
|
shuffle=shuffle
|
||||||
|
)
|
||||||
|
self._train_dataloader = DataLoader(train_dataset, batch_size=self._dataset_setup.training.batch_size, shuffle=True)
|
||||||
|
logger.info(f"Loaded training dataset with {len(train_dataset)} samples.")
|
||||||
|
|
||||||
|
# Validation dataloader
|
||||||
|
if valid_transforms is not None:
|
||||||
|
if not valid_dir or not self._dataset_setup.training.valid_size:
|
||||||
|
raise RuntimeError("Validation directory or size is not properly configured.")
|
||||||
|
valid_dataset = self.__get_dataset(
|
||||||
|
images_dir=os.path.join(valid_dir, 'images'),
|
||||||
|
masks_dir=os.path.join(valid_dir, 'masks'),
|
||||||
|
transforms=valid_transforms,
|
||||||
|
size=self._dataset_setup.training.valid_size,
|
||||||
|
offset=valid_offset,
|
||||||
|
shuffle=shuffle
|
||||||
|
)
|
||||||
|
self._valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
|
||||||
|
logger.info(f"Loaded validation dataset with {len(valid_dataset)} samples.")
|
||||||
|
|
||||||
|
# Test dataloader
|
||||||
|
if test_transforms is not None:
|
||||||
|
if not test_dir or not self._dataset_setup.training.test_size:
|
||||||
|
raise RuntimeError("Test directory or size is not properly configured.")
|
||||||
|
test_dataset = self.__get_dataset(
|
||||||
|
images_dir=os.path.join(test_dir, 'images'),
|
||||||
|
masks_dir=os.path.join(test_dir, 'masks'),
|
||||||
|
transforms=test_transforms,
|
||||||
|
size=self._dataset_setup.training.test_size,
|
||||||
|
offset=test_offset,
|
||||||
|
shuffle=shuffle
|
||||||
|
)
|
||||||
|
self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
||||||
|
logger.info(f"Loaded test dataset with {len(test_dataset)} samples.")
|
||||||
|
|
||||||
|
# Prediction dataloader
|
||||||
|
if predict_transforms is not None:
|
||||||
|
if not test_dir or not self._dataset_setup.training.test_size:
|
||||||
|
raise RuntimeError("Prediction directory or size is not properly configured.")
|
||||||
|
predict_dataset = self.__get_dataset(
|
||||||
|
images_dir=os.path.join(test_dir, 'images'),
|
||||||
|
masks_dir=None,
|
||||||
|
transforms=predict_transforms,
|
||||||
|
size=self._dataset_setup.training.test_size,
|
||||||
|
offset=test_offset,
|
||||||
|
shuffle=shuffle
|
||||||
|
)
|
||||||
|
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
|
||||||
|
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Inference mode (no training)
|
||||||
|
test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images')
|
||||||
|
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks')
|
||||||
|
|
||||||
|
if test_transforms is not None:
|
||||||
|
test_dataset = self.__get_dataset(
|
||||||
|
images_dir=test_images,
|
||||||
|
masks_dir=test_masks,
|
||||||
|
transforms=test_transforms,
|
||||||
|
size=self._dataset_setup.testing.test_size,
|
||||||
|
offset=self._dataset_setup.testing.test_offset,
|
||||||
|
shuffle=self._dataset_setup.testing.shuffle
|
||||||
|
)
|
||||||
|
self._test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
||||||
|
logger.info(f"Loaded test dataset with {len(test_dataset)} samples.")
|
||||||
|
|
||||||
|
if predict_transforms is not None:
|
||||||
|
predict_dataset = self.__get_dataset(
|
||||||
|
images_dir=test_images,
|
||||||
|
masks_dir=None,
|
||||||
|
transforms=predict_transforms,
|
||||||
|
size=self._dataset_setup.testing.test_size,
|
||||||
|
offset=self._dataset_setup.testing.test_offset,
|
||||||
|
shuffle=self._dataset_setup.testing.shuffle
|
||||||
|
)
|
||||||
|
self._predict_dataloader = DataLoader(predict_dataset, batch_size=1, shuffle=False)
|
||||||
|
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
|
||||||
|
|
||||||
|
|
||||||
|
def train(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def predict(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def __parse_config(self, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Parses the given configuration object to initialize model, criterion,
|
||||||
|
optimizer, scheduler, and dataset setup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Config): Configuration object with model, optimizer,
|
||||||
|
scheduler, criterion, and dataset setup information.
|
||||||
|
"""
|
||||||
|
model = config.model
|
||||||
|
criterion = config.criterion
|
||||||
|
optimizer = config.optimizer
|
||||||
|
scheduler = config.scheduler
|
||||||
|
|
||||||
|
# Log the full configuration dictionary
|
||||||
|
full_config_dict = {
|
||||||
|
"model": model.dump(),
|
||||||
|
"criterion": criterion.dump() if criterion else None,
|
||||||
|
"optimizer": optimizer.dump() if optimizer else None,
|
||||||
|
"scheduler": scheduler.dump() if scheduler else None,
|
||||||
|
"dataset_config": config.dataset_config.model_dump()
|
||||||
|
}
|
||||||
|
logger.info("========== Parsed Configuration ==========")
|
||||||
|
logger.info(pformat(full_config_dict, width=120))
|
||||||
|
logger.info("==========================================")
|
||||||
|
|
||||||
|
# Initialize model using the model registry
|
||||||
|
self._model = ModelRegistry.get_model_class(model.name)(model.params)
|
||||||
|
logger.info(f"Initialized model: {model.name}")
|
||||||
|
|
||||||
|
# Initialize loss criterion if specified
|
||||||
|
self._criterion = (
|
||||||
|
CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params)
|
||||||
|
if criterion is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if self._criterion is not None and criterion is not None:
|
||||||
|
logger.info(f"Initialized criterion: {criterion.name}")
|
||||||
|
else:
|
||||||
|
logger.info("Criterion: not specified")
|
||||||
|
|
||||||
|
# Initialize optimizer if specified
|
||||||
|
self._optimizer = (
|
||||||
|
OptimizerRegistry.get_optimizer_class(optimizer.name)(
|
||||||
|
model_params=self._model.parameters(),
|
||||||
|
optim_params=optimizer.params
|
||||||
|
)
|
||||||
|
if optimizer is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if self._optimizer is not None and optimizer is not None:
|
||||||
|
logger.info(f"Initialized optimizer: {optimizer.name}")
|
||||||
|
else:
|
||||||
|
logger.info("Optimizer: not specified")
|
||||||
|
|
||||||
|
# Initialize scheduler only if both scheduler and optimizer are defined
|
||||||
|
self._scheduler = (
|
||||||
|
SchedulerRegistry.get_scheduler_class(scheduler.name)(
|
||||||
|
optimizer=self._optimizer.optim,
|
||||||
|
params=scheduler.params
|
||||||
|
)
|
||||||
|
if scheduler is not None and self._optimizer is not None and self._optimizer.optim is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if self._scheduler is not None and scheduler is not None:
|
||||||
|
logger.info(f"Initialized scheduler: {scheduler.name}")
|
||||||
|
else:
|
||||||
|
logger.info("Scheduler: not specified")
|
||||||
|
|
||||||
|
# Save dataset config
|
||||||
|
self._dataset_setup = config.dataset_config
|
||||||
|
logger.info("Dataset setup loaded")
|
||||||
|
common = config.dataset_config.common
|
||||||
|
logger.info(f"Seed: {common.seed}")
|
||||||
|
logger.info(f"Device: {common.device}")
|
||||||
|
logger.info(f"Predictions output dir: {common.predictions_dir}")
|
||||||
|
|
||||||
|
if config.dataset_config.is_training:
|
||||||
|
training = config.dataset_config.training
|
||||||
|
logger.info("Mode: Training")
|
||||||
|
logger.info(f" Batch size: {training.batch_size}")
|
||||||
|
logger.info(f" Epochs: {training.num_epochs}")
|
||||||
|
logger.info(f" Validation frequency: {training.val_freq}")
|
||||||
|
logger.info(f" Use AMP: {'yes' if training.use_amp else 'no'}")
|
||||||
|
logger.info(f" Pretrained weights: {training.pretrained_weights}")
|
||||||
|
|
||||||
|
if training.is_split:
|
||||||
|
logger.info(" Using pre-split directories:")
|
||||||
|
logger.info(f" Train dir: {training.pre_split.train_dir}")
|
||||||
|
logger.info(f" Valid dir: {training.pre_split.valid_dir}")
|
||||||
|
logger.info(f" Test dir: {training.pre_split.test_dir}")
|
||||||
|
else:
|
||||||
|
logger.info(" Using unified dataset with splits:")
|
||||||
|
logger.info(f" All data dir: {training.split.all_data_dir}")
|
||||||
|
logger.info(f" Shuffle: {'yes' if training.split.shuffle else 'no'}")
|
||||||
|
|
||||||
|
logger.info(" Dataset split:")
|
||||||
|
logger.info(f" Train size: {training.train_size}, offset: {training.train_offset}")
|
||||||
|
logger.info(f" Valid size: {training.valid_size}, offset: {training.valid_offset}")
|
||||||
|
logger.info(f" Test size: {training.test_size}, offset: {training.test_offset}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
testing = config.dataset_config.testing
|
||||||
|
logger.info("Mode: Inference")
|
||||||
|
logger.info(f" Test dir: {testing.test_dir}")
|
||||||
|
logger.info(f" Test size: {testing.test_size} (offset: {testing.test_offset})")
|
||||||
|
logger.info(f" Shuffle: {'yes' if testing.shuffle else 'no'}")
|
||||||
|
logger.info(f" Use ensemble: {'yes' if testing.use_ensemble else 'no'}")
|
||||||
|
logger.info(f" Pretrained weights:")
|
||||||
|
logger.info(f" Single model: {testing.pretrained_weights}")
|
||||||
|
logger.info(f" Ensemble model 1: {testing.ensemble_pretrained_weights1}")
|
||||||
|
logger.info(f" Ensemble model 2: {testing.ensemble_pretrained_weights2}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def __set_seed(self, seed: Optional[int]) -> None:
|
||||||
|
"""
|
||||||
|
Sets the random seed for reproducibility across Python, NumPy, and PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (Optional[int]): Seed value. If None, no seeding is performed.
|
||||||
|
"""
|
||||||
|
if seed is not None:
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
logger.info(f"Random seed set to {seed}")
|
||||||
|
else:
|
||||||
|
logger.info("Seed not set (None provided)")
|
||||||
|
|
||||||
|
|
||||||
|
def __get_dataset(
|
||||||
|
self,
|
||||||
|
images_dir: str,
|
||||||
|
masks_dir: Optional[str],
|
||||||
|
transforms: Compose,
|
||||||
|
size: Union[int, float],
|
||||||
|
offset: int,
|
||||||
|
shuffle: bool
|
||||||
|
) -> Dataset:
|
||||||
|
"""
|
||||||
|
Loads and returns a dataset object from image and optional mask directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images_dir (str): Path to directory or glob pattern for input images.
|
||||||
|
masks_dir (Optional[str]): Path to directory or glob pattern for masks.
|
||||||
|
transforms (Compose): Transformations to apply to each image or pair.
|
||||||
|
size (Union[int, float]): Either an integer or a fraction of the dataset.
|
||||||
|
offset (int): Number of images to skip from the start.
|
||||||
|
shuffle (bool): Whether to shuffle the dataset before slicing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: A dataset containing image and optional mask paths.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If no images are found.
|
||||||
|
ValueError: If masks are provided but do not match image count.
|
||||||
|
ValueError: If dataset is too small for requested size or offset.
|
||||||
|
"""
|
||||||
|
# Collect sorted list of image paths
|
||||||
|
images = sorted(glob.glob(images_dir))
|
||||||
|
if not images:
|
||||||
|
raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'")
|
||||||
|
|
||||||
|
if masks_dir is not None:
|
||||||
|
# Collect and validate sorted list of mask paths
|
||||||
|
masks = sorted(glob.glob(masks_dir))
|
||||||
|
if len(images) != len(masks):
|
||||||
|
raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})")
|
||||||
|
|
||||||
|
# Convert float size (fraction) to absolute count
|
||||||
|
size = size if isinstance(size, int) else int(size * len(images))
|
||||||
|
|
||||||
|
if size <= 0:
|
||||||
|
raise ValueError(f"Size must be positive, got: {size}")
|
||||||
|
|
||||||
|
if len(images) < size:
|
||||||
|
raise ValueError(f"Not enough images ({len(images)}) for requested size ({size})")
|
||||||
|
|
||||||
|
if len(images) < size + offset:
|
||||||
|
raise ValueError(f"Offset ({offset}) + size ({size}) exceeds dataset length ({len(images)})")
|
||||||
|
|
||||||
|
# Shuffle image-mask pairs if requested
|
||||||
|
if shuffle:
|
||||||
|
if masks_dir is not None:
|
||||||
|
combined = list(zip(images, masks)) # type: ignore
|
||||||
|
random.shuffle(combined)
|
||||||
|
images, masks = zip(*combined)
|
||||||
|
else:
|
||||||
|
random.shuffle(images)
|
||||||
|
|
||||||
|
# Apply offset and limit by size
|
||||||
|
images = images[offset: offset + size]
|
||||||
|
if masks_dir is not None:
|
||||||
|
masks = masks[offset: offset + size] # type: ignore
|
||||||
|
|
||||||
|
# Prepare data structure for Dataset class
|
||||||
|
if masks_dir is not None:
|
||||||
|
data = [
|
||||||
|
{"image": image, "mask": mask}
|
||||||
|
for image, mask in zip(images, masks) # type: ignore
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
data = [{"image": image} for image in images]
|
||||||
|
|
||||||
|
return Dataset(data, transforms)
|
Loading…
Reference in new issue