You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							194 lines
						
					
					
						
							7.2 KiB
						
					
					
				
			
		
		
	
	
							194 lines
						
					
					
						
							7.2 KiB
						
					
					
				import torch
 | 
						|
import numpy as np
 | 
						|
from typing import Sequence
 | 
						|
 | 
						|
from monai.utils.misc import fall_back_tuple
 | 
						|
from monai.data.meta_tensor import MetaTensor
 | 
						|
from monai.data.utils import get_random_patch, get_valid_patch_size
 | 
						|
from monai.transforms import Randomizable, RandCropd, Crop  # type: ignore
 | 
						|
 | 
						|
from core.logger import get_logger
 | 
						|
 | 
						|
logger = get_logger(__name__)
 | 
						|
 | 
						|
 | 
						|
def _compute_multilabel_bbox(
 | 
						|
    mask: np.ndarray
 | 
						|
) -> tuple[list[int], list[int], list[int], list[int]] | None:
 | 
						|
    """
 | 
						|
    Compute per-channel bounding-box constraints and return lists of limits for each axis.
 | 
						|
 | 
						|
    Args:
 | 
						|
        mask: multi-channel instance mask of shape (C, H, W).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        A tuple of four lists:
 | 
						|
            - top_mins: list of r_max for each non-empty channel
 | 
						|
            - top_maxs: list of r_min for each non-empty channel
 | 
						|
            - left_mins: list of c_max for each non-empty channel
 | 
						|
            - left_maxs: list of c_min for each non-empty channel
 | 
						|
        Or None if mask contains no positive labels.
 | 
						|
    """
 | 
						|
    channels, rows, cols = np.nonzero(mask)
 | 
						|
    if channels.size == 0:
 | 
						|
        return None
 | 
						|
 | 
						|
    top_mins: list[int] = []
 | 
						|
    top_maxs: list[int] = []
 | 
						|
    left_mins: list[int] = []
 | 
						|
    left_maxs: list[int] = []
 | 
						|
    C = mask.shape[0]
 | 
						|
    for ch in range(C):
 | 
						|
        rs, cs = np.nonzero(mask[ch])
 | 
						|
        if rs.size == 0:
 | 
						|
            continue
 | 
						|
        r_min, r_max = int(rs.min()), int(rs.max())
 | 
						|
        c_min, c_max = int(cs.min()), int(cs.max())
 | 
						|
        # For each channel, record the row/col extents
 | 
						|
        top_mins.append(r_max)
 | 
						|
        top_maxs.append(r_min)
 | 
						|
        left_mins.append(c_max)
 | 
						|
        left_maxs.append(c_min)
 | 
						|
 | 
						|
    return top_mins, top_maxs, left_mins, left_maxs
 | 
						|
 | 
						|
 | 
						|
