diff --git a/.gitignore b/.gitignore index 9d851fe..adb7fff 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,8 @@ __pycache__/ **/__pycache__/ .vscode/ -*.json \ No newline at end of file +*.json + +outputs/ +weights/ +wandb/ \ No newline at end of file diff --git a/config/__init__.py b/config/__init__.py index d890918..8f58cfc 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -1,3 +1,5 @@ -from .config import Config +from .config import Config, ComponentConfig +from .wandb_config import WandbConfig +from .dataset_config import DatasetConfig -__all__ = ["Config"] \ No newline at end of file +__all__ = ["Config", "WandbConfig", "DatasetConfig", "ComponentConfig"] \ No newline at end of file diff --git a/config/config.py b/config/config.py index acd6ee3..df017c6 100644 --- a/config/config.py +++ b/config/config.py @@ -2,6 +2,7 @@ import json from typing import Any, Dict, Optional from pydantic import BaseModel +from .wandb_config import WandbConfig from .dataset_config import DatasetConfig @@ -33,6 +34,7 @@ class ComponentConfig(BaseModel): class Config(BaseModel): model: ComponentConfig dataset_config: DatasetConfig + wandb_config: WandbConfig criterion: Optional[ComponentConfig] = None optimizer: Optional[ComponentConfig] = None scheduler: Optional[ComponentConfig] = None @@ -57,6 +59,7 @@ class Config(BaseModel): data["optimizer"] = self.optimizer.dump() if self.scheduler is not None: data["scheduler"] = self.scheduler.dump() + data["wandb"] = self.wandb_config.model_dump() return data @@ -88,8 +91,9 @@ class Config(BaseModel): with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) - # Parse dataset_config using its Pydantic model. + # Parse dataset_config and wandb_config using its Pydantic model. dataset_config = DatasetConfig(**data.get("dataset_config", {})) + wandb_config = WandbConfig(**data.get("wandb", {})) # Helper function to parse registry fields. def parse_field(component_data: Dict[str, Any], registry_getter) -> Optional[ComponentConfig]: @@ -119,5 +123,6 @@ class Config(BaseModel): dataset_config=dataset_config, criterion=parsed_criterion, optimizer=parsed_optimizer, - scheduler=parsed_scheduler + scheduler=parsed_scheduler, + wandb_config=wandb_config ) diff --git a/config/dataset_config.py b/config/dataset_config.py index 4ea53d2..9a43858 100644 --- a/config/dataset_config.py +++ b/config/dataset_config.py @@ -7,10 +7,11 @@ class DatasetCommonConfig(BaseModel): """ Common configuration fields shared by both training and testing. """ - seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) - device: str = "cuda0" # Device used for training/testing (e.g., 'cpu' or 'cuda') + seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) + device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda') use_tta: bool = False # Flag to use Test-Time Augmentation (TTA) use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) + masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars' predictions_dir: str = "." # Directory to save predictions @model_validator(mode="after") @@ -62,8 +63,8 @@ class DatasetTrainingConfig(BaseModel): split: TrainingSplitInfo = TrainingSplitInfo() train_size: Union[int, float] = 0.7 # Training data size (int for static, float in (0,1] for dynamic) - valid_size: Union[int, float] = 0.2 # Validation data size (int for static, float in (0,1] for dynamic) - test_size: Union[int, float] = 0.1 # Testing data size (int for static, float in (0,1] for dynamic) + valid_size: Union[int, float] = 0.1 # Validation data size (int for static, float in (0,1] for dynamic) + test_size: Union[int, float] = 0.2 # Testing data size (int for static, float in (0,1] for dynamic) train_offset: int = 0 # Offset for training data valid_offset: int = 0 # Offset for validation data test_offset: int = 0 # Offset for testing data @@ -99,7 +100,7 @@ class DatasetTrainingConfig(BaseModel): - If is_split is False, validates split (all_data_dir must be non-empty and exist). """ if any(isinstance(s, float) for s in (self.train_size, self.valid_size, self.test_size)): - if (self.train_size + self.valid_size + self.test_size) > 1: + if (self.train_size + self.valid_size + self.test_size) > 1 and not self.is_split: raise ValueError("The total sample size with dynamically defined sizes must be <= 1") if not self.is_split: @@ -214,34 +215,6 @@ class DatasetTestingConfig(BaseModel): return self -class WandbConfig(BaseModel): - """ - Configuration for Weights & Biases logging. - """ - use_wandb: bool = False # Whether to enable WandB logging - project: Optional[str] = None # WandB project name - entity: Optional[str] = None # WandB entity (user or team) - name: Optional[str] = None # Name of the run - tags: Optional[list[str]] = None # List of tags for the run - notes: Optional[str] = None # Notes or description for the run - save_code: bool = True # Whether to save the code to WandB - - @model_validator(mode="after") - def validate_wandb(cls) -> "WandbConfig": - if cls.use_wandb: - if not cls.project: - raise ValueError("When use_wandb=True, 'project' must be provided") - if not cls.entity: - raise ValueError("When use_wandb=True, 'entity' must be provided") - return cls - - def asdict(self) -> Dict[str, Any]: - """ - Return a dict of all W&B parameters, excluding 'use_wandb' and any None values. - """ - return self.model_dump(exclude_none=True, exclude={"use_wandb"}) - - class DatasetConfig(BaseModel): """ Main dataset configuration that groups fields into nested models for a structured and readable JSON. @@ -250,7 +223,6 @@ class DatasetConfig(BaseModel): common: DatasetCommonConfig = DatasetCommonConfig() training: DatasetTrainingConfig = DatasetTrainingConfig() testing: DatasetTestingConfig = DatasetTestingConfig() - wandb: WandbConfig = WandbConfig() @model_validator(mode="after") def validate_config(self) -> "DatasetConfig": @@ -265,15 +237,11 @@ class DatasetConfig(BaseModel): if (self.training.is_split and self.training.pre_split.test_dir) or (not self.training.is_split): if self.training.test_size > 0 and not self.common.predictions_dir: raise ValueError("predictions_dir must be provided when test_size is non-zero") - if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): - raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") else: if self.testing is None: raise ValueError("Testing configuration must be provided when is_training is False") if self.testing.test_size > 0 and not self.common.predictions_dir: raise ValueError("predictions_dir must be provided when test_size is non-zero") - if self.common.predictions_dir and not os.path.exists(self.common.predictions_dir): - raise ValueError(f"Path for predictions_dir does not exist: {self.common.predictions_dir}") return self def model_dump(self, **kwargs) -> Dict[str, Any]: @@ -286,12 +254,10 @@ class DatasetConfig(BaseModel): "is_training": self.is_training, "common": self.common.model_dump(), "training": self.training.model_dump() if self.training else {}, - "wandb": self.wandb.model_dump() } else: return { "is_training": self.is_training, "common": self.common.model_dump(), "testing": self.testing.model_dump() if self.testing else {}, - "wandb": self.wandb.model_dump() } diff --git a/config/wandb_config.py b/config/wandb_config.py new file mode 100644 index 0000000..82d10c5 --- /dev/null +++ b/config/wandb_config.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel, model_validator +from typing import Any, Dict, Optional + + +class WandbConfig(BaseModel): + """ + Configuration for Weights & Biases logging. + """ + use_wandb: bool = False # Whether to enable WandB logging + project: Optional[str] = None # WandB project name + group: Optional[str] = None # WandB group name + entity: Optional[str] = None # WandB entity (user or team) + name: Optional[str] = None # Name of the run + tags: Optional[list[str]] = None # List of tags for the run + notes: Optional[str] = None # Notes or description for the run + save_code: bool = True # Whether to save the code to WandB + + @model_validator(mode="after") + def validate_wandb(self) -> "WandbConfig": + if self.use_wandb: + if not self.project: + raise ValueError("When use_wandb=True, 'project' must be provided") + return self + + def asdict(self) -> Dict[str, Any]: + """ + Return a dict of all W&B parameters, excluding 'use_wandb' and any None values. + """ + return self.model_dump(exclude_none=True, exclude={"use_wandb"}) \ No newline at end of file diff --git a/core/data/transforms/__init__.py b/core/data/transforms/__init__.py index 9fa9255..74335b8 100644 --- a/core/data/transforms/__init__.py +++ b/core/data/transforms/__init__.py @@ -169,16 +169,21 @@ def get_predict_transforms(): """ pred_transforms = Compose( [ - # Load the image data in (H, W, C) format. - CustomLoadImage(image_only=False), + # Load image data in (H, W, C) format (allow missing keys). + CustomLoadImaged(keys=["image"], allow_missing_keys=True, image_only=False), # Normalize the (H, W, C) image using the specified percentiles. - CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]), - # Ensure the image is in channel-first format. - EnsureChannelFirst(channel_dim=-1), # image shape: (C, H, W) + CustomNormalizeImaged( + keys=["image"], + allow_missing_keys=True, + channel_wise=False, + percentiles=[0.0, 99.5], + ), + # Ensure image is in channel-first format. + EnsureChannelFirstd(keys=["image"], allow_missing_keys=True, channel_dim=-1), # Scale image intensities. - ScaleIntensity(), - # Convert the image to the required tensor type. - EnsureType(data_type="tensor"), + ScaleIntensityd(keys=["image"], allow_missing_keys=True), + # Ensure that the data types are correct. + EnsureTyped(keys=["image"], allow_missing_keys=True), ] ) return pred_transforms diff --git a/core/data/transforms/cell_aware.py b/core/data/transforms/cell_aware.py index a19e030..b582346 100644 --- a/core/data/transforms/cell_aware.py +++ b/core/data/transforms/cell_aware.py @@ -1,13 +1,18 @@ import copy +import torch import numpy as np from typing import Dict, Sequence, Tuple, Union from skimage.segmentation import find_boundaries from monai.transforms import RandScaleIntensity, Compose, MapTransform # type: ignore +import logging __all__ = ["BoundaryExclusion", "IntensityDiversification"] +logger = logging.getLogger("cell_aware") + + class BoundaryExclusion(MapTransform): """ Map the cell boundary pixel labels to the background class (0). @@ -164,7 +169,8 @@ class IntensityDiversification(MapTransform): # If there are no unique cell objects in this channel, raise an exception. if cell_ids.size == 0: - raise ValueError(f"No unique objects found in the label mask for channel {c}") + logger.warning(f"No unique objects found in the label mask for channel {c}") + continue # Determine the number of cells to modify using the change_cell_ratio. change_count = int(len(cell_ids) * self.change_cell_ratio) @@ -175,7 +181,10 @@ class IntensityDiversification(MapTransform): # Create a binary mask for the current channel: # - Pixels corresponding to the selected cell IDs are set to 1. # - All other pixels are set to 0. - mask = np.isin(channel_label, selected).astype(np.float32) + mask_np = np.isin(channel_label, selected).astype(np.float32) + + # Convert mask to same dtype and device + mask = torch.from_numpy(mask_np).to(dtype=torch.float32, device=channel_label.device) # Separate the image channel into two components: # 1. img_orig: The portion of the image that remains unchanged. @@ -183,8 +192,11 @@ class IntensityDiversification(MapTransform): img_orig = (1 - mask) * img_channel img_changed = mask * img_channel + # Add a channel dimension for RandScaleIntensity: (1, H, W) + img_changed = img_changed.unsqueeze(0) # Apply a random intensity scaling transformation to the selected regions. img_changed = self.randscale_intensity(img_changed) + img_changed = img_changed.squeeze(0) # type: ignore # back to shape (H, W) # Combine the unchanged and modified parts to update the image channel. data["image"][c] = img_orig + img_changed diff --git a/core/losses/base.py b/core/losses/base.py index 65ddbad..0b014e2 100644 --- a/core/losses/base.py +++ b/core/losses/base.py @@ -6,7 +6,7 @@ from typing import Dict, Any, Optional from monai.metrics.cumulative_average import CumulativeAverage -class BaseLoss(abc.ABC): +class BaseLoss(nn.Module, abc.ABC): """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" def __init__(self, params: Optional[BaseModel] = None): diff --git a/core/losses/bce.py b/core/losses/bce.py index 32b2a96..f7c9484 100644 --- a/core/losses/bce.py +++ b/core/losses/bce.py @@ -28,7 +28,7 @@ class BCELossParams(BaseModel): loss_kwargs = self.model_dump() if not self.with_logits: loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss - loss_kwargs.pop("with_logits", None) + loss_kwargs.pop("with_logits", None) weight = loss_kwargs.get("weight") pos_weight = loss_kwargs.get("pos_weight") diff --git a/core/segmentator.py b/core/segmentator.py index 142484a..7b07172 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -12,6 +12,7 @@ import fastremap import fill_voids from skimage import morphology from skimage.segmentation import find_boundaries +from scipy.special import expit from scipy.ndimage import mean, find_objects from monai.data.dataset import Dataset @@ -53,10 +54,11 @@ logger = get_logger() class CellSegmentator: def __init__(self, config: Config) -> None: + self._device: torch.device = torch.device(config.dataset_config.common.device or "cpu") + self.__set_seed(config.dataset_config.common.seed) self.__parse_config(config) - self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu") self._scaler = ( torch.amp.GradScaler(self._device.type) # type: ignore if self._dataset_setup.is_training and self._dataset_setup.common.use_amp @@ -153,7 +155,7 @@ class CellSegmentator: # Train dataloader train_dataset = self.__get_dataset( images_dir=os.path.join(train_dir, 'images'), - masks_dir=os.path.join(train_dir, 'masks'), + masks_dir=os.path.join(train_dir, 'masks', self._dataset_setup.common.masks_subdir), transforms=train_transforms, # type: ignore size=self._dataset_setup.training.train_size, offset=train_offset, @@ -168,7 +170,7 @@ class CellSegmentator: raise RuntimeError("Validation directory or size is not properly configured.") valid_dataset = self.__get_dataset( images_dir=os.path.join(valid_dir, 'images'), - masks_dir=os.path.join(valid_dir, 'masks'), + masks_dir=os.path.join(valid_dir, 'masks', self._dataset_setup.common.masks_subdir), transforms=valid_transforms, size=self._dataset_setup.training.valid_size, offset=valid_offset, @@ -183,7 +185,7 @@ class CellSegmentator: raise RuntimeError("Test directory or size is not properly configured.") test_dataset = self.__get_dataset( images_dir=os.path.join(test_dir, 'images'), - masks_dir=os.path.join(test_dir, 'masks'), + masks_dir=os.path.join(test_dir, 'masks', self._dataset_setup.common.masks_subdir), transforms=test_transforms, size=self._dataset_setup.training.test_size, offset=test_offset, @@ -210,7 +212,7 @@ class CellSegmentator: else: # Inference mode (no training) test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images') - test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks') + test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks', self._dataset_setup.common.masks_subdir) if test_transforms is not None: test_dataset = self.__get_dataset( @@ -385,7 +387,7 @@ class CellSegmentator: batch_counter = 0 for batch in tqdm(self._predict_dataloader, desc="Predicting"): # Move input images to the configured device (CPU/GPU) - inputs = batch["img"].to(self._device) + inputs = batch["image"].to(self._device) # Use automatic mixed precision if enabled in dataset setup with torch.amp.autocast( # type: ignore @@ -443,15 +445,40 @@ class CellSegmentator: def load_from_checkpoint(self, checkpoint_path: str) -> None: """ - Loads model weights from a specified checkpoint into the current model. + Loads model weights from a specified checkpoint into the current model, + but only for parameters whose shapes match. Parameters with mismatched + shapes (e.g., classification heads with different output sizes) remain + at their initialized values. Args: checkpoint_path (str): Path to the checkpoint file containing the model weights. """ - # Load the checkpoint onto the correct device (CPU or GPU) - checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True) - # Load the state dict into the model, allowing for missing keys - self._model.load_state_dict(checkpoint['state_dict'], strict=False) + # Load the checkpoint (state_dict) from file onto CPU + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + # Extract nested state_dict if present + state_dict = checkpoint.get("state_dict", checkpoint) + + # Get the current model's parameter dictionary + model_dict = self._model.state_dict() + + # Filter pretrained parameters to those matching in name and shape + pretrained_dict = { + k: v for k, v in state_dict.items() + if k in model_dict and v.size() == model_dict[k].size() + } + + # Log how many parameters are loaded, skipped, or missing + skipped = [k for k in state_dict if k not in pretrained_dict] + missing = [k for k in model_dict if k not in pretrained_dict] + logger.info( + f"Loaded {len(pretrained_dict)} parameters;" + f" skipped {len(skipped)} params from checkpoint;" + f" {len(missing)} params remain uninitialized in model." + ) + + # Update the model's state_dict and load it + model_dict.update(pretrained_dict) + self._model.load_state_dict(model_dict) def save_checkpoint(self, checkpoint_path: str) -> None: @@ -461,12 +488,9 @@ class CellSegmentator: Args: checkpoint_path (str): Path where the checkpoint file will be saved. """ - # Create a checkpoint dictionary containing the model’s state_dict - checkpoint = { - 'state_dict': self._model.state_dict() - } # Write the checkpoint to disk - torch.save(checkpoint, checkpoint_path) + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + torch.save(self._model.state_dict(), checkpoint_path) def __parse_config(self, config: Config) -> None: @@ -492,15 +516,23 @@ class CellSegmentator: if scheduler: logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2)) logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2)) + logger.info("Wandb Config:\n%s", pformat(config.wandb_config.model_dump(), indent=2)) logger.info("==========================================") # Initialize model using the model registry self._model = ModelRegistry.get_model_class(model.name)(model.params) # Loads model weights from a specified checkpoint - if config.dataset_config.is_training: - if config.dataset_config.training.pretrained_weights: - self.load_from_checkpoint(config.dataset_config.training.pretrained_weights) + pretrained_weights = ( + config.dataset_config.training.pretrained_weights + if config.dataset_config.is_training + else config.dataset_config.testing.pretrained_weights + ) + if pretrained_weights: + self.load_from_checkpoint(pretrained_weights) + logger.info(f"Loaded pre-trained weights from: {pretrained_weights}") + + self._model = self._model.to(self._device) # Initialize loss criterion if specified self._criterion = ( @@ -555,6 +587,7 @@ class CellSegmentator: logger.info(f"├─ Seed: {common.seed}") logger.info(f"├─ Device: {common.device}") logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}") + logger.info(f"├─ Masks subdirectory: {common.masks_subdir}") logger.info(f"└─ Predictions output dir: {common.predictions_dir}") if config.dataset_config.is_training: @@ -592,18 +625,21 @@ class CellSegmentator: logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}") logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}") - wandb_cfg = config.dataset_config.wandb - if wandb_cfg.use_wandb: + self._wandb_config = config.wandb_config + if self._wandb_config.use_wandb: logger.info("[W&B]") - logger.info(f"├─ Project: {wandb_cfg.project}") - logger.info(f"├─ Entity: {wandb_cfg.entity}") - if wandb_cfg.name: - logger.info(f"├─ Run name: {wandb_cfg.name}") - if wandb_cfg.tags: - logger.info(f"├─ Tags: {', '.join(wandb_cfg.tags)}") - if wandb_cfg.notes: - logger.info(f"├─ Notes: {wandb_cfg.notes}") - logger.info(f"└─ Save code: {'yes' if wandb_cfg.save_code else 'no'}") + logger.info(f"├─ Project: {self._wandb_config.project}") + if self._wandb_config.group: + logger.info(f"├─ Group: {self._wandb_config.group}") + if self._wandb_config.entity: + logger.info(f"├─ Entity: {self._wandb_config.entity}") + if self._wandb_config.name: + logger.info(f"├─ Run name: {self._wandb_config.name}") + if self._wandb_config.tags: + logger.info(f"├─ Tags: {', '.join(self._wandb_config.tags)}") + if self._wandb_config.notes: + logger.info(f"├─ Notes: {self._wandb_config.notes}") + logger.info(f"└─ Save code: {'yes' if self._wandb_config.save_code else 'no'}") else: logger.info("[W&B] Logging disabled") @@ -657,13 +693,19 @@ class CellSegmentator: ValueError: If dataset is too small for requested size or offset. """ # Collect sorted list of image paths - images = sorted(glob.glob(images_dir)) + images = sorted( + glob.glob(os.path.join(images_dir, '*.tif')) + + glob.glob(os.path.join(images_dir, '*.tiff')) + ) if not images: raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'") if masks_dir is not None: # Collect and validate sorted list of mask paths - masks = sorted(glob.glob(masks_dir)) + masks = sorted( + glob.glob(os.path.join(masks_dir, '*.tif')) + + glob.glob(os.path.join(masks_dir, '*.tiff')) + ) if len(images) != len(masks): raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})") @@ -720,7 +762,7 @@ class CellSegmentator: tablefmt="fancy_grid" ) print(table, "\n") - if self._dataset_setup.wandb.use_wandb: + if self._wandb_config.use_wandb: wandb.log(results, step=step) @@ -765,8 +807,8 @@ class CellSegmentator: # Iterate over batches batch_counter = 0 for batch in tqdm(loader, desc=desc): - inputs = batch["img"].to(self._device) - targets = batch["label"].to(self._device) + inputs = batch["image"].to(self._device) + targets = batch["mask"].to(self._device) # Zero gradients for training if self._optimizer is not None: @@ -787,7 +829,10 @@ class CellSegmentator: flow_targets = self.__compute_flows_from_masks(targets) # Compute loss for this batch - batch_loss = self._criterion(raw_output, flow_targets) # type: ignore + batch_loss = self._criterion( + raw_output, + torch.from_numpy(flow_targets).to(device=raw_output.device) + ) # Post-process and compute F1 during validation and testing if mode in ("valid", "test"): @@ -842,6 +887,9 @@ class CellSegmentator: epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore tp_array, fp_array, fn_array, reduction="micro" ) + epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( # type: ignore + tp_array, fp_array, fn_array, reduction="imagewise" + ) epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore tp_array, fp_array, fn_array, reduction="macro" ) @@ -914,7 +962,7 @@ class CellSegmentator: instance_masks[idx] = self.__segment_instances( probability_map=probabilities[idx], flow=gradflow[idx], - prob_threshold=0.0, + prob_threshold=0.5, flow_threshold=0.4, min_object_size=15 ) @@ -1159,7 +1207,7 @@ class CellSegmentator: Returns: np.ndarray: Sigmoid of the input. """ - return 1 / (1 + np.exp(-z)) + return expit(z) def __save_prediction_masks( @@ -1191,7 +1239,7 @@ class CellSegmentator: # Convert tensors to numpy def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(x, torch.Tensor): - return x.detach().cpu().numpy() + return x.cpu().numpy() return x image_array = to_numpy(image_obj) if image_obj is not None else None @@ -1201,11 +1249,11 @@ class CellSegmentator: # Handle batch dimension: (B, C, H, W) if pred_array.ndim == 4: for idx in range(pred_array.shape[0]): - batch_sample = dict(sample) + batch_sample: Dict[str, Any] = {} if image_array is not None and image_array.ndim == 4: batch_sample["image"] = image_array[idx] - if isinstance(image_meta, list): - batch_sample["image_meta_dict"] = image_meta[idx] + if isinstance(image_meta, dict) and "filename_or_obj" in image_meta: + batch_sample["image_meta_dict"] = image_meta["filename_or_obj"][idx] if mask_array is not None and mask_array.ndim == 4: batch_sample["mask"] = mask_array[idx] self.__save_prediction_masks( @@ -1216,8 +1264,8 @@ class CellSegmentator: return # Determine base filename - if image_meta and "filename_or_obj" in image_meta: - base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0] + if isinstance(image_meta, (str, os.PathLike)): + base_name = os.path.splitext(os.path.basename(image_meta))[0] else: # Use provided start_index when metadata missing base_name = f"prediction_{start_index:04d}" @@ -1228,8 +1276,8 @@ class CellSegmentator: channel_mask = pred_array[channel_idx] # File names - mask_filename = f"{base_name}_ch{channel_idx:02d}.tif" - plot_filename = f"{base_name}_ch{channel_idx:02d}.png" + mask_filename = f"{base_name}_ch{channel_idx:01d}.tif" + plot_filename = f"{base_name}_ch{channel_idx:01d}.png" mask_path = os.path.join(masks_dir, mask_filename) plot_path = os.path.join(plots_dir, plot_filename) @@ -1402,78 +1450,81 @@ class CellSegmentator: flows = np.zeros((2*channels, height, width), np.float32) for channel in range(channels): - padded_height, padded_width = height + 2, width + 2 - - # Pad the mask with a 1-pixel border - masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device) - masks_padded = F.pad(masks_padded, (1, 1, 1, 1)) - - # Get coordinates of all non-zero pixels in the padded mask - y, x = torch.nonzero(masks_padded, as_tuple=True) - y = y.int(); x = x.int() # ensure integer type - - # Generate 8-connected neighbors (including center) via broadcasted offsets - offsets = torch.tensor([ - [ 0, 0], # center - [-1, 0], # up - [ 1, 0], # down - [ 0, -1], # left - [ 0, 1], # right - [-1, -1], # up-left - [-1, 1], # up-right - [ 1, -1], # down-left - [ 1, 1], # down-right - ], dtype=torch.int32, device=self._device) # (9, 2) - - # coords: (N, 2) - coords = torch.stack((y, x), dim=1) - - # neighbors: (9, N, 2) - neighbors = offsets[:, None, :] + coords[None, :, :] - - # transpose into (2, 9, N) for the GPU kernel - neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index - - # Build connectivity mask: True where neighbor label == center label - center_labels = masks_padded[y, x][None, :] # (1, N) - neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N) - is_neighbor = neighbor_labels == center_labels # (9, N) - - # Compute object slices and pack into array for get_centers - slices = find_objects(mask) - slices_arr = np.array([ - [i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop] - for i, sl in enumerate(slices) if sl is not None - ], dtype=int) + mask_channel = mask[channel] - # Compute centers (pixel indices) and extents via the provided helper - centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr) - # Move centers to GPU and shift by +1 for padding - meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding - - # Determine number of diffusion iterations - n_iter = 2 * ext.max() - - # Run the GPU diffusion kernel - mu = self.__propagate_centers_gpu( - neighbor_indices=neighbors, - center_indices=meds_p.T, - valid_neighbor_mask=is_neighbor, - output_shape=(padded_height, padded_width), - num_iterations=n_iter - ) + if mask_channel.max() > 0: + padded_height, padded_width = height + 2, width + 2 + + # Pad the mask with a 1-pixel border + masks_padded = torch.from_numpy(mask_channel.astype(np.int64)).to(self._device) + masks_padded = F.pad(masks_padded, (1, 1, 1, 1)) + + # Get coordinates of all non-zero pixels in the padded mask + y, x = torch.nonzero(masks_padded, as_tuple=True) + y = y.int(); x = x.int() # ensure integer type + + # Generate 8-connected neighbors (including center) via broadcasted offsets + offsets = torch.tensor([ + [ 0, 0], # center + [-1, 0], # up + [ 1, 0], # down + [ 0, -1], # left + [ 0, 1], # right + [-1, -1], # up-left + [-1, 1], # up-right + [ 1, -1], # down-left + [ 1, 1], # down-right + ], dtype=torch.int32, device=self._device) # (9, 2) + + # coords: (N, 2) + coords = torch.stack((y, x), dim=1) + + # neighbors: (9, N, 2) + neighbors = offsets[:, None, :] + coords[None, :, :] + + # transpose into (2, 9, N) for the GPU kernel + neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index + + # Build connectivity mask: True where neighbor label == center label + center_labels = masks_padded[y, x][None, :] # (1, N) + neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N) + is_neighbor = neighbor_labels == center_labels # (9, N) + + # Compute object slices and pack into array for get_centers + slices = find_objects(mask_channel) + slices_arr = np.array([ + [i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop] + for i, sl in enumerate(slices, start=1) if sl is not None + ], dtype=np.int16) + + # Compute centers (pixel indices) and extents via the provided helper + centers, ext = self.__get_mask_centers_and_extents(mask_channel, slices_arr) + # Move centers to GPU and shift by +1 for padding + meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding + + # Determine number of diffusion iterations + n_iter = 2 * ext.max() + + # Run the GPU diffusion kernel + mu = self.__propagate_centers_gpu( + neighbor_indices=neighbors, + center_indices=meds_p.T, + valid_neighbor_mask=is_neighbor, + output_shape=(padded_height, padded_width), + num_iterations=n_iter + ) - # Cast to float64 and normalize flow vectors - mu = mu.astype(np.float64) - mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60 + # Cast to float64 and normalize flow vectors + mu = mu.astype(np.float64) + mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60 + + # Remove the padding and write into final output + flow_output = np.zeros((2, height, width), dtype=np.float32) + ys_np = y.cpu().numpy() - 1 + xs_np = x.cpu().numpy() - 1 + flow_output[:, ys_np, xs_np] = mu + flows[2*channel: 2*channel + 2] = flow_output - # Remove the padding and write into final output - flow_output = np.zeros((2, height, width), dtype=np.float32) - ys_np = y.cpu().numpy() - 1 - xs_np = x.cpu().numpy() - 1 - flow_output[:, ys_np, xs_np] = mu - flows[2*channel: 2*channel + 2] = flow_output - return flows @@ -1624,8 +1675,10 @@ class CellSegmentator: # Initialize flow_field with two channels: dy and dx flow_field = np.zeros((2, height, width), dtype=np.float64) + mask_channel = mask[channel] + # Find bounding box for each labeled mask - mask_slices = find_objects(mask) + mask_slices = find_objects(mask_channel) # centers: List[Tuple[int, int]] = [] # Iterate over mask labels in parallel @@ -1642,7 +1695,7 @@ class CellSegmentator: # Get local coordinates of mask pixels within the patch local_rows, local_cols = np.nonzero( - mask[row_slice, col_slice] == (label_idx + 1) + mask_channel[row_slice, col_slice] == (label_idx + 1) ) # Shift coords by +1 for the border padding local_rows = local_rows.astype(np.int32) + 1 @@ -1774,8 +1827,8 @@ class CellSegmentator: Generate instance segmentation masks from probability and flow fields. Args: - probability_map: 3D array (channels, height, width) of cell probabilities. - flow: 3D array (2*channels, height, width) of forward flow vectors. + probability_map: 3D array `(C, H, W)` of cell probabilities. + flow: 3D array `(2*C, H, W)` of forward flow vectors. prob_threshold: threshold to binarize probability_map. (Default 0.0) flow_threshold: threshold for filtering bad flow masks. (Default 0.4) num_iters: number of iterations for flow-following. (Default 200) @@ -1802,6 +1855,9 @@ class CellSegmentator: # Extract binary mask for this channel channel_mask = probability_mask[channel_index] + if not channel_mask.sum(): + continue + nonzero_coords = np.stack(np.nonzero(channel_mask)) # Follow the flow vectors to generate coordinate mappings @@ -1810,12 +1866,6 @@ class CellSegmentator: initial_coords=nonzero_coords, num_iters=num_iters ) - # If flow following fails, leave this channel empty - if flow_coordinates is None: - labeled_instances[channel_index] = np.zeros( - probability_map.shape[1:], dtype=np.uint16 - ) - continue if not torch.is_tensor(flow_coordinates): flow_coordinates = torch.from_numpy( @@ -1851,11 +1901,6 @@ class CellSegmentator: ) labeled_instances[channel_index] = channel_instances_mask - else: - # No valid instances found, leave the channel empty - labeled_instances[channel_index] = np.zeros( - probability_map.shape[1:], dtype=np.uint16 - ) return labeled_instances @@ -1923,7 +1968,7 @@ class CellSegmentator: ) # Update each coordinate and clamp to valid range for i in range(dims): - pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0) + pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0) # Denormalize back to original pixel coordinates pts = (pts + 1) * 0.5 * max_idx_pt @@ -2072,16 +2117,21 @@ class CellSegmentator: raise # Step 3: Find peaks via 5x5 max-pooling - k = 5 - pooled = F.max_pool2d( + # k = 5 + # pooled = F.max_pool2d( + # histogram.float().unsqueeze(0).unsqueeze(1), + # kernel_size=k, + # stride=1, + # padding=k // 2 + # ).squeeze() + pooled = self.__max_pool_nd( histogram.unsqueeze(0), - kernel_size=k, - stride=1, - padding=k // 2 + kernel_size=5 ).squeeze() + # Seeds are positions where histogram equals local max and count > threshold - seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10)) - if seed_positions.numel() == 0: + seed_positions = torch.nonzero((histogram - pooled == 0) * (histogram > 10)) + if seed_positions.shape[0] == 0: logger.warning("No seeds found: returning empty mask") return np.zeros(original_shape, dtype=np.uint16) @@ -2106,13 +2156,14 @@ class CellSegmentator: seed_masks[:, 5, 5] = 1 # Iterative dilation and thresholding for _ in range(5): - seed_masks = F.max_pool2d( - seed_masks, - kernel_size=3, - stride=1, - padding=1 - ) - seed_masks = seed_masks & (patches > 2) + # seed_masks = F.max_pool2d( + # seed_masks.float().unsqueeze(0), + # kernel_size=3, + # stride=1, + # padding=1 + # ).squeeze(0).int() + seed_masks = self.__max_pool_nd(seed_masks, kernel_size=3) + seed_masks *= (patches > 2) # Compute final mask coordinates final_coords = [] for idx in range(num_seeds): @@ -2133,7 +2184,7 @@ class CellSegmentator: # Step 6: Map to original image and remove oversized masks mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32) - mask_final[valid_indices] = mask_values + mask_final[tuple(valid_indices)] = mask_values # Prune masks that are too large labels, counts = fastremap.unique(mask_final, return_counts=True) @@ -2146,6 +2197,96 @@ class CellSegmentator: return mask_final + def __max_pool1d( + self, + input_tensor: Tensor, + kernel_size: int = 5, + axis: int = 1, + output_tensor: Optional[Tensor] = None + ) -> Tensor: + """ + Memory-efficient 1D max pooling along a specified axis using in-place updates. + Requires: + - stride = 1 + - padding = kernel_size // 2 + - odd kernel_size >= 3 + + Args: + input_tensor (Tensor): Source tensor for pooling. + kernel_size (int): Size of the pooling window (must be odd and >= 3). + axis (int): Axis along which to compute 1D max pooling. + output_tensor (Optional[Tensor]): Tensor to store the result. + If None, a clone of input_tensor is used. + + Returns: + Tensor: The pooled tensor, same shape as input_tensor. + """ + # Initialize or copy data into the output tensor + if output_tensor is None: + output = input_tensor.clone() + else: + output = output_tensor + output.copy_(input_tensor) + + # Number of elements along the chosen axis and half-window size + dimension_size = input_tensor.shape[axis] + half_window = kernel_size // 2 + + # Slide window offsets from -half_window to +half_window + for offset in range(-half_window, half_window + 1): + # Compute slice indices depending on axis + if axis == 1: + target_slice = output[:, max(-offset, 0): min(dimension_size - offset, dimension_size)] + source_slice = input_tensor[:, max(offset, 0): min(dimension_size + offset, dimension_size)] + elif axis == 2: + target_slice = output[:, :, max(-offset, 0): min(dimension_size - offset, dimension_size)] + source_slice = input_tensor[:, :, max(offset, 0): min(dimension_size + offset, dimension_size)] + elif axis == 3: + target_slice = output[:, :, :, max(-offset, 0): min(dimension_size - offset, dimension_size)] + source_slice = input_tensor[:, :, :, max(offset, 0): min(dimension_size + offset, dimension_size)] + else: + raise ValueError(f"Unsupported axis {axis} for 1D pooling") + + # In-place element-wise maximum + torch.maximum(target_slice, source_slice, out=target_slice) + + return output + + + def __max_pool_nd( + self, + input_tensor: Tensor, + kernel_size: int = 5 + ) -> Tensor: + """ + Memory-efficient N-dimensional max pooling for 2D or 3D spatial data. + Applies 1D max pooling sequentially over each spatial axis. + + Args: + input_tensor (Tensor): Input tensor with shape + (batch_size, dim1, dim2, ..., dimN). + kernel_size (int): Size of the pooling window (must be odd and >= 3). + + Returns: + Tensor: The pooled tensor, same shape as input_tensor. + """ + # Determine number of spatial dimensions (excluding batch axis) + num_spatial_dims = input_tensor.ndim - 1 + + # First pass: pool along axis=1 + pooled = self.__max_pool1d(input_tensor, kernel_size=kernel_size, axis=1) + # Second pass: pool along axis=2 + pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=2) + + # If 3D data, apply a third pass along axis=3 + if num_spatial_dims == 3: + pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=3) + elif num_spatial_dims != 2: + raise ValueError("max_pool_nd only supports 2D or 3D spatial data") + + return pooled + + def __remove_inconsistent_flow_masks( self, mask: np.ndarray, @@ -2175,11 +2316,7 @@ class CellSegmentator: """ # If mask is very large and running on CUDA, check memory num_pixels = mask.size - if ( - num_pixels > 10000 * 10000 - - and self._device.type == 'cuda' - ): + if num_pixels > 10000 * 10000 and self._device.type == 'cuda': # Clear unused GPU cache torch.cuda.empty_cache() # Determine PyTorch version @@ -2296,15 +2433,8 @@ class CellSegmentator: logger.error(msg) raise ValueError(msg) - # Optionally remove masks smaller than minimum_size - if minimum_size >= 0: - # Compute label counts (skipping background at index 0) - labels, counts = fastremap.unique(masks, return_counts=True) - # Identify labels to remove: those with count < minimum_size - small_labels = labels[counts < minimum_size] - if small_labels.size > 0: - masks = fastremap.mask(masks, small_labels) - fastremap.renumber(masks, in_place=True) + # Initial pruning of too-small masks + masks = self._prune_small_masks(masks, minimum_size) # Find bounding boxes for each mask label object_slices = find_objects(masks) @@ -2325,12 +2455,36 @@ class CellSegmentator: output_masks[slc][filled_region] = new_label new_label += 1 - # Final pruning of small masks after filling (optional) - if minimum_size >= 0: - labels, counts = fastremap.unique(output_masks, return_counts=True) - small_labels = labels[counts < minimum_size] - if small_labels.size > 0: - output_masks = fastremap.mask(output_masks, small_labels) - fastremap.renumber(output_masks, in_place=True) + # Final pruning after hole filling + output_masks = self._prune_small_masks(output_masks, minimum_size) + return output_masks + + + def _prune_small_masks( + self, + masks: np.ndarray, + minimum_size: int + ) -> np.ndarray: + """ + Remove labeled regions in `masks` whose pixel count is below `minimum_size`. - return output_masks \ No newline at end of file + Args: + masks (np.ndarray): Integer mask array (any shape), 0=background. + minimum_size (int): Minimum pixel count; labels smaller are removed. If <0, skip pruning. + + Returns: + np.ndarray: Mask array with small labels suppressed and labels renumbered. + """ + if minimum_size < 0: + return masks + + labels, counts = fastremap.unique(masks, return_counts=True) + # Skip background label at index 0 + non_bg_labels = labels[1:] + non_bg_counts = counts[1:] + # Identify labels to remove + small_labels = non_bg_labels[non_bg_counts < minimum_size] + if small_labels.size > 0: + masks = fastremap.mask(masks, small_labels) + fastremap.renumber(masks, in_place=True) + return masks \ No newline at end of file diff --git a/core/utils/measures.py b/core/utils/measures.py index c157e06..53135ea 100644 --- a/core/utils/measures.py +++ b/core/utils/measures.py @@ -11,6 +11,8 @@ from skimage import segmentation from scipy.optimize import linear_sum_assignment from typing import Dict, List, Tuple, Any, Union +from core.logger import get_logger + __all__ = [ "compute_batch_segmentation_f1_metrics", "compute_batch_segmentation_average_precision_metrics", "compute_batch_segmentation_tp_fp_fn", @@ -18,7 +20,9 @@ __all__ = [ "compute_segmentation_tp_fp_fn", "compute_confusion_matrix", "compute_f1_score", "compute_average_precision_score" ] - + +logger = get_logger() + def compute_f1_score( true_positives: int, @@ -92,7 +96,7 @@ def compute_confusion_matrix( # If no predictions were made, return zeros (with a printout for debugging). if num_predictions == 0: - print("No segmentation results!") + logger.warning("No segmentation results!") return 0, 0, 0 # Compute the IoU matrix and ignore the background (first row and column). @@ -586,7 +590,7 @@ def _process_instance_matching( # If no predictions are found, return with all ground truth as false negatives. if num_prediction == 0: - print("No segmentation results!") + logger.warning("No segmentation results!") result = {'tp': 0, 'fp': 0, 'fn': num_ground_truth} if return_masks: tp_mask = np.zeros_like(ground_truth_mask, dtype=np.uint8) diff --git a/generate_config.py b/generate_config.py index aa5dd7a..2bb91ec 100644 --- a/generate_config.py +++ b/generate_config.py @@ -1,8 +1,7 @@ import os from typing import Tuple -from config.config import * -from config.dataset_config import DatasetConfig +from config import Config, WandbConfig, DatasetConfig, ComponentConfig from core import ( ModelRegistry, CriterionRegistry, OptimizerRegistry, SchedulerRegistry @@ -47,7 +46,8 @@ def main(): if is_training is False: config = Config( model=ComponentConfig(name=chosen_model, params=model_instance), - dataset_config=dataset_config + dataset_config=dataset_config, + wandb_config=WandbConfig() ) # Construct a base filename from the selected registry names. @@ -76,6 +76,7 @@ def main(): config = Config( model=ComponentConfig(name=chosen_model, params=model_instance), dataset_config=dataset_config, + wandb_config=WandbConfig(), criterion=ComponentConfig(name=chosen_criterion, params=criterion_instance), optimizer=ComponentConfig(name=chosen_optimizer, params=optimizer_instance), scheduler=ComponentConfig(name=chosen_scheduler, params=scheduler_instance) diff --git a/main.py b/main.py index eea6098..c6e1970 100644 --- a/main.py +++ b/main.py @@ -1,41 +1,78 @@ import os +import argparse import wandb -from config.config import Config +from config import Config from core.data import * from core.segmentator import CellSegmentator -if __name__ == "__main__": - config_path = 'config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json' - # config_path = 'config/templates/predict/ModelV.json' - +def main(): + parser = argparse.ArgumentParser( + description="Train or predict cell segmentator with specified config file." + ) + parser.add_argument( + '-c', '--config', + type=str, + default='config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing.json', + help='Path to the JSON config file' + ) + parser.add_argument( + '-m', '--mode', + choices=['train', 'test', 'predict'], + default='train', + help='Run mode: train, test or predict' + ) + args = parser.parse_args() + + mode = args.mode + config_path = args.config config = Config.load_json(config_path) - # config = Config.load_json(config_path) - - if config.dataset_config.wandb.use_wandb: - # Initialize W&B - wandb.init(config=config.asdict(), **config.dataset_config.wandb.asdict()) + if mode == 'train' and not config.dataset_config.is_training: + raise ValueError( + f"Config is not set for training (is_training=False), but mode 'train' was requested." + ) + if mode in ('test', 'predict') and config.dataset_config.is_training: + raise ValueError( + f"Config is set for training (is_training=True), but mode '{mode}' was requested." + ) + + if config.wandb_config.use_wandb: + # Initialize W&B + wandb.init(config=config.asdict(), **config.wandb_config.asdict()) # How many batches to wait before logging training status wandb.config.log_interval = 10 - + segmentator = CellSegmentator(config) - segmentator.create_dataloaders() - + segmentator.create_dataloaders( + train_transforms=get_train_transforms() if mode == "train" else None, + valid_transforms=get_valid_transforms() if mode == "train" else None, + test_transforms=get_test_transforms() if mode in ("train", "test") else None, + predict_transforms=get_predict_transforms() if mode == "predict" else None + ) + # Watch parameters & gradients of model - if config.dataset_config.wandb.use_wandb: + if config.wandb_config.use_wandb: wandb.watch(segmentator._model, log="all", log_graph=True) - + + # Run training (or prediction, if implemented) segmentator.run() - - weights_dir = "weights" if not config.dataset_config.wandb.use_wandb else wandb.run.dir # type: ignore - saving_path = os.path.join( - weights_dir, os.path.basename(config.dataset_config.common.predictions_dir) + '.pth' - ) - segmentator.save_checkpoint(saving_path) - - if config.dataset_config.wandb.use_wandb: - wandb.save(saving_path) - - + + if config.dataset_config.is_training: + # Prepare saving path + weights_dir = ( + wandb.run.dir if config.wandb_config.use_wandb else "weights" # type: ignore + ) + saving_path = os.path.join( + weights_dir, + os.path.basename(config.dataset_config.common.predictions_dir) + '.pth' + ) + segmentator.save_checkpoint(saving_path) + + if config.wandb_config.use_wandb: + wandb.save(saving_path) + + +if __name__ == "__main__": + main()