added methods for calculating gradient flows

master
laynholt 1 week ago
parent 2413420620
commit 60aebe5921

@ -54,6 +54,13 @@ class ModelV(MAnet):
in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes
)
# self.gradflow_head = nn.ModuleList([
# DeepSegmentationHead(
# in_channels=params.decoder_channels[-1], out_channels=2
# ) for _ in range(params.out_classes)
# ])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network"""
# Ensure the input shape is correct
@ -67,8 +74,12 @@ class ModelV(MAnet):
cellprob_mask = self.cellprob_head(decoder_output)
gradflow_mask = self.gradflow_head(decoder_output)
# gradflow_masks = torch.cat(
# [head(decoder_output) for head in self.flow_heads], dim=1 # [B, 2*C, H, W]
# )
# Concatenate the masks for output
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
masks = torch.cat((gradflow_mask, cellprob_mask), dim=1)
return masks

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save