You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

213 lines
7.6 KiB

import torch
from torch import nn
from typing import Any, Callable
from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation
from pydantic import BaseModel, ConfigDict
__all__ = ["MediarV"]
class MediarVParams(BaseModel):
model_config = ConfigDict(frozen=True)
encoder_name: str = "mit_b5" # Default encoder
encoder_weights: str | None = "imagenet" # Pre-trained weights
decoder_channels: list[int] = [1024, 512, 256, 128, 64] # Decoder configuration
decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels
in_channels: int = 3 # Number of input channels
out_classes: int = 1 # Number of output classes
prefer_gn: bool = True # Prefer GroupNorm for small batches; falls back to BatchNorm if prefer_gn=False
upsample: int = 1 # Upsampling factor for heads (keep 1 if decoder outputs final spatial size)
zero_init_heads: bool = False # Apply zero init to final conv in heads (stabilizes early training)
def asdict(self) -> dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.MediarV`.
Returns:
dict(str, Any): Dictionary of parameters for nn.MediarV.
"""
loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class MediarV(MAnet):
"""MediarV model"""
def __init__(self, params: MediarVParams) -> None:
# Initialize the MAnet model with provided parameters
super().__init__(**params.asdict())
self.num_classes = params.out_classes
in_ch: int = params.decoder_channels[-1]
# Remove the default segmentation head as it's not used in this architecture
self.segmentation_head = None
# Modify all activation functions in the encoder and decoder from ReLU to Mish
_convert_activations(self.encoder, nn.ReLU, lambda: nn.Mish(inplace=True))
_convert_activations(self.decoder, nn.ReLU, lambda: nn.Mish(inplace=True))
self.prefer_gn: bool = params.prefer_gn
self.upsample: int = params.upsample
self.zero_init_heads: bool = params.zero_init_heads
# Add custom segmentation heads for different segmentation tasks
# Per-class logits head (C channels)
self.logits_head = DeepSegmentationHead(
in_channels=in_ch,
out_channels=self.num_classes,
kernel_size=3,
activation=None, # keep logits raw; apply loss with logits
upsampling=self.upsample,
prefer_gn=self.prefer_gn,
)
# Flow head per-class (2*C)
flow_out = 2 * self.num_classes
self.gradflow_head = DeepSegmentationHead(
in_channels=in_ch,
out_channels=flow_out,
kernel_size=3,
activation=None, # we'll apply tanh explicitly (if enabled)
upsampling=self.upsample,
prefer_gn=self.prefer_gn,
)
if self.zero_init_heads:
zero_init_last_conv(self.logits_head)
zero_init_last_conv(self.gradflow_head)
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Forward pass through the network.
Args:
x: Input tensor of shape (B, in_channels, H, W).
Returns:
Dict with:
- "flow": (B, 2*C, H, W) per-class flow;
- "logits": (B, C, H, W) per-class logits.
"""
# Ensure the input shape is correct
self.check_input_shape(x)
# Encode the input and then decode it
features = self.encoder(x)
decoder_output = self.decoder(features)
# Generate masks for cell probability and gradient flows
logits = self.logits_head(decoder_output) # (B, C, H, W)
flow = self.gradflow_head(decoder_output) # (B, 2*C, H, W)
return {"flow": flow, "logits": logits}
class DeepSegmentationHead(nn.Sequential):
"""
A robust segmentation head block:
Conv(bias=False) -> Norm -> Mish -> Conv -> (Upsample) -> (Activation?)
Notes:
* Using bias=False on the first conv since normalization follows it.
* GroupNorm is preferred for small batch sizes; fall back to BatchNorm2d.
* The 'mid' width is clamped by a minimal value to avoid too narrow bottlenecks.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
activation: str | None = None,
upsampling: int = 1,
prefer_gn: bool = True,
min_mid: int = 8,
reduce_ratio: float = 0.5,
) -> None:
mid = compute_mid(in_channels, r=reduce_ratio, min_mid=min_mid)
norm_layer = make_norm(mid, prefer_gn=prefer_gn)
layers: list[nn.Module] = [
nn.Conv2d(
in_channels,
mid,
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=False,
),
norm_layer,
nn.Mish(inplace=True),
nn.Conv2d(
mid,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=True, # final conv may keep bias; can be zero-initialized
),
nn.Upsample(scale_factor=upsampling, mode="bilinear", align_corners=False)
if upsampling > 1 else nn.Identity(),
Activation(activation) if activation else nn.Identity(),
]
super().__init__(*layers)
def _convert_activations(module: nn.Module, from_activation: type, to_activation: Callable) -> None:
"""Recursively convert activation functions in a module"""
for name, child in module.named_children():
if isinstance(child, from_activation):
setattr(module, name, to_activation())
else:
_convert_activations(child, from_activation, to_activation)
def make_norm(num_channels: int, prefer_gn: bool = True) -> nn.Module:
"""
Return a normalization layer resilient to small batch sizes.
GroupNorm is independent of batch dimension and thus stable when B is small.
"""
if prefer_gn:
for g in (32, 16, 8, 4, 2, 1):
if num_channels % g == 0:
return nn.GroupNorm(g, num_channels)
# Fallback: 1 group ~ LayerNorm across channels (per-spatial)
return nn.GroupNorm(1, num_channels)
else:
return nn.BatchNorm2d(num_channels)
def compute_mid(
in_ch: int,
r: float = 0.5,
min_mid: int = 8,
groups_hint: tuple[int, ...] = (32, 16, 8, 4, 2)
) -> int:
"""
Compute intermediate channel width for the head.
Ensures a minimum width and (optionally) tries to align to a group size.
"""
raw = max(min_mid, int(round(in_ch * r)))
for g in groups_hint:
if raw % g == 0:
return raw
return raw
def zero_init_last_conv(module: nn.Module) -> None:
"""
Zero-initialize the last Conv2d in a head to make its initial contribution neutral.
This often stabilizes early training for multi-task heads.
"""
last_conv: nn.Conv2d | None = None
for m in module.modules():
if isinstance(m, nn.Conv2d):
last_conv = m
if last_conv is not None:
nn.init.zeros_(last_conv.weight)
if last_conv.bias is not None:
nn.init.zeros_(last_conv.bias)