You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
194 lines
7.2 KiB
194 lines
7.2 KiB
import torch
|
|
import numpy as np
|
|
from typing import Sequence
|
|
|
|
from monai.utils.misc import fall_back_tuple
|
|
from monai.data.meta_tensor import MetaTensor
|
|
from monai.data.utils import get_random_patch, get_valid_patch_size
|
|
from monai.transforms import Randomizable, RandCropd, Crop # type: ignore
|
|
|
|
from core.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def _compute_multilabel_bbox(
|
|
mask: np.ndarray
|
|
) -> tuple[list[int], list[int], list[int], list[int]] | None:
|
|
"""
|
|
Compute per-channel bounding-box constraints and return lists of limits for each axis.
|
|
|
|
Args:
|
|
mask: multi-channel instance mask of shape (C, H, W).
|
|
|
|
Returns:
|
|
A tuple of four lists:
|
|
- top_mins: list of r_max for each non-empty channel
|
|
- top_maxs: list of r_min for each non-empty channel
|
|
- left_mins: list of c_max for each non-empty channel
|
|
- left_maxs: list of c_min for each non-empty channel
|
|
Or None if mask contains no positive labels.
|
|
"""
|
|
channels, rows, cols = np.nonzero(mask)
|
|
if channels.size == 0:
|
|
return None
|
|
|
|
top_mins: list[int] = []
|
|
top_maxs: list[int] = []
|
|
left_mins: list[int] = []
|
|
left_maxs: list[int] = []
|
|
C = mask.shape[0]
|
|
for ch in range(C):
|
|
rs, cs = np.nonzero(mask[ch])
|
|
if rs.size == 0:
|
|
continue
|
|
r_min, r_max = int(rs.min()), int(rs.max())
|
|
c_min, c_max = int(cs.min()), int(cs.max())
|
|
# For each channel, record the row/col extents
|
|
top_mins.append(r_max)
|
|
top_maxs.append(r_min)
|
|
left_mins.append(c_max)
|
|
left_maxs.append(c_min)
|
|
|
|
return top_mins, top_maxs, left_mins, left_maxs
|
|
|
|
|
|
class SpatialCropAllClasses(Randomizable, Crop):
|
|
"""
|
|
Cropper for multi-label instance masks and images: ensures each label-channel's
|
|
instances lie within the crop if possible.
|
|
|
|
Must be called on a mask tensor first to compute the crop, then on the image.
|
|
|
|
Args:
|
|
roi_size: desired crop size (height, width).
|
|
num_candidates: fallback samples when no single crop fits all instances.
|
|
lazy: defer actual cropping.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
roi_size: Sequence[int],
|
|
num_candidates: int = 10,
|
|
lazy: bool = False,
|
|
) -> None:
|
|
super().__init__(lazy=lazy)
|
|
self.roi_size = tuple(roi_size)
|
|
self.num_candidates = num_candidates
|
|
self._slices: tuple[slice, ...] | None = None
|
|
|
|
def randomize(self, img_size: Sequence[int]) -> None: # type: ignore
|
|
"""
|
|
Choose crop offsets so that each non-empty channel is included if possible.
|
|
"""
|
|
height, width = img_size
|
|
crop_h, crop_w = self.roi_size
|
|
max_top = max(0, height - crop_h)
|
|
max_left = max(0, width - crop_w)
|
|
|
|
# Compute per-channel bbox constraints
|
|
mask = self._img
|
|
bboxes = _compute_multilabel_bbox(mask)
|
|
if bboxes is None:
|
|
# no labels: random patch using MONAI utils
|
|
logger.warning("No labels found; using random patch.")
|
|
# determine actual patch size (fallback)
|
|
self._size = fall_back_tuple(self.roi_size, img_size)
|
|
# compute valid size for random patch
|
|
valid_size = get_valid_patch_size(img_size, self._size)
|
|
# directly get random patch slices
|
|
self._slices = get_random_patch(img_size, valid_size, self.R)
|
|
return
|
|
else:
|
|
top_mins, top_maxs, left_mins, left_maxs = bboxes
|
|
# Convert to allowable windows
|
|
# top_min_global = max(r_max - crop_h +1 for each channel)
|
|
global_top_min = max(0, max(r_max - crop_h + 1 for r_max in top_mins))
|
|
# top_max_global = min(r_min for each channel)
|
|
global_top_max = min(min(top_maxs), max_top)
|
|
# same for left
|
|
global_left_min = max(0, max(c_max - crop_w + 1 for c_max in left_mins))
|
|
global_left_max = min(min(left_maxs), max_left)
|
|
|
|
if global_top_min <= global_top_max and global_left_min <= global_left_max:
|
|
# there is a window covering all channels fully
|
|
top = self.R.randint(global_top_min, global_top_max + 1)
|
|
left = self.R.randint(global_left_min, global_left_max + 1)
|
|
else:
|
|
# fallback: sample candidates to maximize channel coverage
|
|
logger.warning(
|
|
f"Cannot fit all instances; sampling {self.num_candidates} candidates."
|
|
)
|
|
best_cover = -1
|
|
best_top = best_left = 0
|
|
C = mask.shape[0]
|
|
for _ in range(self.num_candidates):
|
|
cand_top = self.R.randint(0, max_top + 1)
|
|
cand_left = self.R.randint(0, max_left + 1)
|
|
window = mask[:, cand_top : cand_top + crop_h, cand_left : cand_left + crop_w]
|
|
cover = sum(int(window[ch].any()) for ch in range(C))
|
|
if cover > best_cover:
|
|
best_cover = cover
|
|
best_top, best_left = cand_top, cand_left
|
|
logger.info(f"Selected crop covering {best_cover}/{C} channels.")
|
|
top, left = best_top, best_left
|
|
|
|
# store slices for use on both mask and image
|
|
self._slices = (
|
|
slice(None),
|
|
slice(top, top + crop_h),
|
|
slice(left, left + crop_w),
|
|
)
|
|
|
|
def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore
|
|
"""
|
|
On first call (mask), computes crop. On subsequent (image), just applies.
|
|
Raises if mask not provided first.
|
|
"""
|
|
# Determine tensor shape
|
|
img_size = (
|
|
img.peek_pending_shape()[1:]
|
|
if isinstance(img, MetaTensor)
|
|
else img.shape[1:]
|
|
)
|
|
# First call must be mask to compute slices
|
|
if self._slices is None:
|
|
if not torch.is_floating_point(img) and img.dtype in (torch.uint8, torch.int16, torch.int32, torch.int64):
|
|
# assume integer mask
|
|
self._img = img.cpu().numpy()
|
|
self.randomize(img_size)
|
|
else:
|
|
raise RuntimeError(
|
|
"Mask tensor must be passed first for computing crop bounds."
|
|
)
|
|
# Now apply stored slice
|
|
if self._slices is None:
|
|
raise RuntimeError("Crop slices not computed; call on mask first.")
|
|
lazy_exec = self.lazy if lazy is None else lazy
|
|
return super().__call__(img=img, slices=self._slices, lazy=lazy_exec)
|
|
|
|
|
|
class RandSpatialCropAllClassesd(RandCropd):
|
|
"""
|
|
Dict-based wrapper: applies SpatialCropAllClasses to mask then image.
|
|
Requires mask present or raises.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
keys: Sequence,
|
|
roi_size: Sequence[int],
|
|
num_candidates: int = 10,
|
|
allow_missing_keys: bool = False,
|
|
lazy: bool = False,
|
|
):
|
|
cropper = SpatialCropAllClasses(
|
|
roi_size=roi_size,
|
|
num_candidates=num_candidates,
|
|
lazy=lazy,
|
|
)
|
|
super().__init__(
|
|
keys=keys,
|
|
cropper=cropper,
|
|
allow_missing_keys=allow_missing_keys,
|
|
lazy=lazy,
|
|
)
|