class SpatialCropAllClasses(Randomizable, Crop):
 | 
						|
    """
 | 
						|
    Cropper for multi-label instance masks and images: ensures each label-channel's
 | 
						|
    instances lie within the crop if possible.
 | 
						|
 | 
						|
    Must be called on a mask tensor first to compute the crop, then on the image.
 | 
						|
 | 
						|
    Args:
 | 
						|
        roi_size: desired crop size (height, width).
 | 
						|
        num_candidates: fallback samples when no single crop fits all instances.
 | 
						|
        lazy: defer actual cropping.
 | 
						|
    """
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        roi_size: Sequence[int],
 | 
						|
        num_candidates: int = 10,
 | 
						|
        lazy: bool = False,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__(lazy=lazy)
 | 
						|
        self.roi_size = tuple(roi_size)
 | 
						|
        self.num_candidates = num_candidates
 | 
						|
        self._slices: tuple[slice, ...] | None = None
 | 
						|
 | 
						|
    def randomize(self, img_size: Sequence[int]) -> None: # type: ignore
 | 
						|
        """
 | 
						|
        Choose crop offsets so that each non-empty channel is included if possible.
 | 
						|
        """
 | 
						|
        height, width = img_size
 | 
						|
        crop_h, crop_w = self.roi_size
 | 
						|
        max_top = max(0, height - crop_h)
 | 
						|
        max_left = max(0, width - crop_w)
 | 
						|
 | 
						|
        # Compute per-channel bbox constraints
 | 
						|
        mask = self._img
 | 
						|
        bboxes = _compute_multilabel_bbox(mask)
 | 
						|
        if bboxes is None:
 | 
						|
            # no labels: random patch using MONAI utils
 | 
						|
            logger.warning("No labels found; using random patch.")
 | 
						|
            # determine actual patch size (fallback)
 | 
						|
            self._size = fall_back_tuple(self.roi_size, img_size)
 | 
						|
            # compute valid size for random patch
 | 
						|
            valid_size = get_valid_patch_size(img_size, self._size)
 | 
						|
            # directly get random patch slices
 | 
						|
            self._slices = get_random_patch(img_size, valid_size, self.R)
 | 
						|
            return
 | 
						|
        else:
 | 
						|
            top_mins, top_maxs, left_mins, left_maxs = bboxes
 | 
						|
            # Convert to allowable windows
 | 
						|
            # top_min_global = max(r_max - crop_h +1 for each channel)
 | 
						|
            global_top_min = max(0, max(r_max - crop_h + 1 for r_max in top_mins))
 | 
						|
            # top_max_global = min(r_min for each channel)
 | 
						|
            global_top_max = min(min(top_maxs), max_top)
 | 
						|
            # same for left
 | 
						|
            global_left_min = max(0, max(c_max - crop_w + 1 for c_max in left_mins))
 | 
						|
            global_left_max = min(min(left_maxs), max_left)
 | 
						|
 | 
						|
            if global_top_min <= global_top_max and global_left_min <= global_left_max:
 | 
						|
                # there is a window covering all channels fully
 | 
						|
                top = self.R.randint(global_top_min, global_top_max + 1)
 | 
						|
                left = self.R.randint(global_left_min, global_left_max + 1)
 | 
						|
            else:
 | 
						|
                # fallback: sample candidates to maximize channel coverage
 | 
						|
                logger.warning(
 | 
						|
                    f"Cannot fit all instances; sampling {self.num_candidates} candidates."
 | 
						|
                )
 | 
						|
                best_cover = -1
 | 
						|
                best_top = best_left = 0
 | 
						|
                C = mask.shape[0]
 | 
						|
                for _ in range(self.num_candidates):
 | 
						|
                    cand_top = self.R.randint(0, max_top + 1)
 | 
						|
                    cand_left = self.R.randint(0, max_left + 1)
 | 
						|
                    window = mask[:, cand_top : cand_top + crop_h, cand_left : cand_left + crop_w]
 | 
						|
                    cover = sum(int(window[ch].any()) for ch in range(C))
 | 
						|
                    if cover > best_cover:
 | 
						|
                        best_cover = cover
 | 
						|
                        best_top, best_left = cand_top, cand_left
 | 
						|
                logger.info(f"Selected crop covering {best_cover}/{C} channels.")
 | 
						|
                top, left = best_top, best_left
 | 
						|
 | 
						|
        # store slices for use on both mask and image
 | 
						|
        self._slices = (
 | 
						|
            slice(None),
 | 
						|
            slice(top, top + crop_h),
 | 
						|
            slice(left, left + crop_w),
 | 
						|
        )
 | 
						|
 | 
						|
    def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore
 | 
						|
        """
 | 
						|
        On first call (mask), computes crop. On subsequent (image), just applies.
 | 
						|
        Raises if mask not provided first.
 | 
						|
        """
 | 
						|
        # Determine tensor shape
 | 
						|
        img_size = (
 | 
						|
            img.peek_pending_shape()[1:]
 | 
						|
            if isinstance(img, MetaTensor)
 | 
						|
            else img.shape[1:]
 | 
						|
        )
 | 
						|
        # First call must be mask to compute slices
 | 
						|
        if self._slices is None:
 | 
						|
            if not torch.is_floating_point(img) and img.dtype in (torch.uint8, torch.int16, torch.int32, torch.int64):
 | 
						|
                # assume integer mask
 | 
						|
                self._img = img.cpu().numpy()
 | 
						|
                self.randomize(img_size)
 | 
						|
            else:
 | 
						|
                raise RuntimeError(
 | 
						|
                    "Mask tensor must be passed first for computing crop bounds."
 | 
						|
                )
 | 
						|
        # Now apply stored slice
 | 
						|
        if self._slices is None:
 | 
						|
            raise RuntimeError("Crop slices not computed; call on mask first.")
 | 
						|
        lazy_exec = self.lazy if lazy is None else lazy
 | 
						|
        return super().__call__(img=img, slices=self._slices, lazy=lazy_exec)
 | 
						|
 | 
						|
 | 
						|
class RandSpatialCropAllClassesd(RandCropd):
 | 
						|
    """
 | 
						|
    Dict-based wrapper: applies SpatialCropAllClasses to mask then image.
 | 
						|
    Requires mask present or raises.
 | 
						|
    """
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        keys: Sequence,
 | 
						|
        roi_size: Sequence[int],
 | 
						|
        num_candidates: int = 10,
 | 
						|
        allow_missing_keys: bool = False,
 | 
						|
        lazy: bool = False,
 | 
						|
    ):
 | 
						|
        cropper = SpatialCropAllClasses(
 | 
						|
            roi_size=roi_size,
 | 
						|
            num_candidates=num_candidates,
 | 
						|
            lazy=lazy,
 | 
						|
        )
 | 
						|
        super().__init__(
 | 
						|
            keys=keys,
 | 
						|
            cropper=cropper,
 | 
						|
            allow_missing_keys=allow_missing_keys,
 | 
						|
            lazy=lazy,
 | 
						|
        )
 |