fix: renumbering in __compute_flow_from_masks (it was by batch, it became by batch and channels)

master
laynholt 2 weeks ago
parent 3177b33a92
commit cec2fcaf3f

@ -1772,16 +1772,18 @@ class CellSegmentator:
# shape (batch, H, W) -> (batch, 1, H, W)
_true_masks = _true_masks[:, np.newaxis, :, :]
batch_size, *_ = _true_masks.shape
batch_size, num_channels, *_ = _true_masks.shape
# Renumber labels to ensure uniqueness
renumbered: np.ndarray = np.stack([fastremap.renumber(_true_masks[i], in_place=True)[0]
for i in range(batch_size)])
for b in range(batch_size):
for c in range(num_channels):
fastremap.renumber(_true_masks[b, c], in_place=True)
# Compute vector flows per image
flow_vectors = np.stack([self.__compute_flow_from_mask(renumbered[i])
flow_vectors = np.stack([self.__compute_flow_from_mask(_true_masks[i])
for i in range(batch_size)])
return np.concatenate((flow_vectors, renumbered), axis=1).astype(np.float32)
return np.concatenate((flow_vectors, _true_masks), axis=1).astype(np.float32)
def __compute_flow_from_mask(

Loading…
Cancel
Save