|
|
|
|
@ -973,7 +973,7 @@ class CellSegmentator:
|
|
|
|
|
# Compute loss for this batch
|
|
|
|
|
batch_loss = self._criterion(
|
|
|
|
|
raw_output,
|
|
|
|
|
torch.from_numpy(flow_targets).to(device=raw_output.device)
|
|
|
|
|
{k: torch.from_numpy(v) for k, v in flow_targets.items()}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Post-process and compute F1 during validation and testing
|
|
|
|
|
@ -1027,7 +1027,7 @@ class CellSegmentator:
|
|
|
|
|
else:
|
|
|
|
|
epoch_metrics = {}
|
|
|
|
|
|
|
|
|
|
# Include F1 and mAP for validation and testing
|
|
|
|
|
# Include F1 and AP for validation and testing
|
|
|
|
|
if mode in ("valid", "test"):
|
|
|
|
|
# Concatenating by batch: shape (num_batches*B, C)
|
|
|
|
|
tp_array = np.vstack(all_tp)
|
|
|
|
|
@ -1046,13 +1046,13 @@ class CellSegmentator:
|
|
|
|
|
epoch_metrics[f"{mode}_f1_score_macro"] = self.__compute_f1_metric(
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="macro"
|
|
|
|
|
)
|
|
|
|
|
epoch_metrics[f"{mode}_mAP_micro"] = self.__compute_average_precision_metric(
|
|
|
|
|
epoch_metrics[f"{mode}_AP_micro"] = self.__compute_average_precision_metric(
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="micro"
|
|
|
|
|
)
|
|
|
|
|
epoch_metrics[f"{mode}_mAP_macro"] = self.__compute_average_precision_metric(
|
|
|
|
|
epoch_metrics[f"{mode}_AP_macro"] = self.__compute_average_precision_metric(
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="macro"
|
|
|
|
|
)
|
|
|
|
|
epoch_metrics[f"{mode}_mAP_pc"] = self.__compute_average_precision_metric(
|
|
|
|
|
epoch_metrics[f"{mode}_AP_pc"] = self.__compute_average_precision_metric(
|
|
|
|
|
tp_array, fp_array, fn_array, reduction="per_class"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -1063,7 +1063,7 @@ class CellSegmentator:
|
|
|
|
|
self,
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
mode: Literal["train", "valid", "test", "predict"] = "train"
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
) -> dict[str, torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Perform model inference for different stages.
|
|
|
|
|
|
|
|
|
|
@ -1072,7 +1072,7 @@ class CellSegmentator:
|
|
|
|
|
stage (Literal[...]): One of "train", "valid", "test", "predict".
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: Model outputs tensor.
|
|
|
|
|
dict[str, torch.Tensor]: Model outputs tensor.
|
|
|
|
|
"""
|
|
|
|
|
if mode != "train":
|
|
|
|
|
# Use sliding window inference for non-training phases
|
|
|
|
|
@ -1093,14 +1093,16 @@ class CellSegmentator:
|
|
|
|
|
|
|
|
|
|
def __post_process_predictions(
|
|
|
|
|
self,
|
|
|
|
|
raw_outputs: torch.Tensor,
|
|
|
|
|
raw_outputs: dict[str, torch.Tensor],
|
|
|
|
|
ground_truth: torch.Tensor | None = None
|
|
|
|
|
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
|
|
|
"""
|
|
|
|
|
Post-process raw network outputs to extract instance segmentation masks.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W).
|
|
|
|
|
raw_outputs (dict[str, torch.Tensor]): Raw model outputs:
|
|
|
|
|
- "flow": (B, 2*C, H, W) per-class flow;
|
|
|
|
|
- "logits": (B, C, H, W) per-class logits.
|
|
|
|
|
ground_truth (torch.Tensor | None): Ground truth masks of shape (B, С, H, W).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
@ -1109,11 +1111,9 @@ class CellSegmentator:
|
|
|
|
|
- labels_np: Converted ground truth of shape (B, С, H, W) or None if
|
|
|
|
|
ground_truth was not provided.
|
|
|
|
|
"""
|
|
|
|
|
# Move outputs to CPU and convert to numpy
|
|
|
|
|
outputs_np = raw_outputs.cpu().numpy()
|
|
|
|
|
# Split channels: gradient flows then class logits
|
|
|
|
|
gradflow = outputs_np[:, :2 * self._model.num_classes]
|
|
|
|
|
logits = outputs_np[:, -self._model.num_classes :]
|
|
|
|
|
gradflow = raw_outputs["flow"].cpu().numpy()
|
|
|
|
|
logits = raw_outputs["logits"].cpu().numpy()
|
|
|
|
|
# Apply sigmoid to logits to get probabilities
|
|
|
|
|
probabilities = self.__sigmoid(logits)
|
|
|
|
|
|
|
|
|
|
@ -1271,12 +1271,11 @@ class CellSegmentator:
|
|
|
|
|
if reduction == "macro":
|
|
|
|
|
f1_per_class = np.zeros(num_classes, dtype=float)
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
_, _, f1_c = compute_f1_score(
|
|
|
|
|
_, _, f1_per_class[c] = compute_f1_score(
|
|
|
|
|
tp_per_class[c],
|
|
|
|
|
fp_per_class[c],
|
|
|
|
|
fn_per_class[c]
|
|
|
|
|
)
|
|
|
|
|
f1_per_class[c] = f1_c
|
|
|
|
|
return float(f1_per_class.mean())
|
|
|
|
|
|
|
|
|
|
# 6) Weighted: class-wise F1 weighted by support = TP + FN
|
|
|
|
|
@ -1381,7 +1380,7 @@ class CellSegmentator:
|
|
|
|
|
fp_per_class = false_positives.sum(axis=0).astype(int)
|
|
|
|
|
fn_per_class = false_negatives.sum(axis=0).astype(int)
|
|
|
|
|
|
|
|
|
|
# 4) Per-class: compute F1 for each class and return vector
|
|
|
|
|
# 4) Per-class: compute AP for each class and return vector
|
|
|
|
|
if reduction == "per_class":
|
|
|
|
|
ap_per_class = np.zeros(num_classes, dtype=float)
|
|
|
|
|
for c in range(num_classes):
|
|
|
|
|
@ -1752,7 +1751,7 @@ class CellSegmentator:
|
|
|
|
|
def __compute_flows_from_masks(
|
|
|
|
|
self,
|
|
|
|
|
true_masks: Tensor
|
|
|
|
|
) -> np.ndarray:
|
|
|
|
|
) -> dict[str, np.ndarray]:
|
|
|
|
|
"""
|
|
|
|
|
Convert segmentation masks to flow fields for training.
|
|
|
|
|
|
|
|
|
|
@ -1760,8 +1759,9 @@ class CellSegmentator:
|
|
|
|
|
true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
numpy array of concatenated [flow_vectors, renumbered_true_masks] per image.
|
|
|
|
|
renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow.
|
|
|
|
|
Dict with:
|
|
|
|
|
- "flow": (B, 2*C, H, W) per-class flow (flow_vectors[0] is Y flow, flow_vectors[1] is X flow)
|
|
|
|
|
- "labels": (B, C, H, W) per-class labels.
|
|
|
|
|
"""
|
|
|
|
|
# Move to CPU numpy
|
|
|
|
|
_true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16)
|
|
|
|
|
@ -1783,7 +1783,7 @@ class CellSegmentator:
|
|
|
|
|
flow_vectors = np.stack([self.__compute_flow_from_mask(_true_masks[i])
|
|
|
|
|
for i in range(batch_size)])
|
|
|
|
|
|
|
|
|
|
return np.concatenate((flow_vectors, _true_masks), axis=1).astype(np.float32)
|
|
|
|
|
return {"flow": flow_vectors.astype(np.float32), "labels": _true_masks.astype(np.float32)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __compute_flow_from_mask(
|
|
|
|
|
@ -2245,9 +2245,10 @@ class CellSegmentator:
|
|
|
|
|
|
|
|
|
|
# Follow the flow vectors to generate coordinate mappings
|
|
|
|
|
flow_coordinates = self.__follow_flows(
|
|
|
|
|
flow_field=channel_flow_vectors * channel_mask / 5.0,
|
|
|
|
|
flow_field=channel_flow_vectors * channel_mask,
|
|
|
|
|
initial_coords=nonzero_coords,
|
|
|
|
|
num_iters=num_iters
|
|
|
|
|
num_iters=num_iters,
|
|
|
|
|
step_scale=0.2
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not torch.is_tensor(flow_coordinates):
|
|
|
|
|
@ -2292,7 +2293,8 @@ class CellSegmentator:
|
|
|
|
|
self,
|
|
|
|
|
flow_field: np.ndarray,
|
|
|
|
|
initial_coords: np.ndarray,
|
|
|
|
|
num_iters: int = 200
|
|
|
|
|
num_iters: int = 200,
|
|
|
|
|
step_scale: float = 1.0
|
|
|
|
|
) -> np.ndarray | torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Trace pixel positions through a flow field via iterative interpolation.
|
|
|
|
|
@ -2301,6 +2303,7 @@ class CellSegmentator:
|
|
|
|
|
flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors.
|
|
|
|
|
initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions.
|
|
|
|
|
num_iters (int): Number of integration steps.
|
|
|
|
|
step_scale (float): Step length in "pixels per iteration".
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
(np.ndarray | torch.Tensor): Final (y, x) positions of each point.
|
|
|
|
|
@ -2311,50 +2314,40 @@ class CellSegmentator:
|
|
|
|
|
|
|
|
|
|
# Choose GPU/MPS path if available
|
|
|
|
|
if self._device.type in ("cuda", "mps"):
|
|
|
|
|
# Load initial positions and flow into tensors (flip order for grid_sample)
|
|
|
|
|
# Prepare point tensor: shape [1, 1, num_points, 2]
|
|
|
|
|
pts = torch.zeros((1, 1, initial_coords.shape[1], dims),
|
|
|
|
|
dtype=torch.float32, device=self._device)
|
|
|
|
|
pts = torch.empty((1, 1, initial_coords.shape[1], dims), dtype=torch.float32, device=self._device)
|
|
|
|
|
pts[..., 0] = torch.from_numpy(initial_coords[1]).to(self._device, torch.float32) # x
|
|
|
|
|
pts[..., 1] = torch.from_numpy(initial_coords[0]).to(self._device, torch.float32) # y
|
|
|
|
|
|
|
|
|
|
# Prepare flow volume: shape [1, 2, height, width]
|
|
|
|
|
flow_vol = torch.zeros((1, dims, height, width),
|
|
|
|
|
dtype=torch.float32, device=self._device)
|
|
|
|
|
|
|
|
|
|
# Load initial positions and flow into tensors (flip order for grid_sample)
|
|
|
|
|
# dim 0 = x
|
|
|
|
|
# dim 1 = y
|
|
|
|
|
for i in range(dims):
|
|
|
|
|
pts[0, 0, :, dims - i - 1] = (
|
|
|
|
|
torch.from_numpy(initial_coords[i])
|
|
|
|
|
.to(self._device, torch.float32)
|
|
|
|
|
)
|
|
|
|
|
flow_vol[0, dims - i - 1] = (
|
|
|
|
|
torch.from_numpy(flow_field[i])
|
|
|
|
|
.to(self._device, torch.float32)
|
|
|
|
|
)
|
|
|
|
|
flow_vol = torch.empty((1, dims, height, width), dtype=torch.float32, device=self._device)
|
|
|
|
|
flow_vol[:, 0] = torch.from_numpy(flow_field[1]).to(self._device, torch.float32) # x-component
|
|
|
|
|
flow_vol[:, 1] = torch.from_numpy(flow_field[0]).to(self._device, torch.float32) # y-component
|
|
|
|
|
|
|
|
|
|
# Prepare normalization factors for x and y (max index)
|
|
|
|
|
max_indices = torch.tensor([width - 1, height - 1],
|
|
|
|
|
dtype=torch.float32, device=self._device)
|
|
|
|
|
max_indices = torch.tensor([width - 1, height - 1], dtype=torch.float32, device=self._device) # (2,)
|
|
|
|
|
# Reshape for broadcasting to point tensor dims
|
|
|
|
|
max_idx_pt = max_indices.view(1, 1, 1, dims)
|
|
|
|
|
# Reshape for broadcasting to flow volume dims
|
|
|
|
|
max_idx_flow = max_indices.view(1, dims, 1, 1)
|
|
|
|
|
|
|
|
|
|
# Normalize flow values to [-1, 1] range
|
|
|
|
|
flow_vol = (flow_vol * 2) / max_idx_flow
|
|
|
|
|
flow_vol = (flow_vol * step_scale * 2.0) / max_idx_flow
|
|
|
|
|
# Normalize points to [-1, 1]
|
|
|
|
|
pts = (pts / max_idx_pt) * 2 - 1
|
|
|
|
|
pts = (pts / max_idx_pt) * 2.0 - 1.0
|
|
|
|
|
|
|
|
|
|
# Iterate: sample flow and update points
|
|
|
|
|
for _ in range(num_iters):
|
|
|
|
|
sampled = torch.nn.functional.grid_sample(
|
|
|
|
|
flow_vol, pts, align_corners=False
|
|
|
|
|
flow_vol, pts, align_corners=True
|
|
|
|
|
)
|
|
|
|
|
# Update each coordinate and clamp to valid range
|
|
|
|
|
for i in range(dims):
|
|
|
|
|
pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0)
|
|
|
|
|
|
|
|
|
|
# Denormalize back to original pixel coordinates
|
|
|
|
|
pts = (pts + 1) * 0.5 * max_idx_pt
|
|
|
|
|
pts = (pts + 1.0) * 0.5 * max_idx_pt
|
|
|
|
|
# Swap channels back to (y, x) and flatten
|
|
|
|
|
final_pts = pts[..., [1, 0]].squeeze()
|
|
|
|
|
# Convert from (num_points, 2) to (2, num_points)
|
|
|
|
|
@ -2368,8 +2361,8 @@ class CellSegmentator:
|
|
|
|
|
# Interpolate flow at current positions
|
|
|
|
|
self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta)
|
|
|
|
|
# Update positions and clamp to image bounds
|
|
|
|
|
current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1)
|
|
|
|
|
current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1)
|
|
|
|
|
current_pos[0] = np.clip(current_pos[0] + step_scale * temp_delta[0], 0, height - 1)
|
|
|
|
|
current_pos[1] = np.clip(current_pos[1] + step_scale * temp_delta[1], 0, width - 1)
|
|
|
|
|
|
|
|
|
|
return current_pos
|
|
|
|
|
|
|
|
|
|
|