From cec2fcaf3fbee6dc36ba586427a1de62af4f94d1 Mon Sep 17 00:00:00 2001 From: laynholt Date: Thu, 16 Oct 2025 13:59:20 +0000 Subject: [PATCH] fix: renumbering in __compute_flow_from_masks (it was by batch, it became by batch and channels) --- core/segmentator.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/segmentator.py b/core/segmentator.py index d4533c6..747646d 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -1772,16 +1772,18 @@ class CellSegmentator: # shape (batch, H, W) -> (batch, 1, H, W) _true_masks = _true_masks[:, np.newaxis, :, :] - batch_size, *_ = _true_masks.shape + batch_size, num_channels, *_ = _true_masks.shape # Renumber labels to ensure uniqueness - renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0] - for i in range(batch_size)]) + for b in range(batch_size): + for c in range(num_channels): + fastremap.renumber(_true_masks[b, c], in_place=True) + # Compute vector flows per image - flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i]) + flow_vectors = np.stack([self.__compute_flow_from_mask(_true_masks[i]) for i in range(batch_size)]) - return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32) + return np.concatenate((flow_vectors, _true_masks), axis=1).astype(np.float32) def __compute_flow_from_mask(