You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
7.6 KiB
213 lines
7.6 KiB
import torch
|
|
from torch import nn
|
|
|
|
from typing import Any, Callable
|
|
from segmentation_models_pytorch import MAnet
|
|
from segmentation_models_pytorch.base.modules import Activation
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
__all__ = ["MediarV"]
|
|
|
|
|
|
class MediarVParams(BaseModel):
|
|
model_config = ConfigDict(frozen=True)
|
|
|
|
encoder_name: str = "mit_b5" # Default encoder
|
|
encoder_weights: str | None = "imagenet" # Pre-trained weights
|
|
decoder_channels: list[int] = [1024, 512, 256, 128, 64] # Decoder configuration
|
|
decoder_pab_channels: int = 256 # Decoder Pyramid Attention Block channels
|
|
in_channels: int = 3 # Number of input channels
|
|
out_classes: int = 1 # Number of output classes
|
|
|
|
prefer_gn: bool = True # Prefer GroupNorm for small batches; falls back to BatchNorm if prefer_gn=False
|
|
upsample: int = 1 # Upsampling factor for heads (keep 1 if decoder outputs final spatial size)
|
|
zero_init_heads: bool = False # Apply zero init to final conv in heads (stabilizes early training)
|
|
|
|
def asdict(self) -> dict[str, Any]:
|
|
"""
|
|
Returns a dictionary of valid parameters for `nn.MediarV`.
|
|
|
|
Returns:
|
|
dict(str, Any): Dictionary of parameters for nn.MediarV.
|
|
"""
|
|
loss_kwargs = self.model_dump()
|
|
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
|
|
|
|
|
|
class MediarV(MAnet):
|
|
"""MediarV model"""
|
|
|
|
def __init__(self, params: MediarVParams) -> None:
|
|
# Initialize the MAnet model with provided parameters
|
|
super().__init__(**params.asdict())
|
|
|
|
self.num_classes = params.out_classes
|
|
in_ch: int = params.decoder_channels[-1]
|
|
|
|
# Remove the default segmentation head as it's not used in this architecture
|
|
self.segmentation_head = None
|
|
|
|
# Modify all activation functions in the encoder and decoder from ReLU to Mish
|
|
_convert_activations(self.encoder, nn.ReLU, lambda: nn.Mish(inplace=True))
|
|
_convert_activations(self.decoder, nn.ReLU, lambda: nn.Mish(inplace=True))
|
|
|
|
self.prefer_gn: bool = params.prefer_gn
|
|
self.upsample: int = params.upsample
|
|
self.zero_init_heads: bool = params.zero_init_heads
|
|
|
|
# Add custom segmentation heads for different segmentation tasks
|
|
# Per-class logits head (C channels)
|
|
self.logits_head = DeepSegmentationHead(
|
|
in_channels=in_ch,
|
|
out_channels=self.num_classes,
|
|
kernel_size=3,
|
|
activation=None, # keep logits raw; apply loss with logits
|
|
upsampling=self.upsample,
|
|
prefer_gn=self.prefer_gn,
|
|
)
|
|
|
|
# Flow head per-class (2*C)
|
|
flow_out = 2 * self.num_classes
|
|
self.gradflow_head = DeepSegmentationHead(
|
|
in_channels=in_ch,
|
|
out_channels=flow_out,
|
|
kernel_size=3,
|
|
activation=None, # we'll apply tanh explicitly (if enabled)
|
|
upsampling=self.upsample,
|
|
prefer_gn=self.prefer_gn,
|
|
)
|
|
|
|
if self.zero_init_heads:
|
|
zero_init_last_conv(self.logits_head)
|
|
zero_init_last_conv(self.gradflow_head)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Forward pass through the network.
|
|
|
|
Args:
|
|
x: Input tensor of shape (B, in_channels, H, W).
|
|
|
|
Returns:
|
|
Dict with:
|
|
- "flow": (B, 2*C, H, W) per-class flow;
|
|
- "logits": (B, C, H, W) per-class logits.
|
|
"""
|
|
# Ensure the input shape is correct
|
|
self.check_input_shape(x)
|
|
|
|
# Encode the input and then decode it
|
|
features = self.encoder(x)
|
|
decoder_output = self.decoder(features)
|
|
|
|
# Generate masks for cell probability and gradient flows
|
|
logits = self.logits_head(decoder_output) # (B, C, H, W)
|
|
flow = self.gradflow_head(decoder_output) # (B, 2*C, H, W)
|
|
|
|
return {"flow": flow, "logits": logits}
|
|
|
|
|
|
class DeepSegmentationHead(nn.Sequential):
|
|
"""
|
|
A robust segmentation head block:
|
|
Conv(bias=False) -> Norm -> Mish -> Conv -> (Upsample) -> (Activation?)
|
|
Notes:
|
|
* Using bias=False on the first conv since normalization follows it.
|
|
* GroupNorm is preferred for small batch sizes; fall back to BatchNorm2d.
|
|
* The 'mid' width is clamped by a minimal value to avoid too narrow bottlenecks.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int = 3,
|
|
activation: str | None = None,
|
|
upsampling: int = 1,
|
|
prefer_gn: bool = True,
|
|
min_mid: int = 8,
|
|
reduce_ratio: float = 0.5,
|
|
) -> None:
|
|
mid = compute_mid(in_channels, r=reduce_ratio, min_mid=min_mid)
|
|
norm_layer = make_norm(mid, prefer_gn=prefer_gn)
|
|
|
|
layers: list[nn.Module] = [
|
|
nn.Conv2d(
|
|
in_channels,
|
|
mid,
|
|
kernel_size=kernel_size,
|
|
padding=kernel_size // 2,
|
|
bias=False,
|
|
),
|
|
norm_layer,
|
|
nn.Mish(inplace=True),
|
|
nn.Conv2d(
|
|
mid,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
padding=kernel_size // 2,
|
|
bias=True, # final conv may keep bias; can be zero-initialized
|
|
),
|
|
nn.Upsample(scale_factor=upsampling, mode="bilinear", align_corners=False)
|
|
if upsampling > 1 else nn.Identity(),
|
|
Activation(activation) if activation else nn.Identity(),
|
|
]
|
|
super().__init__(*layers)
|
|
|
|
|
|
def _convert_activations(module: nn.Module, from_activation: type, to_activation: Callable) -> None:
|
|
"""Recursively convert activation functions in a module"""
|
|
for name, child in module.named_children():
|
|
if isinstance(child, from_activation):
|
|
setattr(module, name, to_activation())
|
|
else:
|
|
_convert_activations(child, from_activation, to_activation)
|
|
|
|
|
|
def make_norm(num_channels: int, prefer_gn: bool = True) -> nn.Module:
|
|
"""
|
|
Return a normalization layer resilient to small batch sizes.
|
|
GroupNorm is independent of batch dimension and thus stable when B is small.
|
|
"""
|
|
if prefer_gn:
|
|
for g in (32, 16, 8, 4, 2, 1):
|
|
if num_channels % g == 0:
|
|
return nn.GroupNorm(g, num_channels)
|
|
# Fallback: 1 group ~ LayerNorm across channels (per-spatial)
|
|
return nn.GroupNorm(1, num_channels)
|
|
else:
|
|
return nn.BatchNorm2d(num_channels)
|
|
|
|
|
|
def compute_mid(
|
|
in_ch: int,
|
|
r: float = 0.5,
|
|
min_mid: int = 8,
|
|
groups_hint: tuple[int, ...] = (32, 16, 8, 4, 2)
|
|
) -> int:
|
|
"""
|
|
Compute intermediate channel width for the head.
|
|
Ensures a minimum width and (optionally) tries to align to a group size.
|
|
"""
|
|
raw = max(min_mid, int(round(in_ch * r)))
|
|
for g in groups_hint:
|
|
if raw % g == 0:
|
|
return raw
|
|
return raw
|
|
|
|
|
|
def zero_init_last_conv(module: nn.Module) -> None:
|
|
"""
|
|
Zero-initialize the last Conv2d in a head to make its initial contribution neutral.
|
|
This often stabilizes early training for multi-task heads.
|
|
"""
|
|
last_conv: nn.Conv2d | None = None
|
|
for m in module.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
last_conv = m
|
|
if last_conv is not None:
|
|
nn.init.zeros_(last_conv.weight)
|
|
if last_conv.bias is not None:
|
|
nn.init.zeros_(last_conv.bias) |