You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
4.2 KiB

from typing import List, Optional
import torch
import torch.nn as nn
from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation
from pydantic import BaseModel, ConfigDict
__all__ = ["ModelV"]
class ModelVParams(BaseModel):
model_config = ConfigDict(frozen=True)
encoder_name: str = "mit_b5" # Default encoder
encoder_weights: Optional[str] = "imagenet" # Pre-trained weights
decoder_channels: List[int] = [1024, 512, 256, 128, 64] # Decoder configuration
decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels
in_channels: int = 3 # Number of input channels
out_classes: int = 3 # Number of output classes
def asdict(self):
"""
Returns a dictionary of valid parameters for `nn.ModelV`.
Returns:
Dict[str, Any]: Dictionary of parameters for nn.ModelV.
"""
loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class ModelV(MAnet):
"""ModelV model"""
def __init__(self, params: ModelVParams) -> None:
# Initialize the MAnet model with provided parameters
super().__init__(**params.asdict())
# Remove the default segmentation head as it's not used in this architecture
self.segmentation_head = None
# Modify all activation functions in the encoder and decoder from ReLU to Mish
_convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
_convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
# Add custom segmentation heads for different segmentation tasks
self.cellprob_head = DeepSegmentationHead(
in_channels=params.decoder_channels[-1], out_channels=params.out_classes
)
self.gradflow_head = DeepSegmentationHead(
in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network"""
# Ensure the input shape is correct
self.check_input_shape(x)
# Encode the input and then decode it
features = self.encoder(x)
decoder_output = self.decoder(*features)
# Generate masks for cell probability and gradient flows
cellprob_mask = self.cellprob_head(decoder_output)
gradflow_mask = self.gradflow_head(decoder_output)
# Concatenate the masks for output
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
return masks
class DeepSegmentationHead(nn.Sequential):
"""Custom segmentation head for generating specific masks"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
activation: Optional[str] = None,
upsampling: int = 1,
) -> None:
# Define a sequence of layers for the segmentation head
layers: List[nn.Module] = [
nn.Conv2d(
in_channels,
in_channels // 2,
kernel_size=kernel_size,
padding=kernel_size // 2,
),
nn.Mish(inplace=True),
nn.BatchNorm2d(in_channels // 2),
nn.Conv2d(
in_channels // 2,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
),
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity(),
Activation(activation) if activation else nn.Identity(),
]
super().__init__(*layers)
def _convert_activations(module: nn.Module, from_activation: type, to_activation: nn.Module) -> None:
"""Recursively convert activation functions in a module"""
for name, child in module.named_children():
if isinstance(child, from_activation):
setattr(module, name, to_activation)
else:
_convert_activations(child, from_activation, to_activation)