|
|
|
@ -1,4 +1,3 @@
|
|
|
|
|
import torch
|
|
|
|
|
import random
|
|
|
|
|
import numpy as np
|
|
|
|
|
from numba import njit, prange
|
|
|
|
@ -10,23 +9,41 @@ from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
import fastremap
|
|
|
|
|
|
|
|
|
|
import fill_voids
|
|
|
|
|
from skimage import morphology
|
|
|
|
|
from skimage.segmentation import find_boundaries
|
|
|
|
|
from scipy.ndimage import mean, find_objects
|
|
|
|
|
import fill_voids
|
|
|
|
|
|
|
|
|
|
from monai.data.dataset import Dataset
|
|
|
|
|
from monai.transforms import * # type: ignore
|
|
|
|
|
from monai.inferers.utils import sliding_window_inference
|
|
|
|
|
from monai.metrics.cumulative_average import CumulativeAverage
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import matplotlib.colors as mcolors
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import glob
|
|
|
|
|
import copy
|
|
|
|
|
import tifffile as tiff
|
|
|
|
|
|
|
|
|
|
from pprint import pformat
|
|
|
|
|
from typing import Optional, Tuple, List, Union
|
|
|
|
|
from tabulate import tabulate
|
|
|
|
|
from typing import Any, Dict, Literal, Optional, Tuple, List, Union
|
|
|
|
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
|
|
from config import Config
|
|
|
|
|
from core.models import *
|
|
|
|
|
from core.losses import *
|
|
|
|
|
from core.optimizers import *
|
|
|
|
|
from core.schedulers import *
|
|
|
|
|
from core.utils import (
|
|
|
|
|
compute_batch_segmentation_tp_fp_fn,
|
|
|
|
|
compute_f1_score,
|
|
|
|
|
compute_average_precision_score
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from core.logger import get_logger
|
|
|
|
|
|
|
|
|
@ -40,6 +57,11 @@ class CellSegmentator:
|
|
|
|
|
self.__parse_config(config)
|
|
|
|
|
|
|
|
|
|
self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu")
|
|
|
|
|
self._scaler = (
|
|
|
|
|
torch.amp.GradScaler(self._device.type) # type: ignore
|
|
|
|
|
if self._dataset_setup.is_training and self._dataset_setup.common.use_amp
|
|
|
|
|
else None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._train_dataloader: Optional[DataLoader] = None
|
|
|
|
|
self._valid_dataloader: Optional[DataLoader] = None
|
|
|
|
@ -215,16 +237,236 @@ class CellSegmentator:
|
|
|
|
|
logger.info(f"Loaded prediction dataset with {len(predict_dataset)} samples.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_data_info(
|
|
|
|
|
self,
|
|
|
|
|
loader_type: Literal["train", "valid", "test", "predict"],
|
|
|
|
|
index: Optional[int] = None
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Prints statistics for a single sample from the specified dataloader.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
loader_type: One of "train", "valid", "test", "predict".
|
|
|
|
|
index: The sample index; if None, a random index is selected.
|
|
|
|
|
"""
|
|
|
|
|
# Retrieve the dataloader attribute, e.g., self._train_dataloader
|
|
|
|
|
loader: Optional[torch.utils.data.DataLoader] = getattr(self, f"_{loader_type}_dataloader", None)
|
|
|
|
|
if loader is None:
|
|
|
|
|
logger.error(f"Dataloader '{loader_type}' is not initialized.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
dataset = loader.dataset
|
|
|
|
|
total = len(dataset) # type: ignore
|
|
|
|
|
if total == 0:
|
|
|
|
|
logger.error(f"Dataset for '{loader_type}' is empty.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Choose index
|
|
|
|
|
idx = index if index is not None else random.randint(0, total - 1)
|
|
|
|
|
if not (0 <= idx < total):
|
|
|
|
|
logger.error(f"Index {idx} is out of range [0, {total}).")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Fetch the sample and apply transforms
|
|
|
|
|
sample = dataset[idx]
|
|
|
|
|
# Expecting a dict with {'image': ..., 'mask': ...} or {'image': ...}
|
|
|
|
|
img = sample["image"]
|
|
|
|
|
# Convert tensor to numpy if needed
|
|
|
|
|
if isinstance(img, torch.Tensor):
|
|
|
|
|
img = img.cpu().numpy()
|
|
|
|
|
img = np.asarray(img)
|
|
|
|
|
|
|
|
|
|
# Compute image statistics
|
|
|
|
|
img_min, img_max = img.min(), img.max()
|
|
|
|
|
img_mean, img_std = float(img.mean()), float(img.std())
|
|
|
|
|
img_shape = img.shape
|
|
|
|
|
|
|
|
|
|
# Prepare log lines
|
|
|
|
|
lines = [
|
|
|
|
|
"=" * 40,
|
|
|
|
|
f"Dataloader: '{loader_type}', sample index: {idx} / {total - 1}",
|
|
|
|
|
f"Image — shape: {img_shape}, min: {img_min:.4f}, max: {img_max:.4f}, mean: {img_mean:.4f}, std: {img_std:.4f}"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# For 'predict', no mask is available
|
|
|
|
|
if loader_type != "predict":
|
|
|
|
|
mask = sample.get("mask", None)
|
|
|
|
|
if mask is not None:
|
|
|
|
|
# Convert tensor to numpy if needed
|
|
|
|
|
if isinstance(mask, torch.Tensor):
|
|
|
|
|
mask = mask.cpu().numpy()
|
|
|
|
|
mask = np.asarray(mask)
|
|
|
|
|
m_min, m_max = mask.min(), mask.max()
|
|
|
|
|
m_mean, m_std = float(mask.mean()), float(mask.std())
|
|
|
|
|
m_shape = mask.shape
|
|
|
|
|
lines.append(
|
|
|
|
|
f"Mask — shape: {m_shape}, min: {m_min:.4f}, "
|
|
|
|
|
f"max: {m_max:.4f}, mean: {m_mean:.4f}, std: {m_std:.4f}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
lines.append("Mask — not available for this sample.")
|
|
|
|
|
|
|
|
|
|
lines.append("=" * 40)
|
|
|
|
|
|
|
|
|
|
# Output via logger
|
|
|
|
|
for l in lines:
|
|
|
|
|
logger.info(l)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self) -> None:
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
Train the model over multiple epochs, including validation and test.
|
|
|
|
|
"""
|
|
|
|
|
# Ensure training is enabled in dataset setup
|
|
|
|
|
if not self._dataset_setup.is_training:
|
|
|
|
|
raise RuntimeError("Dataset setup indicates training is disabled.")
|
|
|
|
|
|
|
|
|
|
# Determine device name for logging
|
|
|
|
|
if self._device.type == "cpu":
|
|
|
|
|
device_name = "cpu"
|
|
|
|
|
else:
|
|
|
|
|
idx = self._device.index if hasattr(self._device, 'index') else torch.cuda.current_device()
|
|
|
|
|
device_name = torch.cuda.get_device_name(idx)
|
|
|
|
|
|
|
|
|
|
logger.info(f"\n{'=' * 50}\n")
|
|
|
|
|
logger.info(f"Training starts on device: {device_name}")
|
|
|
|
|
logger.info(f"\n{'=' * 50}")
|
|
|
|
|
|
|
|
|
|
best_f1_score = 0.0
|
|
|
|
|
best_weights = None
|
|
|
|
|
|
|
|
|
|
for epoch in range(1, self._dataset_setup.training.num_epochs + 1):
|
|
|
|
|
train_metrics = self.__run_epoch("train", epoch)
|
|
|
|
|
self.__print_with_logging(train_metrics, epoch)
|
|
|
|
|
|
|
|
|
|
# Step the scheduler after training
|
|
|
|
|
if self._scheduler is not None:
|
|
|
|
|
self._scheduler.step()
|
|
|
|
|
|
|
|
|
|
# Periodic validation or tuning
|
|
|
|
|
if epoch % self._dataset_setup.training.val_freq == 0:
|
|
|
|
|
if self._valid_dataloader is not None:
|
|
|
|
|
valid_metrics = self.__run_epoch("valid", epoch)
|
|
|
|
|
self.__print_with_logging(valid_metrics, epoch)
|
|
|
|
|
|
|
|
|
|
# Update best model on improved F1
|
|
|
|
|
f1 = valid_metrics.get("valid_f1_score", 0.0)
|
|
|
|
|
if f1 > best_f1_score:
|
|
|
|
|
best_f1_score = f1
|
|
|
|
|
# Deep copy weights to avoid reference issues
|
|
|
|
|
best_weights = copy.deepcopy(self._model.state_dict())
|
|
|
|
|
logger.info(f"Updated best model weights with F1 score: {f1:.4f}")
|
|
|
|
|
|
|
|
|
|
# Restore best model weights if available
|
|
|
|
|
if best_weights is not None:
|
|
|
|
|
self._model.load_state_dict(best_weights)
|
|
|
|
|
|
|
|
|
|
if self._test_dataloader is not None:
|
|
|
|
|
test_metrics = self.__run_epoch("test")
|
|
|
|
|
self.__print_with_logging(test_metrics, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(self) -> None:
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
Run a full test epoch and display/log the resulting metrics.
|
|
|
|
|
"""
|
|
|
|
|
test_metrics = self.__run_epoch("test")
|
|
|
|
|
self.__print_with_logging(test_metrics, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self) -> None:
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
Run inference on the predict set and save the resulting instance masks.
|
|
|
|
|
"""
|
|
|
|
|
# Ensure the predict DataLoader has been set
|
|
|
|
|
if self._predict_dataloader is None:
|
|
|
|
|
raise RuntimeError("DataLoader for mode 'predict' is not set.")
|
|
|
|
|
|
|
|
|
|
batch_counter = 0
|
|
|
|
|
for batch in tqdm(self._predict_dataloader, desc="Predicting"):
|
|
|
|
|
# Move input images to the configured device (CPU/GPU)
|
|
|
|
|
inputs = batch["img"].to(self._device)
|
|
|
|
|
|
|
|
|
|
# Use automatic mixed precision if enabled in dataset setup
|
|
|
|
|
with torch.amp.autocast( # type: ignore
|
|
|
|
|
self._device.type,
|
|
|
|
|
enabled=self._dataset_setup.common.use_amp
|
|
|
|
|
):
|
|
|
|
|
# Disable gradient computation for inference
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
# Run the model’s forward pass in ‘predict’ mode
|
|
|
|
|
raw_output = self.__run_inference(inputs, mode="predict")
|
|
|
|
|
|
|
|
|
|
# Convert logits/probabilities to discrete instance masks
|
|
|
|
|
# ground_truth is not passed here; only predictions are needed
|
|
|
|
|
preds, _ = self.__post_process_predictions(raw_output)
|
|
|
|
|
|
|
|
|
|
# Save out the predicted masks, using batch_counter to index files
|
|
|
|
|
self.__save_prediction_masks(batch, preds, batch_counter)
|
|
|
|
|
|
|
|
|
|
# Increment counter by batch size for unique file naming
|
|
|
|
|
batch_counter += inputs.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Orchestrate the full workflow:
|
|
|
|
|
- If training is enabled in the dataset setup, start training.
|
|
|
|
|
- Otherwise, if a test DataLoader is provided, run evaluation.
|
|
|
|
|
- Else if a prediction DataLoader is provided, run inference/prediction.
|
|
|
|
|
- If neither loader is available in non‐training mode, raise an error.
|
|
|
|
|
"""
|
|
|
|
|
# 1) TRAINING PATH
|
|
|
|
|
if self._dataset_setup.is_training:
|
|
|
|
|
# Launch the full training loop (with validation, scheduler steps, etc.)
|
|
|
|
|
self.train()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 2) NON-TRAINING PATH (TEST or PREDICT)
|
|
|
|
|
# Prefer test if available
|
|
|
|
|
if self._test_dataloader is not None:
|
|
|
|
|
# Run a single evaluation epoch on the test set and log metrics
|
|
|
|
|
self.evaluate()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# If no test loader, fall back to prediction if available
|
|
|
|
|
if self._predict_dataloader is not None:
|
|
|
|
|
# Run inference on the predict set and save outputs
|
|
|
|
|
self.predict()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 3) ERROR: no appropriate loader found
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Neither test nor predict DataLoader is set for non‐training mode."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_from_checkpoint(self, checkpoint_path: str) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Loads model weights from a specified checkpoint into the current model.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
checkpoint_path (str): Path to the checkpoint file containing the model weights.
|
|
|
|
|
"""
|
|
|
|
|
# Load the checkpoint onto the correct device (CPU or GPU)
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True)
|
|
|
|
|
# Load the state dict into the model, allowing for missing keys
|
|
|
|
|
self._model.load_state_dict(checkpoint['state_dict'], strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(self, checkpoint_path: str) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Saves the current model weights to a checkpoint file.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
checkpoint_path (str): Path where the checkpoint file will be saved.
|
|
|
|
|
"""
|
|
|
|
|
# Create a checkpoint dictionary containing the model’s state_dict
|
|
|
|
|
checkpoint = {
|
|
|
|
|
'state_dict': self._model.state_dict()
|
|
|
|
|
}
|
|
|
|
|
# Write the checkpoint to disk
|
|
|
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __parse_config(self, config: Config) -> None:
|
|
|
|
@ -255,6 +497,11 @@ class CellSegmentator:
|
|
|
|
|
# Initialize model using the model registry
|
|
|
|
|
self._model = ModelRegistry.get_model_class(model.name)(model.params)
|
|
|
|
|
|
|
|
|
|
# Loads model weights from a specified checkpoint
|
|
|
|
|
if config.dataset_config.is_training:
|
|
|
|
|
if config.dataset_config.training.pretrained_weights:
|
|
|
|
|
self.load_from_checkpoint(config.dataset_config.training.pretrained_weights)
|
|
|
|
|
|
|
|
|
|
# Initialize loss criterion if specified
|
|
|
|
|
self._criterion = (
|
|
|
|
|
CriterionRegistry.get_criterion_class(criterion.name)(params=criterion.params)
|
|
|
|
@ -262,6 +509,15 @@ class CellSegmentator:
|
|
|
|
|
else None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if hasattr(self._criterion, "num_classes"):
|
|
|
|
|
nc_model = self._model.num_classes
|
|
|
|
|
nc_crit = getattr(self._criterion, "num_classes")
|
|
|
|
|
if nc_model != nc_crit:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Number of classes mismatch: model.num_classes={nc_model} "
|
|
|
|
|
f"but criterion.num_classes={nc_crit}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Initialize optimizer if specified
|
|
|
|
|
self._optimizer = (
|
|
|
|
|
OptimizerRegistry.get_optimizer_class(optimizer.name)(
|
|
|
|
@ -298,6 +554,7 @@ class CellSegmentator:
|
|
|
|
|
logger.info("[COMMON]")
|
|
|
|
|
logger.info(f"├─ Seed: {common.seed}")
|
|
|
|
|
logger.info(f"├─ Device: {common.device}")
|
|
|
|
|
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
|
|
|
|
|
logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
|
|
|
|
|
|
|
|
|
|
if config.dataset_config.is_training:
|
|
|
|
@ -306,7 +563,6 @@ class CellSegmentator:
|
|
|
|
|
logger.info(f"├─ Batch size: {training.batch_size}")
|
|
|
|
|
logger.info(f"├─ Epochs: {training.num_epochs}")
|
|
|
|
|
logger.info(f"├─ Validation frequency: {training.val_freq}")
|
|
|
|
|
logger.info(f"├─ Use AMP: {'yes' if training.use_amp else 'no'}")
|
|
|
|
|
logger.info(f"├─ Pretrained weights: {training.pretrained_weights or 'None'}")
|
|
|
|
|
|
|
|
|
|
if training.is_split:
|
|
|
|
@ -434,6 +690,624 @@ class CellSegmentator:
|
|
|
|
|
return Dataset(data, transforms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __print_with_logging(self, results: Dict[str, float], step: int) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Print metrics in a tabular format and log to W&B.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
results (Dict[str, float]): results dictionary.
|
|
|
|
|
step (int): epoch index.
|
|
|
|
|
"""
|
|
|
|
|
table = tabulate(
|
|
|
|
|
tabular_data=results.items(),
|
|
|
|
|
headers=["Metric", "Value"],
|
|
|
|
|
floatfmt=".4f",
|
|
|
|
|
tablefmt="fancy_grid"
|
|
|
|
|
)
|
|
|
|
|
print(table, "\n")
|
|
|
|
|
wandb.log(results, step=step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __run_epoch(self,
|
|
|
|
|
mode: Literal["train", "valid", "test"],
|
|
|
|
|
epoch: Optional[int] = None
|
|
|
|
|
) -> Dict[str, float]:
|
|
|
|
|
"""
|
|
|
|
|
Execute one epoch of training, validation, or testing.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mode (str): One of 'train', 'valid', or 'test'.
|
|
|
|
|
epoch (int, optional): Current epoch number for logging.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dict[str, float]: Loss metrics and F1 score for valid/test.
|
|
|
|
|
"""
|
|
|
|
|
# Ensure required components are available
|
|
|
|
|
if mode in ("train", "valid") and (self._optimizer is None or self._criterion is None):
|
|
|
|
|
raise RuntimeError("Optimizer and loss function must be initialized for train/valid.")
|
|
|
|
|
|
|
|
|
|
# Set model mode and choose the appropriate data loader
|
|
|
|
|
if mode == "train":
|
|
|
|
|
self._model.train()
|
|
|
|
|
loader = self._train_dataloader
|
|
|
|
|
else:
|
|
|
|
|
self._model.eval()
|
|
|
|
|
loader = self._valid_dataloader if mode == "valid" else self._test_dataloader
|
|
|
|
|
|
|
|
|
|
# Check that the data loader is provided
|
|
|
|
|
if loader is None:
|
|
|
|
|
raise RuntimeError(f"DataLoader for mode '{mode}' is not set.")
|
|
|
|
|
|
|
|
|
|
all_tp, all_fp, all_fn = [], [], []
|
|
|
|
|
|
|
|
|
|
# Prepare tqdm description with epoch info if available
|
|
|
|
|
if epoch is not None:
|
|
|
|
|
desc = f"Epoch {epoch}/{self._dataset_setup.training.num_epochs} [{mode}]"
|
|
|
|
|
else:
|
|
|
|
|
desc = f"Epoch ({mode})"
|
|
|
|
|
|
|
|
|
|
# Iterate over batches
|
|
|
|
|
batch_counter = 0
|
|
|
|
|
for batch in tqdm(loader, desc=desc):
|
|
|
|
|
inputs = batch["img"].to(self._device)
|
|
|
|
|
targets = batch["label"].to(self._device)
|
|
|
|
|
|
|
|
|
|
# Zero gradients for training
|
|
|
|
|
if self._optimizer is not None:
|
|
|
|
|
self._optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
# Mixed precision context if enabled
|
|
|
|
|
with torch.amp.autocast( # type: ignore
|
|
|
|
|
self._device.type,
|
|
|
|
|
enabled=self._dataset_setup.common.use_amp
|
|
|
|
|
):
|
|
|
|
|
# Only compute gradients in training phase
|
|
|
|
|
with torch.set_grad_enabled(mode == "train"):
|
|
|
|
|
# Forward pass
|
|
|
|
|
raw_output = self.__run_inference(inputs, mode)
|
|
|
|
|
|
|
|
|
|
if self._criterion is not None:
|
|
|
|
|
# Convert label masks to flow representations (one-hot)
|
|
|
|
|
flow_targets = self.__compute_flows_from_masks(targets)
|
|
|
|
|
|
|
|
|
|
# Compute loss for this batch
|
|
|
|
|
batch_loss = self._criterion(raw_output, flow_targets) # type: ignore
|
|
|
|
|
|
|
|
|
|
# Post-process and compute F1 during validation and testing
|
|
|
|
|
if mode in ("valid", "test"):
|
|
|
|
|
preds, labels_post = self.__post_process_predictions(
|
|
|
|
|
raw_output, targets
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Collecting statistics on the batch
|
|
|
|
|
tp, fp, fn = self.__compute_stats(
|
|
|
|
|
predicted_masks=preds,
|
|
|
|
|
ground_truth_masks=labels_post, # type: ignore
|
|
|
|
|
iou_threshold=0.5
|
|
|
|
|
)
|
|
|
|
|
all_tp.append(tp)
|
|
|
|
|
all_fp.append(fp)
|
|
|
|
|
all_fn.append(fn)
|
|
|
|
|
|
|
|
|
|
if mode == "test":
|
|
|
|
|
self.__save_prediction_masks(
|
|
|
|
|
batch, preds, batch_counter
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Backpropagation and optimizer step in training
|
|
|
|
|
if mode == "train":
|
|
|
|
|
if self._dataset_setup.common.use_amp and self._scaler is not None:
|
|
|
|
|
self._scaler.scale(batch_loss).backward() # type: ignore
|
|
|
|
|
self._scaler.unscale_(self._optimizer.optim) # type: ignore
|
|
|
|
|
self._scaler.step(self._optimizer.optim) # type: ignore
|
|
|
|
|
self._scaler.update()
|
|
|
|
|
else:
|
|
|
|
|
batch_loss.backward() # type: ignore
|
|
|
|
|
if self._optimizer is not None:
|
|
|
|
|
self._optimizer.step()
|
|
|
|
|
|
|
|
|
|
batch_counter += inputs.shape[0]
|
|
|
|
|
|
|
|
|
|
if self._criterion is not None:
|
|
|
|
|
# Collect loss metrics
|
|
|
|
|
epoch_metrics = {f"{mode}_{name}": value for name, value in self._criterion.get_loss_metrics().items()}
|
|
|
|
|
# Reset internal loss metrics accumulator
|
|
|
|
|
self._criterion.reset_metrics()
|
|
|
|
|
else:
|
|
|
|
|
epoch_metrics = {}
|
|
|
|
|
|
|
|
|
|
# Include F1 and mAP for validation and testing
|
|
|
|
|
if mode in ("valid", "test"):
|
|
|
|
|
# Concatenating by batch: shape (num_batches*B, C)
|
|
|
|
|
tp_array = np.vstack(all_tp)
|
|
|
|
|
fp_array = np.vstack(all_fp)
|
|
|
|
|
fn_array = np.vstack(all_fn)
|
|
|
|
|
|
|
|
|
|
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="micro"
|
|
|
|
|
)
|
|
|
|
|
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="macro"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return epoch_metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __run_inference(
|
|
|
|
|
self,
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
mode: Literal["train", "valid", "test", "predict"] = "train"
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Perform model inference for different stages.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
inputs (torch.Tensor): Input tensor of shape (B, C, H, W).
|
|
|
|
|
stage (Literal[...]): One of "train", "valid", "test", "predict".
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: Model outputs tensor.
|
|
|
|
|
"""
|
|
|
|
|
if mode != "train":
|
|
|
|
|
# Use sliding window inference for non-training phases
|
|
|
|
|
outputs = sliding_window_inference(
|
|
|
|
|
inputs,
|
|
|
|
|
roi_size=512,
|
|
|
|
|
sw_batch_size=4,
|
|
|
|
|
predictor=self._model,
|
|
|
|
|
padding_mode="constant",
|
|
|
|
|
mode="gaussian",
|
|
|
|
|
overlap=0.5,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Direct forward pass during training
|
|
|
|
|
outputs = self._model(inputs)
|
|
|
|
|
return outputs # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __post_process_predictions(
|
|
|
|
|
self,
|
|
|
|
|
raw_outputs: torch.Tensor,
|
|
|
|
|
ground_truth: Optional[torch.Tensor] = None
|
|
|
|
|
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
|
|
|
"""
|
|
|
|
|
Post-process raw network outputs to extract instance segmentation masks.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W).
|
|
|
|
|
ground_truth (torch.Tensor): Ground truth masks of shape (B, С, H, W).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
|
|
|
- instance_masks: Instance-wise masks array of shape (B, С, H, W).
|
|
|
|
|
- labels_np: Converted ground truth of shape (B, С, H, W) or None if
|
|
|
|
|
ground_truth was not provided.
|
|
|
|
|
"""
|
|
|
|
|
# Move outputs to CPU and convert to numpy
|
|
|
|
|
outputs_np = raw_outputs.cpu().numpy()
|
|
|
|
|
# Split channels: gradient flows then class logits
|
|
|
|
|
gradflow = outputs_np[:, :2 * self._model.num_classes]
|
|
|
|
|
logits = outputs_np[:, -self._model.num_classes :]
|
|
|
|
|
# Apply sigmoid to logits to get probabilities
|
|
|
|
|
probabilities = self.__sigmoid(logits)
|
|
|
|
|
|
|
|
|
|
batch_size, _, height, width = probabilities.shape
|
|
|
|
|
# Prepare container for instance masks
|
|
|
|
|
instance_masks = np.zeros((batch_size, self._model.num_classes, height, width), dtype=np.uint16)
|
|
|
|
|
for idx in range(batch_size):
|
|
|
|
|
instance_masks[idx] = self.__segment_instances(
|
|
|
|
|
probability_map=probabilities[idx],
|
|
|
|
|
flow=gradflow[idx],
|
|
|
|
|
prob_threshold=0.0,
|
|
|
|
|
flow_threshold=0.4,
|
|
|
|
|
min_object_size=15
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Convert ground truth to numpy
|
|
|
|
|
labels_np = ground_truth.cpu().numpy() if ground_truth is not None else None
|
|
|
|
|
return instance_masks, labels_np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __compute_stats(
|
|
|
|
|
self,
|
|
|
|
|
predicted_masks: np.ndarray,
|
|
|
|
|
ground_truth_masks: np.ndarray,
|
|
|
|
|
iou_threshold: float = 0.5
|
|
|
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
|
|
|
"""
|
|
|
|
|
Compute batch-wise true positives, false positives, and false negatives
|
|
|
|
|
for instance segmentation, using a configurable IoU threshold.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
predicted_masks (np.ndarray): Predicted instance masks of shape (B, C, H, W).
|
|
|
|
|
ground_truth_masks (np.ndarray): Ground truth instance masks of shape (B, C, H, W).
|
|
|
|
|
iou_threshold (float): Intersection-over-Union threshold for matching predictions
|
|
|
|
|
to ground truths (default: 0.5).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
|
|
|
- tp: True positives per batch and class, shape (B, C)
|
|
|
|
|
- fp: False positives per batch and class, shape (B, C)
|
|
|
|
|
- fn: False negatives per batch and class, shape (B, C)
|
|
|
|
|
"""
|
|
|
|
|
stats = compute_batch_segmentation_tp_fp_fn(
|
|
|
|
|
batch_ground_truth=ground_truth_masks,
|
|
|
|
|
batch_prediction=predicted_masks,
|
|
|
|
|
iou_threshold=iou_threshold,
|
|
|
|
|
remove_boundary_objects=True
|
|
|
|
|
)
|
|
|
|
|
tp = stats["tp"]
|
|
|
|
|
fp = stats["fp"]
|
|
|
|
|
fn = stats["fn"]
|
|
|
|
|
return tp, fp, fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __compute_f1_metric(
|
|
|
|
|
self,
|
|
|
|
|
true_positives: np.ndarray,
|
|
|
|
|
false_positives: np.ndarray,
|
|
|
|
|
false_negatives: np.ndarray,
|
|
|
|
|
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro"
|
|
|
|
|
) -> Union[float, np.ndarray]:
|
|
|
|
|
"""
|
|
|
|
|
Compute F1-score from batch-wise TP/FP/FN using various aggregation schemes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
true_positives: array of TP counts per sample and class.
|
|
|
|
|
false_positives: array of FP counts per sample and class.
|
|
|
|
|
false_negatives: array of FN counts per sample and class.
|
|
|
|
|
reduction:
|
|
|
|
|
- 'none': return F1 for each sample, class → shape (batch_size, num_classes)
|
|
|
|
|
- 'micro': global F1 over all samples & classes
|
|
|
|
|
- 'imagewise': F1 per sample (summing over classes), then average over samples
|
|
|
|
|
- 'macro': average class-wise F1 (classes summed over batch)
|
|
|
|
|
- 'weighted': class-wise F1 weighted by support (TP+FN)
|
|
|
|
|
Returns:
|
|
|
|
|
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
|
|
|
|
|
or np.ndarray of shape (batch_size, num_classes) if reduction='none'.
|
|
|
|
|
"""
|
|
|
|
|
batch_size, num_classes = true_positives.shape
|
|
|
|
|
|
|
|
|
|
# 1) No reduction: compute F1 for each (sample, class)
|
|
|
|
|
if reduction == "none":
|
|
|
|
|
f1_matrix = np.zeros((batch_size, num_classes), dtype=float)
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
tp_val = int(true_positives[i, c])
|
|
|
|
|
fp_val = int(false_positives[i, c])
|
|
|
|
|
fn_val = int(false_negatives[i, c])
|
|
|
|
|
_, _, f1_val = compute_f1_score(tp_val, fp_val, fn_val)
|
|
|
|
|
f1_matrix[i, c] = f1_val
|
|
|
|
|
return f1_matrix
|
|
|
|
|
|
|
|
|
|
# 2) Micro: sum all TP/FP/FN and compute a single F1
|
|
|
|
|
if reduction == "micro":
|
|
|
|
|
tp_total = int(true_positives.sum())
|
|
|
|
|
fp_total = int(false_positives.sum())
|
|
|
|
|
fn_total = int(false_negatives.sum())
|
|
|
|
|
_, _, f1_global = compute_f1_score(tp_total, fp_total, fn_total)
|
|
|
|
|
return f1_global
|
|
|
|
|
|
|
|
|
|
# 3) Imagewise: compute per-sample F1 (sum over classes), then average
|
|
|
|
|
if reduction == "imagewise":
|
|
|
|
|
f1_per_image = np.zeros(batch_size, dtype=float)
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
tp_i = int(true_positives[i].sum())
|
|
|
|
|
fp_i = int(false_positives[i].sum())
|
|
|
|
|
fn_i = int(false_negatives[i].sum())
|
|
|
|
|
_, _, f1_i = compute_f1_score(tp_i, fp_i, fn_i)
|
|
|
|
|
f1_per_image[i] = f1_i
|
|
|
|
|
return float(f1_per_image.mean())
|
|
|
|
|
|
|
|
|
|
# For macro/weighted, first aggregate per class across the batch
|
|
|
|
|
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
|
|
|
|
|
fp_per_class = false_positives.sum(axis=0).astype(int)
|
|
|
|
|
fn_per_class = false_negatives.sum(axis=0).astype(int)
|
|
|
|
|
|
|
|
|
|
# 4) Macro: average F1 across classes equally
|
|
|
|
|
if reduction == "macro":
|
|
|
|
|
f1_per_class = np.zeros(num_classes, dtype=float)
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
_, _, f1_c = compute_f1_score(
|
|
|
|
|
tp_per_class[c],
|
|
|
|
|
fp_per_class[c],
|
|
|
|
|
fn_per_class[c]
|
|
|
|
|
)
|
|
|
|
|
f1_per_class[c] = f1_c
|
|
|
|
|
return float(f1_per_class.mean())
|
|
|
|
|
|
|
|
|
|
# 5) Weighted: class-wise F1 weighted by support = TP + FN
|
|
|
|
|
if reduction == "weighted":
|
|
|
|
|
f1_per_class = np.zeros(num_classes, dtype=float)
|
|
|
|
|
support = np.zeros(num_classes, dtype=float)
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
tp_c = tp_per_class[c]
|
|
|
|
|
fp_c = fp_per_class[c]
|
|
|
|
|
fn_c = fn_per_class[c]
|
|
|
|
|
_, _, f1_c = compute_f1_score(tp_c, fp_c, fn_c)
|
|
|
|
|
f1_per_class[c] = f1_c
|
|
|
|
|
support[c] = tp_c + fn_c
|
|
|
|
|
total_support = support.sum()
|
|
|
|
|
if total_support == 0:
|
|
|
|
|
# fallback to unweighted macro if no positives
|
|
|
|
|
return float(f1_per_class.mean())
|
|
|
|
|
weights = support / total_support
|
|
|
|
|
return float((f1_per_class * weights).sum())
|
|
|
|
|
|
|
|
|
|
raise ValueError(f"Unknown reduction mode: {reduction}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __compute_average_precision_metric(
|
|
|
|
|
self,
|
|
|
|
|
true_positives: np.ndarray,
|
|
|
|
|
false_positives: np.ndarray,
|
|
|
|
|
false_negatives: np.ndarray,
|
|
|
|
|
reduction: Literal["micro", "macro", "weighted", "imagewise", "none"] = "micro"
|
|
|
|
|
) -> Union[float, np.ndarray]:
|
|
|
|
|
"""
|
|
|
|
|
Compute Average Precision (AP) from batch-wise TP/FP/FN using various aggregation schemes.
|
|
|
|
|
|
|
|
|
|
AP is defined here as:
|
|
|
|
|
AP = TP / (TP + FP + FN)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
true_positives: array of true positives per sample and class.
|
|
|
|
|
false_positives: array of false positives per sample and class.
|
|
|
|
|
false_negatives: array of false negatives per sample and class.
|
|
|
|
|
reduction:
|
|
|
|
|
- 'none': return AP for each sample and class → shape (batch_size, num_classes)
|
|
|
|
|
- 'micro': global AP over all samples & classes
|
|
|
|
|
- 'imagewise': AP per sample (summing stats over classes), then average over batch
|
|
|
|
|
- 'macro': average class-wise AP (each class summed over batch)
|
|
|
|
|
- 'weighted': class-wise AP weighted by support (TP+FN)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
float for reductions 'micro', 'imagewise', 'macro', 'weighted';
|
|
|
|
|
or np.ndarray of shape (batch_size, num_classes) if reduction='none'.
|
|
|
|
|
"""
|
|
|
|
|
batch_size, num_classes = true_positives.shape
|
|
|
|
|
|
|
|
|
|
# 1) No reduction: AP per (sample, class)
|
|
|
|
|
if reduction == "none":
|
|
|
|
|
ap_matrix = np.zeros((batch_size, num_classes), dtype=float)
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
tp_val = int(true_positives[i, c])
|
|
|
|
|
fp_val = int(false_positives[i, c])
|
|
|
|
|
fn_val = int(false_negatives[i, c])
|
|
|
|
|
ap_val = compute_average_precision_score(tp_val, fp_val, fn_val)
|
|
|
|
|
ap_matrix[i, c] = ap_val
|
|
|
|
|
return ap_matrix
|
|
|
|
|
|
|
|
|
|
# 2) Micro: sum all TP/FP/FN and compute one AP
|
|
|
|
|
if reduction == "micro":
|
|
|
|
|
tp_total = int(true_positives.sum())
|
|
|
|
|
fp_total = int(false_positives.sum())
|
|
|
|
|
fn_total = int(false_negatives.sum())
|
|
|
|
|
return compute_average_precision_score(tp_total, fp_total, fn_total)
|
|
|
|
|
|
|
|
|
|
# 3) Imagewise: compute per-sample AP (sum over classes), then mean
|
|
|
|
|
if reduction == "imagewise":
|
|
|
|
|
ap_per_image = np.zeros(batch_size, dtype=float)
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
tp_i = int(true_positives[i].sum())
|
|
|
|
|
fp_i = int(false_positives[i].sum())
|
|
|
|
|
fn_i = int(false_negatives[i].sum())
|
|
|
|
|
ap_per_image[i] = compute_average_precision_score(tp_i, fp_i, fn_i)
|
|
|
|
|
return float(ap_per_image.mean())
|
|
|
|
|
|
|
|
|
|
# For macro and weighted: first aggregate per class across batch
|
|
|
|
|
tp_per_class = true_positives.sum(axis=0).astype(int) # shape (num_classes,)
|
|
|
|
|
fp_per_class = false_positives.sum(axis=0).astype(int)
|
|
|
|
|
fn_per_class = false_negatives.sum(axis=0).astype(int)
|
|
|
|
|
|
|
|
|
|
# 4) Macro: average AP across classes equally
|
|
|
|
|
if reduction == "macro":
|
|
|
|
|
ap_per_class = np.zeros(num_classes, dtype=float)
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
ap_per_class[c] = compute_average_precision_score(
|
|
|
|
|
tp_per_class[c],
|
|
|
|
|
fp_per_class[c],
|
|
|
|
|
fn_per_class[c]
|
|
|
|
|
)
|
|
|
|
|
return float(ap_per_class.mean())
|
|
|
|
|
|
|
|
|
|
# 5) Weighted: class-wise AP weighted by support = TP + FN
|
|
|
|
|
if reduction == "weighted":
|
|
|
|
|
ap_per_class = np.zeros(num_classes, dtype=float)
|
|
|
|
|
support = np.zeros(num_classes, dtype=float)
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
tp_c = tp_per_class[c]
|
|
|
|
|
fp_c = fp_per_class[c]
|
|
|
|
|
fn_c = fn_per_class[c]
|
|
|
|
|
ap_per_class[c] = compute_average_precision_score(tp_c, fp_c, fn_c)
|
|
|
|
|
support[c] = tp_c + fn_c
|
|
|
|
|
total_support = support.sum()
|
|
|
|
|
if total_support == 0:
|
|
|
|
|
# fallback to unweighted macro if no positive instances
|
|
|
|
|
return float(ap_per_class.mean())
|
|
|
|
|
weights = support / total_support
|
|
|
|
|
return float((ap_per_class * weights).sum())
|
|
|
|
|
|
|
|
|
|
raise ValueError(f"Unknown reduction mode: {reduction}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __sigmoid(z: np.ndarray) -> np.ndarray:
|
|
|
|
|
"""
|
|
|
|
|
Numerically stable sigmoid activation for numpy arrays.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
z (np.ndarray): Input array.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
np.ndarray: Sigmoid of the input.
|
|
|
|
|
"""
|
|
|
|
|
return 1 / (1 + np.exp(-z))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __save_prediction_masks(
|
|
|
|
|
self,
|
|
|
|
|
sample: Dict[str, Any],
|
|
|
|
|
predicted_mask: Union[np.ndarray, torch.Tensor],
|
|
|
|
|
start_index: int = 0,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Save multi-channel predicted masks as TIFFs and corresponding visualizations as PNGs in separate folders.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
sample (Dict[str, Any]): Batch sample from MONAI LoadImaged (contains 'image', optional 'mask', and 'image_meta_dict').
|
|
|
|
|
predicted_mask (np.ndarray or torch.Tensor): Array of shape (C, H, W) or (B, C, H, W).
|
|
|
|
|
start_index (int): Starting index for naming when metadata is missing.
|
|
|
|
|
"""
|
|
|
|
|
# Determine base paths
|
|
|
|
|
base_output_dir = self._dataset_setup.common.predictions_dir
|
|
|
|
|
masks_dir = base_output_dir
|
|
|
|
|
plots_dir = os.path.join(base_output_dir, "plots")
|
|
|
|
|
os.makedirs(masks_dir, exist_ok=True)
|
|
|
|
|
os.makedirs(plots_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# Extract image (C, H, W) or batch of images (B, C, H, W), and metadata
|
|
|
|
|
image_obj = sample.get("image") # Expected shape: (C, H, W) or (B, C, H, W)
|
|
|
|
|
mask_obj = sample.get("mask") # Expected shape: (C, H, W) or (B, C, H, W)
|
|
|
|
|
image_meta = sample.get("image_meta_dict")
|
|
|
|
|
|
|
|
|
|
# Convert tensors to numpy
|
|
|
|
|
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
|
return x.detach().cpu().numpy()
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
image_array = to_numpy(image_obj) if image_obj is not None else None
|
|
|
|
|
mask_array = to_numpy(mask_obj) if mask_obj is not None else None
|
|
|
|
|
pred_array = to_numpy(predicted_mask)
|
|
|
|
|
|
|
|
|
|
# Handle batch dimension: (B, C, H, W)
|
|
|
|
|
if pred_array.ndim == 4:
|
|
|
|
|
for idx in range(pred_array.shape[0]):
|
|
|
|
|
batch_sample = dict(sample)
|
|
|
|
|
if image_array is not None and image_array.ndim == 4:
|
|
|
|
|
batch_sample["image"] = image_array[idx]
|
|
|
|
|
if isinstance(image_meta, list):
|
|
|
|
|
batch_sample["image_meta_dict"] = image_meta[idx]
|
|
|
|
|
if mask_array is not None and mask_array.ndim == 4:
|
|
|
|
|
batch_sample["mask"] = mask_array[idx]
|
|
|
|
|
self.__save_prediction_masks(
|
|
|
|
|
batch_sample,
|
|
|
|
|
pred_array[idx],
|
|
|
|
|
start_index=start_index+idx
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Determine base filename
|
|
|
|
|
if image_meta and "filename_or_obj" in image_meta:
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0]
|
|
|
|
|
else:
|
|
|
|
|
# Use provided start_index when metadata missing
|
|
|
|
|
base_name = f"prediction_{start_index:04d}"
|
|
|
|
|
|
|
|
|
|
# Now pred_array shape is (C, H, W)
|
|
|
|
|
num_channels = pred_array.shape[0]
|
|
|
|
|
for channel_idx in range(num_channels):
|
|
|
|
|
channel_mask = pred_array[channel_idx]
|
|
|
|
|
|
|
|
|
|
# File names
|
|
|
|
|
mask_filename = f"{base_name}_ch{channel_idx:02d}.tif"
|
|
|
|
|
plot_filename = f"{base_name}_ch{channel_idx:02d}.png"
|
|
|
|
|
mask_path = os.path.join(masks_dir, mask_filename)
|
|
|
|
|
plot_path = os.path.join(plots_dir, plot_filename)
|
|
|
|
|
|
|
|
|
|
# Save mask TIFF (16-bit)
|
|
|
|
|
tiff.imwrite(mask_path, channel_mask.astype(np.uint16), compression="zlib")
|
|
|
|
|
|
|
|
|
|
# Extract corresponding true mask channel if exists
|
|
|
|
|
true_mask = None
|
|
|
|
|
if mask_array is not None and mask_array.ndim == 3:
|
|
|
|
|
true_mask = mask_array[channel_idx]
|
|
|
|
|
|
|
|
|
|
# Generate and save visualization
|
|
|
|
|
self.__plot_mask(
|
|
|
|
|
file_path=plot_path,
|
|
|
|
|
image_data=image_array, # type: ignore
|
|
|
|
|
predicted_mask=channel_mask,
|
|
|
|
|
true_mask=true_mask,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __plot_mask(
|
|
|
|
|
self,
|
|
|
|
|
file_path: str,
|
|
|
|
|
image_data: np.ndarray,
|
|
|
|
|
predicted_mask: np.ndarray,
|
|
|
|
|
true_mask: Optional[np.ndarray] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Create and save grid visualization: 1x3 if no true mask, or 2x3 if true mask provided.
|
|
|
|
|
"""
|
|
|
|
|
img = np.moveaxis(image_data, 0, -1) if image_data.ndim == 3 else image_data
|
|
|
|
|
|
|
|
|
|
if true_mask is None:
|
|
|
|
|
fig, axs = plt.subplots(1, 3, figsize=(15,5))
|
|
|
|
|
plt.subplots_adjust(wspace=0.02, hspace=0)
|
|
|
|
|
self.__plot_panels(axs, img, predicted_mask, 'red',
|
|
|
|
|
('Original Image','Predicted Mask','Predicted Contours'))
|
|
|
|
|
else:
|
|
|
|
|
fig, axs = plt.subplots(2,3,figsize=(15,10))
|
|
|
|
|
plt.subplots_adjust(wspace=0.02, hspace=0.1)
|
|
|
|
|
# row 0: predicted
|
|
|
|
|
self.__plot_panels(axs[0], img, predicted_mask, 'red',
|
|
|
|
|
('Original Image','Predicted Mask','Predicted Contours'))
|
|
|
|
|
# row 1: true
|
|
|
|
|
self.__plot_panels(axs[1], img, true_mask, 'blue',
|
|
|
|
|
('Original Image','True Mask','True Contours'))
|
|
|
|
|
fig.savefig(file_path, bbox_inches='tight', dpi=300)
|
|
|
|
|
plt.close(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __plot_panels(
|
|
|
|
|
self,
|
|
|
|
|
axes,
|
|
|
|
|
img: np.ndarray,
|
|
|
|
|
mask: np.ndarray,
|
|
|
|
|
contour_color: str,
|
|
|
|
|
titles: Tuple[str, ...]
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Plot a row of three panels: original image, mask, and mask boundaries on image.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
axes: list/array of three Axis objects.
|
|
|
|
|
img: Image array (H, W or H, W, C).
|
|
|
|
|
mask: Label mask (H, W).
|
|
|
|
|
contour_color: Color for boundaries.
|
|
|
|
|
titles: (title_img, title_mask, title_contours).
|
|
|
|
|
"""
|
|
|
|
|
# Panel 1: Original image
|
|
|
|
|
ax0, ax1, ax2 = axes
|
|
|
|
|
ax0.imshow(img, cmap='gray' if img.ndim == 2 else None)
|
|
|
|
|
ax0.set_title(titles[0]); ax0.axis('off')
|
|
|
|
|
|
|
|
|
|
# Compute boundaries once
|
|
|
|
|
boundaries = find_boundaries(mask, mode='thick')
|
|
|
|
|
|
|
|
|
|
# Panel 2: Mask with black boundaries
|
|
|
|
|
cmap = plt.get_cmap("gist_ncar")
|
|
|
|
|
cmap = mcolors.ListedColormap([
|
|
|
|
|
cmap(i/len(np.unique(mask))) for i in range(len(np.unique(mask)))
|
|
|
|
|
])
|
|
|
|
|
ax1.imshow(mask, cmap=cmap)
|
|
|
|
|
ax1.contour(boundaries, colors='black', linewidths=0.5)
|
|
|
|
|
ax1.set_title(titles[1])
|
|
|
|
|
ax1.axis('off')
|
|
|
|
|
|
|
|
|
|
# Panel 3: Original image with black contour overlay
|
|
|
|
|
ax2.imshow(img)
|
|
|
|
|
# Draw boundaries as contour lines
|
|
|
|
|
ax2.contour(boundaries, colors=contour_color, linewidths=0.5)
|
|
|
|
|
ax2.set_title(titles[2])
|
|
|
|
|
ax2.axis('off')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __compute_flows_from_masks(
|
|
|
|
|
self,
|
|
|
|
|
true_masks: Tensor
|
|
|
|
@ -445,9 +1319,8 @@ class CellSegmentator:
|
|
|
|
|
true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
numpy array of concatenated [renumbered_true_masks, binary_masks, flow_vectors] per image.
|
|
|
|
|
renumbered_true_masks is labels, binary_masks is cell distance transform, flow_vectors[2] is Y flow, flows[k][3] is X flow,
|
|
|
|
|
and flow_vectors[4] is heat distribution.
|
|
|
|
|
numpy array of concatenated [flow_vectors, renumbered_true_masks] per image.
|
|
|
|
|
renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow.
|
|
|
|
|
"""
|
|
|
|
|
# Move to CPU numpy
|
|
|
|
|
_true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16)
|
|
|
|
@ -467,7 +1340,7 @@ class CellSegmentator:
|
|
|
|
|
flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i])
|
|
|
|
|
for i in range(batch_size)])
|
|
|
|
|
|
|
|
|
|
return np.concatenate((renumbered, renumbered > 0.5, flow_vectors), axis=1).astype(np.float32)
|
|
|
|
|
return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __compute_flow_from_mask(
|
|
|
|
@ -879,7 +1752,7 @@ class CellSegmentator:
|
|
|
|
|
prob_threshold: float = 0.0,
|
|
|
|
|
flow_threshold: float = 0.4,
|
|
|
|
|
num_iters: int = 200,
|
|
|
|
|
min_object_size: int = 0
|
|
|
|
|
min_object_size: int = 15
|
|
|
|
|
) -> np.ndarray:
|
|
|
|
|
"""
|
|
|
|
|
Generate instance segmentation masks from probability and flow fields.
|
|
|
|
@ -890,7 +1763,7 @@ class CellSegmentator:
|
|
|
|
|
prob_threshold: threshold to binarize probability_map. (Default 0.0)
|
|
|
|
|
flow_threshold: threshold for filtering bad flow masks. (Default 0.4)
|
|
|
|
|
num_iters: number of iterations for flow-following. (Default 200)
|
|
|
|
|
min_object_size: minimum area to keep small instances. (Default 0)
|
|
|
|
|
min_object_size: minimum area to keep small instances. (Default 15)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
3D array of uint16 instance labels for each channel.
|
|
|
|
|