Compare commits

...

2 Commits

@ -1,16 +1,19 @@
# Cell Segmentator # Mediar-V
--- ---
## Overview ## Overview
Mediar-V is a multi-head extension of [MEDIAR-Former](https://github.com/Lee-Gihun/MEDIAR) for instance segmentation of overlapping cell structures in microscopy images (e.g., cytoplasm and nucleus of the same cell). Classical flow-field based methods typically support only a single object class and therefore require a separate model for each class. Mediar-V keeps a single backbone and flow representation, but adds dedicated segmentation heads for every target class and trains them jointly, which yields richer supervision and feature sharing between object types. On phase-contrast Glioma C6 and histological CytoNuke datasets with multiple cell classes, this unified architecture outperforms a set of separate single-class models by ≈3 p.p. F1 and ≈4 p.p. AP on average, while using a shared post-processing pipeline that reduces memory usage and FLOPs.
This repository provides two main scripts to configure and run a cell segmentation workflow: This repository provides two main scripts to configure and run a cell segmentation workflow:
* **generate\_config.py**: Interactive script to create JSON configuration files for training or prediction. * **generate_config.py**: Interactive script to create JSON configuration files for training or prediction.
* **main.py**: Entry point to train, test, or predict using the generated configuration. * **main.py**: Entry point to train, test, or predict using the generated configuration.
--- ---
## Installation ## Installation
0. **Install uv**: 0. **Install uv**:
@ -217,6 +220,20 @@ python main.py -c config/templates/predict/YourConfig.json -m predict
> Unlike prediction testing, it is not necessary that the specified test directory contains a folder with true masks. > Unlike prediction testing, it is not necessary that the specified test directory contains a folder with true masks.
### Run multiple configs from Python
You are not limited to CLI arguments: if you have many configs, you can specify them directly in `main.py` (or another script) and call `main()` in manual mode:
```python
from main import main
for cfg in [
"config/templates/train/YourConfigA.json",
"config/templates/train/YourConfigB.json",
]:
main(manual=True, config_path=cfg, mode="train")
```
--- ---
## Acknowledgments ## Acknowledgments

@ -2,13 +2,13 @@ from torch import nn
from typing import Final, Type, Any from typing import Final, Type, Any
from pydantic import BaseModel from pydantic import BaseModel
from .model_v import ModelV, ModelVParams from .mediar_v import MediarV, MediarVParams
__all__ = [ __all__ = [
"ModelRegistry", "ModelRegistry",
"ModelV", "MediarV",
"ModelVParams" "MediarVParams"
] ]
@ -17,9 +17,9 @@ class ModelRegistry:
# Single dictionary storing both model classes and parameter classes. # Single dictionary storing both model classes and parameter classes.
__MODELS: Final[dict[str, dict[str, Type[Any]]]] = { __MODELS: Final[dict[str, dict[str, Type[Any]]]] = {
"ModelV": { "MediarV": {
"class": ModelV, "class": MediarV,
"params": ModelVParams, "params": MediarVParams,
}, },
} }

@ -7,10 +7,10 @@ from segmentation_models_pytorch.base.modules import Activation
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
__all__ = ["ModelV"] __all__ = ["MediarV"]
class ModelVParams(BaseModel): class MediarVParams(BaseModel):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
encoder_name: str = "mit_b5" # Default encoder encoder_name: str = "mit_b5" # Default encoder
@ -26,19 +26,19 @@ class ModelVParams(BaseModel):
def asdict(self) -> dict[str, Any]: def asdict(self) -> dict[str, Any]:
""" """
Returns a dictionary of valid parameters for `nn.ModelV`. Returns a dictionary of valid parameters for `nn.MediarV`.
Returns: Returns:
dict(str, Any): Dictionary of parameters for nn.ModelV. dict(str, Any): Dictionary of parameters for nn.MediarV.
""" """
loss_kwargs = self.model_dump() loss_kwargs = self.model_dump()
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
class ModelV(MAnet): class MediarV(MAnet):
"""ModelV model""" """MediarV model"""
def __init__(self, params: ModelVParams) -> None: def __init__(self, params: MediarVParams) -> None:
# Initialize the MAnet model with provided parameters # Initialize the MAnet model with provided parameters
super().__init__(**params.asdict()) super().__init__(**params.asdict())
Loading…
Cancel
Save