You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

194 lines
7.2 KiB

import torch
import numpy as np
from typing import Sequence
from monai.utils.misc import fall_back_tuple
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms import Randomizable, RandCropd, Crop # type: ignore
from core.logger import get_logger
logger = get_logger(__name__)
def _compute_multilabel_bbox(
mask: np.ndarray
) -> tuple[list[int], list[int], list[int], list[int]] | None:
"""
Compute per-channel bounding-box constraints and return lists of limits for each axis.
Args:
mask: multi-channel instance mask of shape (C, H, W).
Returns:
A tuple of four lists:
- top_mins: list of r_max for each non-empty channel
- top_maxs: list of r_min for each non-empty channel
- left_mins: list of c_max for each non-empty channel
- left_maxs: list of c_min for each non-empty channel
Or None if mask contains no positive labels.
"""
channels, rows, cols = np.nonzero(mask)
if channels.size == 0:
return None
top_mins: list[int] = []
top_maxs: list[int] = []
left_mins: list[int] = []
left_maxs: list[int] = []
C = mask.shape[0]
for ch in range(C):
rs, cs = np.nonzero(mask[ch])
if rs.size == 0:
continue
r_min, r_max = int(rs.min()), int(rs.max())
c_min, c_max = int(cs.min()), int(cs.max())
# For each channel, record the row/col extents
top_mins.append(r_max)
top_maxs.append(r_min)
left_mins.append(c_max)
left_maxs.append(c_min)
return top_mins, top_maxs, left_mins, left_maxs
class SpatialCropAllClasses(Randomizable, Crop):
"""
Cropper for multi-label instance masks and images: ensures each label-channel's
instances lie within the crop if possible.
Must be called on a mask tensor first to compute the crop, then on the image.
Args:
roi_size: desired crop size (height, width).
num_candidates: fallback samples when no single crop fits all instances.
lazy: defer actual cropping.
"""
def __init__(
self,
roi_size: Sequence[int],
num_candidates: int = 10,
lazy: bool = False,
) -> None:
super().__init__(lazy=lazy)
self.roi_size = tuple(roi_size)
self.num_candidates = num_candidates
self._slices: tuple[slice, ...] | None = None
def randomize(self, img_size: Sequence[int]) -> None: # type: ignore
"""
Choose crop offsets so that each non-empty channel is included if possible.
"""
height, width = img_size
crop_h, crop_w = self.roi_size
max_top = max(0, height - crop_h)
max_left = max(0, width - crop_w)
# Compute per-channel bbox constraints
mask = self._img
bboxes = _compute_multilabel_bbox(mask)
if bboxes is None:
# no labels: random patch using MONAI utils
logger.warning("No labels found; using random patch.")
# determine actual patch size (fallback)
self._size = fall_back_tuple(self.roi_size, img_size)
# compute valid size for random patch
valid_size = get_valid_patch_size(img_size, self._size)
# directly get random patch slices
self._slices = get_random_patch(img_size, valid_size, self.R)
return
else:
top_mins, top_maxs, left_mins, left_maxs = bboxes
# Convert to allowable windows
# top_min_global = max(r_max - crop_h +1 for each channel)
global_top_min = max(0, max(r_max - crop_h + 1 for r_max in top_mins))
# top_max_global = min(r_min for each channel)
global_top_max = min(min(top_maxs), max_top)
# same for left
global_left_min = max(0, max(c_max - crop_w + 1 for c_max in left_mins))
global_left_max = min(min(left_maxs), max_left)
if global_top_min <= global_top_max and global_left_min <= global_left_max:
# there is a window covering all channels fully
top = self.R.randint(global_top_min, global_top_max + 1)
left = self.R.randint(global_left_min, global_left_max + 1)
else:
# fallback: sample candidates to maximize channel coverage
logger.warning(
f"Cannot fit all instances; sampling {self.num_candidates} candidates."
)
best_cover = -1
best_top = best_left = 0
C = mask.shape[0]
for _ in range(self.num_candidates):
cand_top = self.R.randint(0, max_top + 1)
cand_left = self.R.randint(0, max_left + 1)
window = mask[:, cand_top : cand_top + crop_h, cand_left : cand_left + crop_w]
cover = sum(int(window[ch].any()) for ch in range(C))
if cover > best_cover:
best_cover = cover
best_top, best_left = cand_top, cand_left
logger.info(f"Selected crop covering {best_cover}/{C} channels.")
top, left = best_top, best_left
# store slices for use on both mask and image
self._slices = (
slice(None),
slice(top, top + crop_h),
slice(left, left + crop_w),
)
def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore
"""
On first call (mask), computes crop. On subsequent (image), just applies.
Raises if mask not provided first.
"""
# Determine tensor shape
img_size = (
img.peek_pending_shape()[1:]
if isinstance(img, MetaTensor)
else img.shape[1:]
)
# First call must be mask to compute slices
if self._slices is None:
if not torch.is_floating_point(img) and img.dtype in (torch.uint8, torch.int16, torch.int32, torch.int64):
# assume integer mask
self._img = img.cpu().numpy()
self.randomize(img_size)
else:
raise RuntimeError(
"Mask tensor must be passed first for computing crop bounds."
)
# Now apply stored slice
if self._slices is None:
raise RuntimeError("Crop slices not computed; call on mask first.")
lazy_exec = self.lazy if lazy is None else lazy
return super().__call__(img=img, slices=self._slices, lazy=lazy_exec)
class RandSpatialCropAllClassesd(RandCropd):
"""
Dict-based wrapper: applies SpatialCropAllClasses to mask then image.
Requires mask present or raises.
"""
def __init__(
self,
keys: Sequence,
roi_size: Sequence[int],
num_candidates: int = 10,
allow_missing_keys: bool = False,
lazy: bool = False,
):
cropper = SpatialCropAllClasses(
roi_size=roi_size,
num_candidates=num_candidates,
lazy=lazy,
)
super().__init__(
keys=keys,
cropper=cropper,
allow_missing_keys=allow_missing_keys,
lazy=lazy,
)