diff --git a/core/segmentator.py b/core/segmentator.py index d4533c6..747646d 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -1772,16 +1772,18 @@ class CellSegmentator: # shape (batch, H, W) -> (batch, 1, H, W) _true_masks = _true_masks[:, np.newaxis, :, :] - batch_size, *_ = _true_masks.shape + batch_size, num_channels, *_ = _true_masks.shape # Renumber labels to ensure uniqueness - renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0] - for i in range(batch_size)]) + for b in range(batch_size): + for c in range(num_channels): + fastremap.renumber(_true_masks[b, c], in_place=True) + # Compute vector flows per image - flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i]) + flow_vectors = np.stack([self.__compute_flow_from_mask(_true_masks[i]) for i in range(batch_size)]) - return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32) + return np.concatenate((flow_vectors, _true_masks), axis=1).astype(np.float32) def __compute_flow_from_mask(