You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							120 lines
						
					
					
						
							4.2 KiB
						
					
					
				
			
		
		
	
	
							120 lines
						
					
					
						
							4.2 KiB
						
					
					
				import torch
 | 
						|
from torch import nn
 | 
						|
 | 
						|
from typing import Any
 | 
						|
from segmentation_models_pytorch import MAnet
 | 
						|
from segmentation_models_pytorch.base.modules import Activation
 | 
						|
 | 
						|
from pydantic import BaseModel, ConfigDict
 | 
						|
 | 
						|
__all__ = ["ModelV"]
 | 
						|
 | 
						|
 | 
						|
class ModelVParams(BaseModel):
 | 
						|
    model_config = ConfigDict(frozen=True)
 | 
						|
    
 | 
						|
    encoder_name: str = "mit_b5"                               # Default encoder
 | 
						|
    encoder_weights: str | None = "imagenet"                   # Pre-trained weights
 | 
						|
    decoder_channels: list[int] = [1024, 512, 256, 128, 64]    # Decoder configuration
 | 
						|
    decoder_pab_channels: int = 256                            # Decoder Pyramid Attention Block channels
 | 
						|
    in_channels: int = 3                                       # Number of input channels
 | 
						|
    out_classes: int = 1                                       # Number of output classes
 | 
						|
    
 | 
						|
    def asdict(self) -> dict[str, Any]:
 | 
						|
        """
 | 
						|
        Returns a dictionary of valid parameters for `nn.ModelV`.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            dict(str, Any): Dictionary of parameters for nn.ModelV.
 | 
						|
        """
 | 
						|
        loss_kwargs = self.model_dump()
 | 
						|
        return {k: v for k, v in loss_kwargs.items() if v is not None}  # Remove None values
 | 
						|
 | 
						|
 | 
						|
class ModelV(MAnet):
 | 
						|
    """ModelV model"""
 | 
						|
 | 
						|
    def __init__(self, params: ModelVParams) -> None:
 | 
						|
        # Initialize the MAnet model with provided parameters
 | 
						|
        super().__init__(**params.asdict())
 | 
						|
        
 | 
						|
        self.num_classes = params.out_classes
 | 
						|
 | 
						|
        # Remove the default segmentation head as it's not used in this architecture
 | 
						|
        self.segmentation_head = None
 | 
						|
 | 
						|
        # Modify all activation functions in the encoder and decoder from ReLU to Mish
 | 
						|
        _convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
 | 
						|
        _convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
 | 
						|
 | 
						|
        # Add custom segmentation heads for different segmentation tasks
 | 
						|
        self.cellprob_head = DeepSegmentationHead(
 | 
						|
            in_channels=params.decoder_channels[-1], out_channels=params.out_classes
 | 
						|
        )
 | 
						|
        self.gradflow_head = DeepSegmentationHead(
 | 
						|
            in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
						|
        """Forward pass through the network"""
 | 
						|
        # Ensure the input shape is correct
 | 
						|
        self.check_input_shape(x)
 | 
						|
 | 
						|
        # Encode the input and then decode it
 | 
						|
        features = self.encoder(x)
 | 
						|
        decoder_output = self.decoder(features)
 | 
						|
 | 
						|
        # Generate masks for cell probability and gradient flows
 | 
						|
        cellprob_mask = self.cellprob_head(decoder_output)
 | 
						|
        gradflow_mask = self.gradflow_head(decoder_output)
 | 
						|
        
 | 
						|
        # Concatenate the masks for output
 | 
						|
        masks = torch.cat((gradflow_mask, cellprob_mask), dim=1)
 | 
						|
 | 
						|
        return masks
 | 
						|
 | 
						|
 | 
						|
class DeepSegmentationHead(nn.Sequential):
 | 
						|
    """Custom segmentation head for generating specific masks"""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_channels: int,
 | 
						|
        out_channels: int,
 | 
						|
        kernel_size: int = 3,
 | 
						|
        activation: str | None = None,
 | 
						|
        upsampling: int = 1,
 | 
						|
    ) -> None:
 | 
						|
        # Define a sequence of layers for the segmentation head
 | 
						|
        layers: list[nn.Module] = [
 | 
						|
            nn.Conv2d(
 | 
						|
                in_channels,
 | 
						|
                in_channels // 2,
 | 
						|
                kernel_size=kernel_size,
 | 
						|
                padding=kernel_size // 2,
 | 
						|
            ),
 | 
						|
            nn.Mish(inplace=True),
 | 
						|
            nn.BatchNorm2d(in_channels // 2),
 | 
						|
            nn.Conv2d(
 | 
						|
                in_channels // 2,
 | 
						|
                out_channels,
 | 
						|
                kernel_size=kernel_size,
 | 
						|
                padding=kernel_size // 2,
 | 
						|
            ),
 | 
						|
            nn.UpsamplingBilinear2d(scale_factor=upsampling)
 | 
						|
            if upsampling > 1
 | 
						|
            else nn.Identity(),
 | 
						|
            Activation(activation) if activation else nn.Identity(),
 | 
						|
        ]
 | 
						|
        super().__init__(*layers)
 | 
						|
 | 
						|
 | 
						|
def _convert_activations(module: nn.Module, from_activation: type, to_activation: nn.Module) -> None:
 | 
						|
    """Recursively convert activation functions in a module"""
 | 
						|
    for name, child in module.named_children():
 | 
						|
        if isinstance(child, from_activation):
 | 
						|
            setattr(module, name, to_activation)
 | 
						|
        else:
 | 
						|
            _convert_activations(child, from_activation, to_activation)
 |