|
|
|
@ -53,6 +53,13 @@ class ModelV(MAnet):
|
|
|
|
|
self.gradflow_head = DeepSegmentationHead(
|
|
|
|
|
in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# self.gradflow_head = nn.ModuleList([
|
|
|
|
|
# DeepSegmentationHead(
|
|
|
|
|
# in_channels=params.decoder_channels[-1], out_channels=2
|
|
|
|
|
# ) for _ in range(params.out_classes)
|
|
|
|
|
# ])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""Forward pass through the network"""
|
|
|
|
@ -66,9 +73,13 @@ class ModelV(MAnet):
|
|
|
|
|
# Generate masks for cell probability and gradient flows
|
|
|
|
|
cellprob_mask = self.cellprob_head(decoder_output)
|
|
|
|
|
gradflow_mask = self.gradflow_head(decoder_output)
|
|
|
|
|
|
|
|
|
|
# gradflow_masks = torch.cat(
|
|
|
|
|
# [head(decoder_output) for head in self.flow_heads], dim=1 # [B, 2*C, H, W]
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
# Concatenate the masks for output
|
|
|
|
|
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
|
|
|
|
|
masks = torch.cat((gradflow_mask, cellprob_mask), dim=1)
|
|
|
|
|
|
|
|
|
|
return masks
|
|
|
|
|
|
|
|
|
|