|
|
|
@ -12,6 +12,7 @@ import fastremap
|
|
|
|
|
import fill_voids
|
|
|
|
|
from skimage import morphology
|
|
|
|
|
from skimage.segmentation import find_boundaries
|
|
|
|
|
from scipy.special import expit
|
|
|
|
|
from scipy.ndimage import mean, find_objects
|
|
|
|
|
|
|
|
|
|
from monai.data.dataset import Dataset
|
|
|
|
@ -53,10 +54,11 @@ logger = get_logger()
|
|
|
|
|
|
|
|
|
|
class CellSegmentator:
|
|
|
|
|
def __init__(self, config: Config) -> None:
|
|
|
|
|
self._device: torch.device = torch.device(config.dataset_config.common.device or "cpu")
|
|
|
|
|
|
|
|
|
|
self.__set_seed(config.dataset_config.common.seed)
|
|
|
|
|
self.__parse_config(config)
|
|
|
|
|
|
|
|
|
|
self._device: torch.device = torch.device(self._dataset_setup.common.device or "cpu")
|
|
|
|
|
self._scaler = (
|
|
|
|
|
torch.amp.GradScaler(self._device.type) # type: ignore
|
|
|
|
|
if self._dataset_setup.is_training and self._dataset_setup.common.use_amp
|
|
|
|
@ -153,7 +155,7 @@ class CellSegmentator:
|
|
|
|
|
# Train dataloader
|
|
|
|
|
train_dataset = self.__get_dataset(
|
|
|
|
|
images_dir=os.path.join(train_dir, 'images'),
|
|
|
|
|
masks_dir=os.path.join(train_dir, 'masks'),
|
|
|
|
|
masks_dir=os.path.join(train_dir, 'masks', self._dataset_setup.common.masks_subdir),
|
|
|
|
|
transforms=train_transforms, # type: ignore
|
|
|
|
|
size=self._dataset_setup.training.train_size,
|
|
|
|
|
offset=train_offset,
|
|
|
|
@ -168,7 +170,7 @@ class CellSegmentator:
|
|
|
|
|
raise RuntimeError("Validation directory or size is not properly configured.")
|
|
|
|
|
valid_dataset = self.__get_dataset(
|
|
|
|
|
images_dir=os.path.join(valid_dir, 'images'),
|
|
|
|
|
masks_dir=os.path.join(valid_dir, 'masks'),
|
|
|
|
|
masks_dir=os.path.join(valid_dir, 'masks', self._dataset_setup.common.masks_subdir),
|
|
|
|
|
transforms=valid_transforms,
|
|
|
|
|
size=self._dataset_setup.training.valid_size,
|
|
|
|
|
offset=valid_offset,
|
|
|
|
@ -183,7 +185,7 @@ class CellSegmentator:
|
|
|
|
|
raise RuntimeError("Test directory or size is not properly configured.")
|
|
|
|
|
test_dataset = self.__get_dataset(
|
|
|
|
|
images_dir=os.path.join(test_dir, 'images'),
|
|
|
|
|
masks_dir=os.path.join(test_dir, 'masks'),
|
|
|
|
|
masks_dir=os.path.join(test_dir, 'masks', self._dataset_setup.common.masks_subdir),
|
|
|
|
|
transforms=test_transforms,
|
|
|
|
|
size=self._dataset_setup.training.test_size,
|
|
|
|
|
offset=test_offset,
|
|
|
|
@ -210,7 +212,7 @@ class CellSegmentator:
|
|
|
|
|
else:
|
|
|
|
|
# Inference mode (no training)
|
|
|
|
|
test_images = os.path.join(self._dataset_setup.testing.test_dir, 'images')
|
|
|
|
|
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks')
|
|
|
|
|
test_masks = os.path.join(self._dataset_setup.testing.test_dir, 'masks', self._dataset_setup.common.masks_subdir)
|
|
|
|
|
|
|
|
|
|
if test_transforms is not None:
|
|
|
|
|
test_dataset = self.__get_dataset(
|
|
|
|
@ -385,7 +387,7 @@ class CellSegmentator:
|
|
|
|
|
batch_counter = 0
|
|
|
|
|
for batch in tqdm(self._predict_dataloader, desc="Predicting"):
|
|
|
|
|
# Move input images to the configured device (CPU/GPU)
|
|
|
|
|
inputs = batch["img"].to(self._device)
|
|
|
|
|
inputs = batch["image"].to(self._device)
|
|
|
|
|
|
|
|
|
|
# Use automatic mixed precision if enabled in dataset setup
|
|
|
|
|
with torch.amp.autocast( # type: ignore
|
|
|
|
@ -443,15 +445,40 @@ class CellSegmentator:
|
|
|
|
|
|
|
|
|
|
def load_from_checkpoint(self, checkpoint_path: str) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Loads model weights from a specified checkpoint into the current model.
|
|
|
|
|
Loads model weights from a specified checkpoint into the current model,
|
|
|
|
|
but only for parameters whose shapes match. Parameters with mismatched
|
|
|
|
|
shapes (e.g., classification heads with different output sizes) remain
|
|
|
|
|
at their initialized values.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
checkpoint_path (str): Path to the checkpoint file containing the model weights.
|
|
|
|
|
"""
|
|
|
|
|
# Load the checkpoint onto the correct device (CPU or GPU)
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=True)
|
|
|
|
|
# Load the state dict into the model, allowing for missing keys
|
|
|
|
|
self._model.load_state_dict(checkpoint['state_dict'], strict=False)
|
|
|
|
|
# Load the checkpoint (state_dict) from file onto CPU
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
|
|
|
|
# Extract nested state_dict if present
|
|
|
|
|
state_dict = checkpoint.get("state_dict", checkpoint)
|
|
|
|
|
|
|
|
|
|
# Get the current model's parameter dictionary
|
|
|
|
|
model_dict = self._model.state_dict()
|
|
|
|
|
|
|
|
|
|
# Filter pretrained parameters to those matching in name and shape
|
|
|
|
|
pretrained_dict = {
|
|
|
|
|
k: v for k, v in state_dict.items()
|
|
|
|
|
if k in model_dict and v.size() == model_dict[k].size()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Log how many parameters are loaded, skipped, or missing
|
|
|
|
|
skipped = [k for k in state_dict if k not in pretrained_dict]
|
|
|
|
|
missing = [k for k in model_dict if k not in pretrained_dict]
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Loaded {len(pretrained_dict)} parameters;"
|
|
|
|
|
f" skipped {len(skipped)} params from checkpoint;"
|
|
|
|
|
f" {len(missing)} params remain uninitialized in model."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Update the model's state_dict and load it
|
|
|
|
|
model_dict.update(pretrained_dict)
|
|
|
|
|
self._model.load_state_dict(model_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(self, checkpoint_path: str) -> None:
|
|
|
|
@ -461,12 +488,9 @@ class CellSegmentator:
|
|
|
|
|
Args:
|
|
|
|
|
checkpoint_path (str): Path where the checkpoint file will be saved.
|
|
|
|
|
"""
|
|
|
|
|
# Create a checkpoint dictionary containing the model’s state_dict
|
|
|
|
|
checkpoint = {
|
|
|
|
|
'state_dict': self._model.state_dict()
|
|
|
|
|
}
|
|
|
|
|
# Write the checkpoint to disk
|
|
|
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
|
|
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
|
|
|
|
torch.save(self._model.state_dict(), checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __parse_config(self, config: Config) -> None:
|
|
|
|
@ -492,15 +516,23 @@ class CellSegmentator:
|
|
|
|
|
if scheduler:
|
|
|
|
|
logger.info("Scheduler Config:\n%s", pformat(scheduler.dump(), indent=2))
|
|
|
|
|
logger.info("Dataset Config:\n%s", pformat(config.dataset_config.model_dump(), indent=2))
|
|
|
|
|
logger.info("Wandb Config:\n%s", pformat(config.wandb_config.model_dump(), indent=2))
|
|
|
|
|
logger.info("==========================================")
|
|
|
|
|
|
|
|
|
|
# Initialize model using the model registry
|
|
|
|
|
self._model = ModelRegistry.get_model_class(model.name)(model.params)
|
|
|
|
|
|
|
|
|
|
# Loads model weights from a specified checkpoint
|
|
|
|
|
if config.dataset_config.is_training:
|
|
|
|
|
if config.dataset_config.training.pretrained_weights:
|
|
|
|
|
self.load_from_checkpoint(config.dataset_config.training.pretrained_weights)
|
|
|
|
|
pretrained_weights = (
|
|
|
|
|
config.dataset_config.training.pretrained_weights
|
|
|
|
|
if config.dataset_config.is_training
|
|
|
|
|
else config.dataset_config.testing.pretrained_weights
|
|
|
|
|
)
|
|
|
|
|
if pretrained_weights:
|
|
|
|
|
self.load_from_checkpoint(pretrained_weights)
|
|
|
|
|
logger.info(f"Loaded pre-trained weights from: {pretrained_weights}")
|
|
|
|
|
|
|
|
|
|
self._model = self._model.to(self._device)
|
|
|
|
|
|
|
|
|
|
# Initialize loss criterion if specified
|
|
|
|
|
self._criterion = (
|
|
|
|
@ -555,6 +587,7 @@ class CellSegmentator:
|
|
|
|
|
logger.info(f"├─ Seed: {common.seed}")
|
|
|
|
|
logger.info(f"├─ Device: {common.device}")
|
|
|
|
|
logger.info(f"├─ Use AMP: {'yes' if common.use_amp else 'no'}")
|
|
|
|
|
logger.info(f"├─ Masks subdirectory: {common.masks_subdir}")
|
|
|
|
|
logger.info(f"└─ Predictions output dir: {common.predictions_dir}")
|
|
|
|
|
|
|
|
|
|
if config.dataset_config.is_training:
|
|
|
|
@ -592,18 +625,21 @@ class CellSegmentator:
|
|
|
|
|
logger.info(f" ├─ Ensemble model 1: {testing.ensemble_pretrained_weights1}")
|
|
|
|
|
logger.info(f" └─ Ensemble model 2: {testing.ensemble_pretrained_weights2}")
|
|
|
|
|
|
|
|
|
|
wandb_cfg = config.dataset_config.wandb
|
|
|
|
|
if wandb_cfg.use_wandb:
|
|
|
|
|
self._wandb_config = config.wandb_config
|
|
|
|
|
if self._wandb_config.use_wandb:
|
|
|
|
|
logger.info("[W&B]")
|
|
|
|
|
logger.info(f"├─ Project: {wandb_cfg.project}")
|
|
|
|
|
logger.info(f"├─ Entity: {wandb_cfg.entity}")
|
|
|
|
|
if wandb_cfg.name:
|
|
|
|
|
logger.info(f"├─ Run name: {wandb_cfg.name}")
|
|
|
|
|
if wandb_cfg.tags:
|
|
|
|
|
logger.info(f"├─ Tags: {', '.join(wandb_cfg.tags)}")
|
|
|
|
|
if wandb_cfg.notes:
|
|
|
|
|
logger.info(f"├─ Notes: {wandb_cfg.notes}")
|
|
|
|
|
logger.info(f"└─ Save code: {'yes' if wandb_cfg.save_code else 'no'}")
|
|
|
|
|
logger.info(f"├─ Project: {self._wandb_config.project}")
|
|
|
|
|
if self._wandb_config.group:
|
|
|
|
|
logger.info(f"├─ Group: {self._wandb_config.group}")
|
|
|
|
|
if self._wandb_config.entity:
|
|
|
|
|
logger.info(f"├─ Entity: {self._wandb_config.entity}")
|
|
|
|
|
if self._wandb_config.name:
|
|
|
|
|
logger.info(f"├─ Run name: {self._wandb_config.name}")
|
|
|
|
|
if self._wandb_config.tags:
|
|
|
|
|
logger.info(f"├─ Tags: {', '.join(self._wandb_config.tags)}")
|
|
|
|
|
if self._wandb_config.notes:
|
|
|
|
|
logger.info(f"├─ Notes: {self._wandb_config.notes}")
|
|
|
|
|
logger.info(f"└─ Save code: {'yes' if self._wandb_config.save_code else 'no'}")
|
|
|
|
|
else:
|
|
|
|
|
logger.info("[W&B] Logging disabled")
|
|
|
|
|
|
|
|
|
@ -657,13 +693,19 @@ class CellSegmentator:
|
|
|
|
|
ValueError: If dataset is too small for requested size or offset.
|
|
|
|
|
"""
|
|
|
|
|
# Collect sorted list of image paths
|
|
|
|
|
images = sorted(glob.glob(images_dir))
|
|
|
|
|
images = sorted(
|
|
|
|
|
glob.glob(os.path.join(images_dir, '*.tif')) +
|
|
|
|
|
glob.glob(os.path.join(images_dir, '*.tiff'))
|
|
|
|
|
)
|
|
|
|
|
if not images:
|
|
|
|
|
raise FileNotFoundError(f"No images found in path or pattern: '{images_dir}'")
|
|
|
|
|
|
|
|
|
|
if masks_dir is not None:
|
|
|
|
|
# Collect and validate sorted list of mask paths
|
|
|
|
|
masks = sorted(glob.glob(masks_dir))
|
|
|
|
|
masks = sorted(
|
|
|
|
|
glob.glob(os.path.join(masks_dir, '*.tif')) +
|
|
|
|
|
glob.glob(os.path.join(masks_dir, '*.tiff'))
|
|
|
|
|
)
|
|
|
|
|
if len(images) != len(masks):
|
|
|
|
|
raise ValueError(f"Number of masks ({len(masks)}) does not match number of images ({len(images)})")
|
|
|
|
|
|
|
|
|
@ -720,7 +762,7 @@ class CellSegmentator:
|
|
|
|
|
tablefmt="fancy_grid"
|
|
|
|
|
)
|
|
|
|
|
print(table, "\n")
|
|
|
|
|
if self._dataset_setup.wandb.use_wandb:
|
|
|
|
|
if self._wandb_config.use_wandb:
|
|
|
|
|
wandb.log(results, step=step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -765,8 +807,8 @@ class CellSegmentator:
|
|
|
|
|
# Iterate over batches
|
|
|
|
|
batch_counter = 0
|
|
|
|
|
for batch in tqdm(loader, desc=desc):
|
|
|
|
|
inputs = batch["img"].to(self._device)
|
|
|
|
|
targets = batch["label"].to(self._device)
|
|
|
|
|
inputs = batch["image"].to(self._device)
|
|
|
|
|
targets = batch["mask"].to(self._device)
|
|
|
|
|
|
|
|
|
|
# Zero gradients for training
|
|
|
|
|
if self._optimizer is not None:
|
|
|
|
@ -787,7 +829,10 @@ class CellSegmentator:
|
|
|
|
|
flow_targets = self.__compute_flows_from_masks(targets)
|
|
|
|
|
|
|
|
|
|
# Compute loss for this batch
|
|
|
|
|
batch_loss = self._criterion(raw_output, flow_targets) # type: ignore
|
|
|
|
|
batch_loss = self._criterion(
|
|
|
|
|
raw_output,
|
|
|
|
|
torch.from_numpy(flow_targets).to(device=raw_output.device)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Post-process and compute F1 during validation and testing
|
|
|
|
|
if mode in ("valid", "test"):
|
|
|
|
@ -842,6 +887,9 @@ class CellSegmentator:
|
|
|
|
|
epoch_metrics[f"{mode}_f1_score"] = self.__compute_f1_metric( # type: ignore
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="micro"
|
|
|
|
|
)
|
|
|
|
|
epoch_metrics[f"{mode}_f1_score_iw"] = self.__compute_f1_metric( # type: ignore
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="imagewise"
|
|
|
|
|
)
|
|
|
|
|
epoch_metrics[f"{mode}_mAP"] = self.__compute_average_precision_metric( # type: ignore
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="macro"
|
|
|
|
|
)
|
|
|
|
@ -914,7 +962,7 @@ class CellSegmentator:
|
|
|
|
|
instance_masks[idx] = self.__segment_instances(
|
|
|
|
|
probability_map=probabilities[idx],
|
|
|
|
|
flow=gradflow[idx],
|
|
|
|
|
prob_threshold=0.0,
|
|
|
|
|
prob_threshold=0.5,
|
|
|
|
|
flow_threshold=0.4,
|
|
|
|
|
min_object_size=15
|
|
|
|
|
)
|
|
|
|
@ -1159,7 +1207,7 @@ class CellSegmentator:
|
|
|
|
|
Returns:
|
|
|
|
|
np.ndarray: Sigmoid of the input.
|
|
|
|
|
"""
|
|
|
|
|
return 1 / (1 + np.exp(-z))
|
|
|
|
|
return expit(z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __save_prediction_masks(
|
|
|
|
@ -1191,7 +1239,7 @@ class CellSegmentator:
|
|
|
|
|
# Convert tensors to numpy
|
|
|
|
|
def to_numpy(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
|
return x.detach().cpu().numpy()
|
|
|
|
|
return x.cpu().numpy()
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
image_array = to_numpy(image_obj) if image_obj is not None else None
|
|
|
|
@ -1201,11 +1249,11 @@ class CellSegmentator:
|
|
|
|
|
# Handle batch dimension: (B, C, H, W)
|
|
|
|
|
if pred_array.ndim == 4:
|
|
|
|
|
for idx in range(pred_array.shape[0]):
|
|
|
|
|
batch_sample = dict(sample)
|
|
|
|
|
batch_sample: Dict[str, Any] = {}
|
|
|
|
|
if image_array is not None and image_array.ndim == 4:
|
|
|
|
|
batch_sample["image"] = image_array[idx]
|
|
|
|
|
if isinstance(image_meta, list):
|
|
|
|
|
batch_sample["image_meta_dict"] = image_meta[idx]
|
|
|
|
|
if isinstance(image_meta, dict) and "filename_or_obj" in image_meta:
|
|
|
|
|
batch_sample["image_meta_dict"] = image_meta["filename_or_obj"][idx]
|
|
|
|
|
if mask_array is not None and mask_array.ndim == 4:
|
|
|
|
|
batch_sample["mask"] = mask_array[idx]
|
|
|
|
|
self.__save_prediction_masks(
|
|
|
|
@ -1216,8 +1264,8 @@ class CellSegmentator:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Determine base filename
|
|
|
|
|
if image_meta and "filename_or_obj" in image_meta:
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(image_meta["filename_or_obj"]))[0]
|
|
|
|
|
if isinstance(image_meta, (str, os.PathLike)):
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(image_meta))[0]
|
|
|
|
|
else:
|
|
|
|
|
# Use provided start_index when metadata missing
|
|
|
|
|
base_name = f"prediction_{start_index:04d}"
|
|
|
|
@ -1228,8 +1276,8 @@ class CellSegmentator:
|
|
|
|
|
channel_mask = pred_array[channel_idx]
|
|
|
|
|
|
|
|
|
|
# File names
|
|
|
|
|
mask_filename = f"{base_name}_ch{channel_idx:02d}.tif"
|
|
|
|
|
plot_filename = f"{base_name}_ch{channel_idx:02d}.png"
|
|
|
|
|
mask_filename = f"{base_name}_ch{channel_idx:01d}.tif"
|
|
|
|
|
plot_filename = f"{base_name}_ch{channel_idx:01d}.png"
|
|
|
|
|
mask_path = os.path.join(masks_dir, mask_filename)
|
|
|
|
|
plot_path = os.path.join(plots_dir, plot_filename)
|
|
|
|
|
|
|
|
|
@ -1402,78 +1450,81 @@ class CellSegmentator:
|
|
|
|
|
flows = np.zeros((2*channels, height, width), np.float32)
|
|
|
|
|
|
|
|
|
|
for channel in range(channels):
|
|
|
|
|
padded_height, padded_width = height + 2, width + 2
|
|
|
|
|
|
|
|
|
|
# Pad the mask with a 1-pixel border
|
|
|
|
|
masks_padded = torch.from_numpy(mask.astype(np.int64)).to(self._device)
|
|
|
|
|
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
|
|
|
|
|
|
|
|
|
|
# Get coordinates of all non-zero pixels in the padded mask
|
|
|
|
|
y, x = torch.nonzero(masks_padded, as_tuple=True)
|
|
|
|
|
y = y.int(); x = x.int() # ensure integer type
|
|
|
|
|
|
|
|
|
|
# Generate 8-connected neighbors (including center) via broadcasted offsets
|
|
|
|
|
offsets = torch.tensor([
|
|
|
|
|
[ 0, 0], # center
|
|
|
|
|
[-1, 0], # up
|
|
|
|
|
[ 1, 0], # down
|
|
|
|
|
[ 0, -1], # left
|
|
|
|
|
[ 0, 1], # right
|
|
|
|
|
[-1, -1], # up-left
|
|
|
|
|
[-1, 1], # up-right
|
|
|
|
|
[ 1, -1], # down-left
|
|
|
|
|
[ 1, 1], # down-right
|
|
|
|
|
], dtype=torch.int32, device=self._device) # (9, 2)
|
|
|
|
|
|
|
|
|
|
# coords: (N, 2)
|
|
|
|
|
coords = torch.stack((y, x), dim=1)
|
|
|
|
|
|
|
|
|
|
# neighbors: (9, N, 2)
|
|
|
|
|
neighbors = offsets[:, None, :] + coords[None, :, :]
|
|
|
|
|
|
|
|
|
|
# transpose into (2, 9, N) for the GPU kernel
|
|
|
|
|
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
|
|
|
|
|
|
|
|
|
|
# Build connectivity mask: True where neighbor label == center label
|
|
|
|
|
center_labels = masks_padded[y, x][None, :] # (1, N)
|
|
|
|
|
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
|
|
|
|
|
is_neighbor = neighbor_labels == center_labels # (9, N)
|
|
|
|
|
|
|
|
|
|
# Compute object slices and pack into array for get_centers
|
|
|
|
|
slices = find_objects(mask)
|
|
|
|
|
slices_arr = np.array([
|
|
|
|
|
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
|
|
|
|
|
for i, sl in enumerate(slices) if sl is not None
|
|
|
|
|
], dtype=int)
|
|
|
|
|
mask_channel = mask[channel]
|
|
|
|
|
|
|
|
|
|
# Compute centers (pixel indices) and extents via the provided helper
|
|
|
|
|
centers, ext = self.__get_mask_centers_and_extents(mask, slices_arr)
|
|
|
|
|
# Move centers to GPU and shift by +1 for padding
|
|
|
|
|
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding
|
|
|
|
|
|
|
|
|
|
# Determine number of diffusion iterations
|
|
|
|
|
n_iter = 2 * ext.max()
|
|
|
|
|
|
|
|
|
|
# Run the GPU diffusion kernel
|
|
|
|
|
mu = self.__propagate_centers_gpu(
|
|
|
|
|
neighbor_indices=neighbors,
|
|
|
|
|
center_indices=meds_p.T,
|
|
|
|
|
valid_neighbor_mask=is_neighbor,
|
|
|
|
|
output_shape=(padded_height, padded_width),
|
|
|
|
|
num_iterations=n_iter
|
|
|
|
|
)
|
|
|
|
|
if mask_channel.max() > 0:
|
|
|
|
|
padded_height, padded_width = height + 2, width + 2
|
|
|
|
|
|
|
|
|
|
# Pad the mask with a 1-pixel border
|
|
|
|
|
masks_padded = torch.from_numpy(mask_channel.astype(np.int64)).to(self._device)
|
|
|
|
|
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
|
|
|
|
|
|
|
|
|
|
# Get coordinates of all non-zero pixels in the padded mask
|
|
|
|
|
y, x = torch.nonzero(masks_padded, as_tuple=True)
|
|
|
|
|
y = y.int(); x = x.int() # ensure integer type
|
|
|
|
|
|
|
|
|
|
# Generate 8-connected neighbors (including center) via broadcasted offsets
|
|
|
|
|
offsets = torch.tensor([
|
|
|
|
|
[ 0, 0], # center
|
|
|
|
|
[-1, 0], # up
|
|
|
|
|
[ 1, 0], # down
|
|
|
|
|
[ 0, -1], # left
|
|
|
|
|
[ 0, 1], # right
|
|
|
|
|
[-1, -1], # up-left
|
|
|
|
|
[-1, 1], # up-right
|
|
|
|
|
[ 1, -1], # down-left
|
|
|
|
|
[ 1, 1], # down-right
|
|
|
|
|
], dtype=torch.int32, device=self._device) # (9, 2)
|
|
|
|
|
|
|
|
|
|
# coords: (N, 2)
|
|
|
|
|
coords = torch.stack((y, x), dim=1)
|
|
|
|
|
|
|
|
|
|
# neighbors: (9, N, 2)
|
|
|
|
|
neighbors = offsets[:, None, :] + coords[None, :, :]
|
|
|
|
|
|
|
|
|
|
# transpose into (2, 9, N) for the GPU kernel
|
|
|
|
|
neighbors = neighbors.permute(2, 0, 1) # first dim is y/x, second is neighbor index
|
|
|
|
|
|
|
|
|
|
# Build connectivity mask: True where neighbor label == center label
|
|
|
|
|
center_labels = masks_padded[y, x][None, :] # (1, N)
|
|
|
|
|
neighbor_labels = masks_padded[neighbors[0], neighbors[1]] # (9, N)
|
|
|
|
|
is_neighbor = neighbor_labels == center_labels # (9, N)
|
|
|
|
|
|
|
|
|
|
# Compute object slices and pack into array for get_centers
|
|
|
|
|
slices = find_objects(mask_channel)
|
|
|
|
|
slices_arr = np.array([
|
|
|
|
|
[i, sl[0].start, sl[0].stop, sl[1].start, sl[1].stop]
|
|
|
|
|
for i, sl in enumerate(slices, start=1) if sl is not None
|
|
|
|
|
], dtype=np.int16)
|
|
|
|
|
|
|
|
|
|
# Compute centers (pixel indices) and extents via the provided helper
|
|
|
|
|
centers, ext = self.__get_mask_centers_and_extents(mask_channel, slices_arr)
|
|
|
|
|
# Move centers to GPU and shift by +1 for padding
|
|
|
|
|
meds_p = torch.from_numpy(centers).to(self._device).long() + 1 # (M, 2); +1 for padding
|
|
|
|
|
|
|
|
|
|
# Determine number of diffusion iterations
|
|
|
|
|
n_iter = 2 * ext.max()
|
|
|
|
|
|
|
|
|
|
# Run the GPU diffusion kernel
|
|
|
|
|
mu = self.__propagate_centers_gpu(
|
|
|
|
|
neighbor_indices=neighbors,
|
|
|
|
|
center_indices=meds_p.T,
|
|
|
|
|
valid_neighbor_mask=is_neighbor,
|
|
|
|
|
output_shape=(padded_height, padded_width),
|
|
|
|
|
num_iterations=n_iter
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Cast to float64 and normalize flow vectors
|
|
|
|
|
mu = mu.astype(np.float64)
|
|
|
|
|
mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60
|
|
|
|
|
# Cast to float64 and normalize flow vectors
|
|
|
|
|
mu = mu.astype(np.float64)
|
|
|
|
|
mu /= np.sqrt((mu**2).sum(axis=0)) + 1e-60
|
|
|
|
|
|
|
|
|
|
# Remove the padding and write into final output
|
|
|
|
|
flow_output = np.zeros((2, height, width), dtype=np.float32)
|
|
|
|
|
ys_np = y.cpu().numpy() - 1
|
|
|
|
|
xs_np = x.cpu().numpy() - 1
|
|
|
|
|
flow_output[:, ys_np, xs_np] = mu
|
|
|
|
|
flows[2*channel: 2*channel + 2] = flow_output
|
|
|
|
|
|
|
|
|
|
# Remove the padding and write into final output
|
|
|
|
|
flow_output = np.zeros((2, height, width), dtype=np.float32)
|
|
|
|
|
ys_np = y.cpu().numpy() - 1
|
|
|
|
|
xs_np = x.cpu().numpy() - 1
|
|
|
|
|
flow_output[:, ys_np, xs_np] = mu
|
|
|
|
|
flows[2*channel: 2*channel + 2] = flow_output
|
|
|
|
|
|
|
|
|
|
return flows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1624,8 +1675,10 @@ class CellSegmentator:
|
|
|
|
|
# Initialize flow_field with two channels: dy and dx
|
|
|
|
|
flow_field = np.zeros((2, height, width), dtype=np.float64)
|
|
|
|
|
|
|
|
|
|
mask_channel = mask[channel]
|
|
|
|
|
|
|
|
|
|
# Find bounding box for each labeled mask
|
|
|
|
|
mask_slices = find_objects(mask)
|
|
|
|
|
mask_slices = find_objects(mask_channel)
|
|
|
|
|
# centers: List[Tuple[int, int]] = []
|
|
|
|
|
|
|
|
|
|
# Iterate over mask labels in parallel
|
|
|
|
@ -1642,7 +1695,7 @@ class CellSegmentator:
|
|
|
|
|
|
|
|
|
|
# Get local coordinates of mask pixels within the patch
|
|
|
|
|
local_rows, local_cols = np.nonzero(
|
|
|
|
|
mask[row_slice, col_slice] == (label_idx + 1)
|
|
|
|
|
mask_channel[row_slice, col_slice] == (label_idx + 1)
|
|
|
|
|
)
|
|
|
|
|
# Shift coords by +1 for the border padding
|
|
|
|
|
local_rows = local_rows.astype(np.int32) + 1
|
|
|
|
@ -1774,8 +1827,8 @@ class CellSegmentator:
|
|
|
|
|
Generate instance segmentation masks from probability and flow fields.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
probability_map: 3D array (channels, height, width) of cell probabilities.
|
|
|
|
|
flow: 3D array (2*channels, height, width) of forward flow vectors.
|
|
|
|
|
probability_map: 3D array `(C, H, W)` of cell probabilities.
|
|
|
|
|
flow: 3D array `(2*C, H, W)` of forward flow vectors.
|
|
|
|
|
prob_threshold: threshold to binarize probability_map. (Default 0.0)
|
|
|
|
|
flow_threshold: threshold for filtering bad flow masks. (Default 0.4)
|
|
|
|
|
num_iters: number of iterations for flow-following. (Default 200)
|
|
|
|
@ -1802,6 +1855,9 @@ class CellSegmentator:
|
|
|
|
|
# Extract binary mask for this channel
|
|
|
|
|
channel_mask = probability_mask[channel_index]
|
|
|
|
|
|
|
|
|
|
if not channel_mask.sum():
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
nonzero_coords = np.stack(np.nonzero(channel_mask))
|
|
|
|
|
|
|
|
|
|
# Follow the flow vectors to generate coordinate mappings
|
|
|
|
@ -1810,12 +1866,6 @@ class CellSegmentator:
|
|
|
|
|
initial_coords=nonzero_coords,
|
|
|
|
|
num_iters=num_iters
|
|
|
|
|
)
|
|
|
|
|
# If flow following fails, leave this channel empty
|
|
|
|
|
if flow_coordinates is None:
|
|
|
|
|
labeled_instances[channel_index] = np.zeros(
|
|
|
|
|
probability_map.shape[1:], dtype=np.uint16
|
|
|
|
|
)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not torch.is_tensor(flow_coordinates):
|
|
|
|
|
flow_coordinates = torch.from_numpy(
|
|
|
|
@ -1851,11 +1901,6 @@ class CellSegmentator:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
labeled_instances[channel_index] = channel_instances_mask
|
|
|
|
|
else:
|
|
|
|
|
# No valid instances found, leave the channel empty
|
|
|
|
|
labeled_instances[channel_index] = np.zeros(
|
|
|
|
|
probability_map.shape[1:], dtype=np.uint16
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return labeled_instances
|
|
|
|
|
|
|
|
|
@ -1923,7 +1968,7 @@ class CellSegmentator:
|
|
|
|
|
)
|
|
|
|
|
# Update each coordinate and clamp to valid range
|
|
|
|
|
for i in range(dims):
|
|
|
|
|
pts[..., i] = torch.clamp(pts[..., i] + sampled[0, i], -1.0, 1.0)
|
|
|
|
|
pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0)
|
|
|
|
|
|
|
|
|
|
# Denormalize back to original pixel coordinates
|
|
|
|
|
pts = (pts + 1) * 0.5 * max_idx_pt
|
|
|
|
@ -2072,16 +2117,21 @@ class CellSegmentator:
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# Step 3: Find peaks via 5x5 max-pooling
|
|
|
|
|
k = 5
|
|
|
|
|
pooled = F.max_pool2d(
|
|
|
|
|
# k = 5
|
|
|
|
|
# pooled = F.max_pool2d(
|
|
|
|
|
# histogram.float().unsqueeze(0).unsqueeze(1),
|
|
|
|
|
# kernel_size=k,
|
|
|
|
|
# stride=1,
|
|
|
|
|
# padding=k // 2
|
|
|
|
|
# ).squeeze()
|
|
|
|
|
pooled = self.__max_pool_nd(
|
|
|
|
|
histogram.unsqueeze(0),
|
|
|
|
|
kernel_size=k,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=k // 2
|
|
|
|
|
kernel_size=5
|
|
|
|
|
).squeeze()
|
|
|
|
|
|
|
|
|
|
# Seeds are positions where histogram equals local max and count > threshold
|
|
|
|
|
seed_positions = torch.nonzero((histogram - pooled == 0) & (histogram > 10))
|
|
|
|
|
if seed_positions.numel() == 0:
|
|
|
|
|
seed_positions = torch.nonzero((histogram - pooled == 0) * (histogram > 10))
|
|
|
|
|
if seed_positions.shape[0] == 0:
|
|
|
|
|
logger.warning("No seeds found: returning empty mask")
|
|
|
|
|
return np.zeros(original_shape, dtype=np.uint16)
|
|
|
|
|
|
|
|
|
@ -2106,13 +2156,14 @@ class CellSegmentator:
|
|
|
|
|
seed_masks[:, 5, 5] = 1
|
|
|
|
|
# Iterative dilation and thresholding
|
|
|
|
|
for _ in range(5):
|
|
|
|
|
seed_masks = F.max_pool2d(
|
|
|
|
|
seed_masks,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=1
|
|
|
|
|
)
|
|
|
|
|
seed_masks = seed_masks & (patches > 2)
|
|
|
|
|
# seed_masks = F.max_pool2d(
|
|
|
|
|
# seed_masks.float().unsqueeze(0),
|
|
|
|
|
# kernel_size=3,
|
|
|
|
|
# stride=1,
|
|
|
|
|
# padding=1
|
|
|
|
|
# ).squeeze(0).int()
|
|
|
|
|
seed_masks = self.__max_pool_nd(seed_masks, kernel_size=3)
|
|
|
|
|
seed_masks *= (patches > 2)
|
|
|
|
|
# Compute final mask coordinates
|
|
|
|
|
final_coords = []
|
|
|
|
|
for idx in range(num_seeds):
|
|
|
|
@ -2133,7 +2184,7 @@ class CellSegmentator:
|
|
|
|
|
|
|
|
|
|
# Step 6: Map to original image and remove oversized masks
|
|
|
|
|
mask_final = np.zeros(original_shape, dtype=np.uint16 if num_seeds < 2**16 else np.uint32)
|
|
|
|
|
mask_final[valid_indices] = mask_values
|
|
|
|
|
mask_final[tuple(valid_indices)] = mask_values
|
|
|
|
|
|
|
|
|
|
# Prune masks that are too large
|
|
|
|
|
labels, counts = fastremap.unique(mask_final, return_counts=True)
|
|
|
|
@ -2146,6 +2197,96 @@ class CellSegmentator:
|
|
|
|
|
return mask_final
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __max_pool1d(
|
|
|
|
|
self,
|
|
|
|
|
input_tensor: Tensor,
|
|
|
|
|
kernel_size: int = 5,
|
|
|
|
|
axis: int = 1,
|
|
|
|
|
output_tensor: Optional[Tensor] = None
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Memory-efficient 1D max pooling along a specified axis using in-place updates.
|
|
|
|
|
Requires:
|
|
|
|
|
- stride = 1
|
|
|
|
|
- padding = kernel_size // 2
|
|
|
|
|
- odd kernel_size >= 3
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input_tensor (Tensor): Source tensor for pooling.
|
|
|
|
|
kernel_size (int): Size of the pooling window (must be odd and >= 3).
|
|
|
|
|
axis (int): Axis along which to compute 1D max pooling.
|
|
|
|
|
output_tensor (Optional[Tensor]): Tensor to store the result.
|
|
|
|
|
If None, a clone of input_tensor is used.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: The pooled tensor, same shape as input_tensor.
|
|
|
|
|
"""
|
|
|
|
|
# Initialize or copy data into the output tensor
|
|
|
|
|
if output_tensor is None:
|
|
|
|
|
output = input_tensor.clone()
|
|
|
|
|
else:
|
|
|
|
|
output = output_tensor
|
|
|
|
|
output.copy_(input_tensor)
|
|
|
|
|
|
|
|
|
|
# Number of elements along the chosen axis and half-window size
|
|
|
|
|
dimension_size = input_tensor.shape[axis]
|
|
|
|
|
half_window = kernel_size // 2
|
|
|
|
|
|
|
|
|
|
# Slide window offsets from -half_window to +half_window
|
|
|
|
|
for offset in range(-half_window, half_window + 1):
|
|
|
|
|
# Compute slice indices depending on axis
|
|
|
|
|
if axis == 1:
|
|
|
|
|
target_slice = output[:, max(-offset, 0): min(dimension_size - offset, dimension_size)]
|
|
|
|
|
source_slice = input_tensor[:, max(offset, 0): min(dimension_size + offset, dimension_size)]
|
|
|
|
|
elif axis == 2:
|
|
|
|
|
target_slice = output[:, :, max(-offset, 0): min(dimension_size - offset, dimension_size)]
|
|
|
|
|
source_slice = input_tensor[:, :, max(offset, 0): min(dimension_size + offset, dimension_size)]
|
|
|
|
|
elif axis == 3:
|
|
|
|
|
target_slice = output[:, :, :, max(-offset, 0): min(dimension_size - offset, dimension_size)]
|
|
|
|
|
source_slice = input_tensor[:, :, :, max(offset, 0): min(dimension_size + offset, dimension_size)]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported axis {axis} for 1D pooling")
|
|
|
|
|
|
|
|
|
|
# In-place element-wise maximum
|
|
|
|
|
torch.maximum(target_slice, source_slice, out=target_slice)
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __max_pool_nd(
|
|
|
|
|
self,
|
|
|
|
|
input_tensor: Tensor,
|
|
|
|
|
kernel_size: int = 5
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Memory-efficient N-dimensional max pooling for 2D or 3D spatial data.
|
|
|
|
|
Applies 1D max pooling sequentially over each spatial axis.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input_tensor (Tensor): Input tensor with shape
|
|
|
|
|
(batch_size, dim1, dim2, ..., dimN).
|
|
|
|
|
kernel_size (int): Size of the pooling window (must be odd and >= 3).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: The pooled tensor, same shape as input_tensor.
|
|
|
|
|
"""
|
|
|
|
|
# Determine number of spatial dimensions (excluding batch axis)
|
|
|
|
|
num_spatial_dims = input_tensor.ndim - 1
|
|
|
|
|
|
|
|
|
|
# First pass: pool along axis=1
|
|
|
|
|
pooled = self.__max_pool1d(input_tensor, kernel_size=kernel_size, axis=1)
|
|
|
|
|
# Second pass: pool along axis=2
|
|
|
|
|
pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=2)
|
|
|
|
|
|
|
|
|
|
# If 3D data, apply a third pass along axis=3
|
|
|
|
|
if num_spatial_dims == 3:
|
|
|
|
|
pooled = self.__max_pool1d(pooled, kernel_size=kernel_size, axis=3)
|
|
|
|
|
elif num_spatial_dims != 2:
|
|
|
|
|
raise ValueError("max_pool_nd only supports 2D or 3D spatial data")
|
|
|
|
|
|
|
|
|
|
return pooled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __remove_inconsistent_flow_masks(
|
|
|
|
|
self,
|
|
|
|
|
mask: np.ndarray,
|
|
|
|
@ -2175,11 +2316,7 @@ class CellSegmentator:
|
|
|
|
|
"""
|
|
|
|
|
# If mask is very large and running on CUDA, check memory
|
|
|
|
|
num_pixels = mask.size
|
|
|
|
|
if (
|
|
|
|
|
num_pixels > 10000 * 10000
|
|
|
|
|
|
|
|
|
|
and self._device.type == 'cuda'
|
|
|
|
|
):
|
|
|
|
|
if num_pixels > 10000 * 10000 and self._device.type == 'cuda':
|
|
|
|
|
# Clear unused GPU cache
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
# Determine PyTorch version
|
|
|
|
@ -2296,15 +2433,8 @@ class CellSegmentator:
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
# Optionally remove masks smaller than minimum_size
|
|
|
|
|
if minimum_size >= 0:
|
|
|
|
|
# Compute label counts (skipping background at index 0)
|
|
|
|
|
labels, counts = fastremap.unique(masks, return_counts=True)
|
|
|
|
|
# Identify labels to remove: those with count < minimum_size
|
|
|
|
|
small_labels = labels[counts < minimum_size]
|
|
|
|
|
if small_labels.size > 0:
|
|
|
|
|
masks = fastremap.mask(masks, small_labels)
|
|
|
|
|
fastremap.renumber(masks, in_place=True)
|
|
|
|
|
# Initial pruning of too-small masks
|
|
|
|
|
masks = self._prune_small_masks(masks, minimum_size)
|
|
|
|
|
|
|
|
|
|
# Find bounding boxes for each mask label
|
|
|
|
|
object_slices = find_objects(masks)
|
|
|
|
@ -2325,12 +2455,36 @@ class CellSegmentator:
|
|
|
|
|
output_masks[slc][filled_region] = new_label
|
|
|
|
|
new_label += 1
|
|
|
|
|
|
|
|
|
|
# Final pruning of small masks after filling (optional)
|
|
|
|
|
if minimum_size >= 0:
|
|
|
|
|
labels, counts = fastremap.unique(output_masks, return_counts=True)
|
|
|
|
|
small_labels = labels[counts < minimum_size]
|
|
|
|
|
if small_labels.size > 0:
|
|
|
|
|
output_masks = fastremap.mask(output_masks, small_labels)
|
|
|
|
|
fastremap.renumber(output_masks, in_place=True)
|
|
|
|
|
# Final pruning after hole filling
|
|
|
|
|
output_masks = self._prune_small_masks(output_masks, minimum_size)
|
|
|
|
|
return output_masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prune_small_masks(
|
|
|
|
|
self,
|
|
|
|
|
masks: np.ndarray,
|
|
|
|
|
minimum_size: int
|
|
|
|
|
) -> np.ndarray:
|
|
|
|
|
"""
|
|
|
|
|
Remove labeled regions in `masks` whose pixel count is below `minimum_size`.
|
|
|
|
|
|
|
|
|
|
return output_masks
|
|
|
|
|
Args:
|
|
|
|
|
masks (np.ndarray): Integer mask array (any shape), 0=background.
|
|
|
|
|
minimum_size (int): Minimum pixel count; labels smaller are removed. If <0, skip pruning.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
np.ndarray: Mask array with small labels suppressed and labels renumbered.
|
|
|
|
|
"""
|
|
|
|
|
if minimum_size < 0:
|
|
|
|
|
return masks
|
|
|
|
|
|
|
|
|
|
labels, counts = fastremap.unique(masks, return_counts=True)
|
|
|
|
|
# Skip background label at index 0
|
|
|
|
|
non_bg_labels = labels[1:]
|
|
|
|
|
non_bg_counts = counts[1:]
|
|
|
|
|
# Identify labels to remove
|
|
|
|
|
small_labels = non_bg_labels[non_bg_counts < minimum_size]
|
|
|
|
|
if small_labels.size > 0:
|
|
|
|
|
masks = fastremap.mask(masks, small_labels)
|
|
|
|
|
fastremap.renumber(masks, in_place=True)
|
|
|
|
|
return masks
|