model outputs converted to dict

master
laynholt 3 weeks ago
parent cec2fcaf3f
commit 4bdd1ee872

@ -68,34 +68,34 @@ class BCE_MSE_Loss(BaseLoss):
self.loss_bce_metric = CumulativeAverage()
self.loss_mse_metric = CumulativeAverage()
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def forward(self, outputs: dict[str, torch.Tensor], target: dict[str, torch.Tensor]) -> torch.Tensor: # type: ignore
"""
Computes the loss between true labels and prediction outputs.
Args:
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W).
outputs (dict[str, torch.Tensor]): Model predictions:
- "flow": (B, 2*C, H, W) per-class flow (flow_vectors[0] is Y flow, flow_vectors[1] is X flow)
- "logits": (B, C, H, W) per-class logits.
target (dict[str, torch.Tensor]): Ground truth labels:
- "flow": (B, 2*C, H, W) per-class flow (flow_vectors[0] is Y flow, flow_vectors[1] is X flow)
- "labels": (B, C, H, W) per-class labels.
Returns:
torch.Tensor: The total loss value.
"""
# Ensure target is on the same device as outputs
assert (
target.device == outputs.device
), (
"Target tensor must be moved to the same device as outputs "
"before calling forward()."
)
logits: torch.Tensor = outputs["logits"] # (B,C,H,W)
flow_pred: torch.Tensor = outputs["flow"] # (B,2*C,H,W)
labels: torch.Tensor = target["labels"] # (B,C,H,W)
flow_tgt: torch.Tensor = target["flow"] # (B,2*C,H,W)
labels = labels.to(device=logits.device)
flow_tgt = flow_tgt.to(device=flow_pred.device)
# Cell Recognition Loss
cellprob_loss = self.bce_loss(
outputs[:, -self.num_classes:], (target[:, -self.num_classes:] > 0).float()
)
cellprob_loss = self.bce_loss(logits, (labels > 0).float())
# Cell Distinction Loss
gradflow_loss = 0.5 * self.mse_loss(
outputs[:, :2 * self.num_classes], 5.0 * target[:, :2 * self.num_classes]
)
gradflow_loss = 0.5 * self.mse_loss(flow_pred, 5.0 * flow_tgt)
# Total loss
total_loss = cellprob_loss + gradflow_loss

@ -1,7 +1,7 @@
import torch
from torch import nn
from typing import Any
from typing import Any, Callable
from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation
@ -20,6 +20,10 @@ class ModelVParams(BaseModel):
in_channels: int = 3 # Number of input channels
out_classes: int = 1 # Number of output classes
prefer_gn: bool = True # Prefer GroupNorm for small batches; falls back to BatchNorm if prefer_gn=False
upsample: int = 1 # Upsampling factor for heads (keep 1 if decoder outputs final spatial size)
zero_init_heads: bool = False # Apply zero init to final conv in heads (stabilizes early training)
def asdict(self) -> dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.ModelV`.
@ -39,25 +43,58 @@ class ModelV(MAnet):
super().__init__(**params.asdict())
self.num_classes = params.out_classes
in_ch: int = params.decoder_channels[-1]
# Remove the default segmentation head as it's not used in this architecture
self.segmentation_head = None
# Modify all activation functions in the encoder and decoder from ReLU to Mish
_convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
_convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
_convert_activations(self.encoder, nn.ReLU, lambda: nn.Mish(inplace=True))
_convert_activations(self.decoder, nn.ReLU, lambda: nn.Mish(inplace=True))
self.prefer_gn: bool = params.prefer_gn
self.upsample: int = params.upsample
self.zero_init_heads: bool = params.zero_init_heads
# Add custom segmentation heads for different segmentation tasks
self.cellprob_head = DeepSegmentationHead(
in_channels=params.decoder_channels[-1], out_channels=params.out_classes
# Per-class logits head (C channels)
self.logits_head = DeepSegmentationHead(
in_channels=in_ch,
out_channels=self.num_classes,
kernel_size=3,
activation=None, # keep logits raw; apply loss with logits
upsampling=self.upsample,
prefer_gn=self.prefer_gn,
)
# Flow head per-class (2*C)
flow_out = 2 * self.num_classes
self.gradflow_head = DeepSegmentationHead(
in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes
in_channels=in_ch,
out_channels=flow_out,
kernel_size=3,
activation=None, # we'll apply tanh explicitly (if enabled)
upsampling=self.upsample,
prefer_gn=self.prefer_gn,
)
if self.zero_init_heads:
zero_init_last_conv(self.logits_head)
zero_init_last_conv(self.gradflow_head)
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Forward pass through the network.
Args:
x: Input tensor of shape (B, in_channels, H, W).
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network"""
Returns:
Dict with:
- "flow": (B, 2*C, H, W) per-class flow;
- "logits": (B, C, H, W) per-class logits.
"""
# Ensure the input shape is correct
self.check_input_shape(x)
@ -66,17 +103,21 @@ class ModelV(MAnet):
decoder_output = self.decoder(features)
# Generate masks for cell probability and gradient flows
cellprob_mask = self.cellprob_head(decoder_output)
gradflow_mask = self.gradflow_head(decoder_output)
logits = self.logits_head(decoder_output) # (B, C, H, W)
flow = self.gradflow_head(decoder_output) # (B, 2*C, H, W)
# Concatenate the masks for output
masks = torch.cat((gradflow_mask, cellprob_mask), dim=1)
return masks
return {"flow": flow, "logits": logits}
class DeepSegmentationHead(nn.Sequential):
"""Custom segmentation head for generating specific masks"""
"""
A robust segmentation head block:
Conv(bias=False) -> Norm -> Mish -> Conv -> (Upsample) -> (Activation?)
Notes:
* Using bias=False on the first conv since normalization follows it.
* GroupNorm is preferred for small batch sizes; fall back to BatchNorm2d.
* The 'mid' width is clamped by a minimal value to avoid too narrow bottlenecks.
"""
def __init__(
self,
@ -85,35 +126,88 @@ class DeepSegmentationHead(nn.Sequential):
kernel_size: int = 3,
activation: str | None = None,
upsampling: int = 1,
prefer_gn: bool = True,
min_mid: int = 8,
reduce_ratio: float = 0.5,
) -> None:
# Define a sequence of layers for the segmentation head
mid = compute_mid(in_channels, r=reduce_ratio, min_mid=min_mid)
norm_layer = make_norm(mid, prefer_gn=prefer_gn)
layers: list[nn.Module] = [
nn.Conv2d(
in_channels,
in_channels // 2,
mid,
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=False,
),
norm_layer,
nn.Mish(inplace=True),
nn.BatchNorm2d(in_channels // 2),
nn.Conv2d(
in_channels // 2,
mid,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=True, # final conv may keep bias; can be zero-initialized
),
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity(),
nn.Upsample(scale_factor=upsampling, mode="bilinear", align_corners=False)
if upsampling > 1 else nn.Identity(),
Activation(activation) if activation else nn.Identity(),
]
super().__init__(*layers)
def _convert_activations(module: nn.Module, from_activation: type, to_activation: nn.Module) -> None:
def _convert_activations(module: nn.Module, from_activation: type, to_activation: Callable) -> None:
"""Recursively convert activation functions in a module"""
for name, child in module.named_children():
if isinstance(child, from_activation):
setattr(module, name, to_activation)
setattr(module, name, to_activation())
else:
_convert_activations(child, from_activation, to_activation)
def make_norm(num_channels: int, prefer_gn: bool = True) -> nn.Module:
"""
Return a normalization layer resilient to small batch sizes.
GroupNorm is independent of batch dimension and thus stable when B is small.
"""
if prefer_gn:
for g in (32, 16, 8, 4, 2, 1):
if num_channels % g == 0:
return nn.GroupNorm(g, num_channels)
# Fallback: 1 group ~ LayerNorm across channels (per-spatial)
return nn.GroupNorm(1, num_channels)
else:
return nn.BatchNorm2d(num_channels)
def compute_mid(
in_ch: int,
r: float = 0.5,
min_mid: int = 8,
groups_hint: tuple[int, ...] = (32, 16, 8, 4, 2)
) -> int:
"""
Compute intermediate channel width for the head.
Ensures a minimum width and (optionally) tries to align to a group size.
"""
raw = max(min_mid, int(round(in_ch * r)))
for g in groups_hint:
if raw % g == 0:
return raw
return raw
def zero_init_last_conv(module: nn.Module) -> None:
"""
Zero-initialize the last Conv2d in a head to make its initial contribution neutral.
This often stabilizes early training for multi-task heads.
"""
last_conv: nn.Conv2d | None = None
for m in module.modules():
if isinstance(m, nn.Conv2d):
last_conv = m
if last_conv is not None:
nn.init.zeros_(last_conv.weight)
if last_conv.bias is not None:
nn.init.zeros_(last_conv.bias)

@ -973,7 +973,7 @@ class CellSegmentator:
# Compute loss for this batch
batch_loss = self._criterion(
raw_output,
torch.from_numpy(flow_targets).to(device=raw_output.device)
{k: torch.from_numpy(v) for k, v in flow_targets.items()}
)
# Post-process and compute F1 during validation and testing
@ -1027,7 +1027,7 @@ class CellSegmentator:
else:
epoch_metrics = {}
# Include F1 and mAP for validation and testing
# Include F1 and AP for validation and testing
if mode in ("valid", "test"):
# Concatenating by batch: shape (num_batches*B, C)
tp_array = np.vstack(all_tp)
@ -1046,13 +1046,13 @@ class CellSegmentator:
epoch_metrics[f"{mode}_f1_score_macro"] = self.__compute_f1_metric(
tp_array, fp_array, fn_array, reduction="macro"
)
epoch_metrics[f"{mode}_mAP_micro"] = self.__compute_average_precision_metric(
epoch_metrics[f"{mode}_AP_micro"] = self.__compute_average_precision_metric(
tp_array, fp_array, fn_array, reduction="micro"
)
epoch_metrics[f"{mode}_mAP_macro"] = self.__compute_average_precision_metric(
epoch_metrics[f"{mode}_AP_macro"] = self.__compute_average_precision_metric(
tp_array, fp_array, fn_array, reduction="macro"
)
epoch_metrics[f"{mode}_mAP_pc"] = self.__compute_average_precision_metric(
epoch_metrics[f"{mode}_AP_pc"] = self.__compute_average_precision_metric(
tp_array, fp_array, fn_array, reduction="per_class"
)
@ -1063,7 +1063,7 @@ class CellSegmentator:
self,
inputs: torch.Tensor,
mode: Literal["train", "valid", "test", "predict"] = "train"
) -> torch.Tensor:
) -> dict[str, torch.Tensor]:
"""
Perform model inference for different stages.
@ -1072,7 +1072,7 @@ class CellSegmentator:
stage (Literal[...]): One of "train", "valid", "test", "predict".
Returns:
torch.Tensor: Model outputs tensor.
dict[str, torch.Tensor]: Model outputs tensor.
"""
if mode != "train":
# Use sliding window inference for non-training phases
@ -1093,14 +1093,16 @@ class CellSegmentator:
def __post_process_predictions(
self,
raw_outputs: torch.Tensor,
raw_outputs: dict[str, torch.Tensor],
ground_truth: torch.Tensor | None = None
) -> tuple[np.ndarray, np.ndarray | None]:
"""
Post-process raw network outputs to extract instance segmentation masks.
Args:
raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W).
raw_outputs (dict[str, torch.Tensor]): Raw model outputs:
- "flow": (B, 2*C, H, W) per-class flow;
- "logits": (B, C, H, W) per-class logits.
ground_truth (torch.Tensor | None): Ground truth masks of shape (B, С, H, W).
Returns:
@ -1109,11 +1111,9 @@ class CellSegmentator:
- labels_np: Converted ground truth of shape (B, С, H, W) or None if
ground_truth was not provided.
"""
# Move outputs to CPU and convert to numpy
outputs_np = raw_outputs.cpu().numpy()
# Split channels: gradient flows then class logits
gradflow = outputs_np[:, :2 * self._model.num_classes]
logits = outputs_np[:, -self._model.num_classes :]
gradflow = raw_outputs["flow"].cpu().numpy()
logits = raw_outputs["logits"].cpu().numpy()
# Apply sigmoid to logits to get probabilities
probabilities = self.__sigmoid(logits)
@ -1271,12 +1271,11 @@ class CellSegmentator:
if reduction == "macro":
f1_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
_, _, f1_c = compute_f1_score(
_, _, f1_per_class[c] = compute_f1_score(
tp_per_class[c],
fp_per_class[c],
fn_per_class[c]
)
f1_per_class[c] = f1_c
return float(f1_per_class.mean())
# 6) Weighted: class-wise F1 weighted by support = TP + FN
@ -1381,7 +1380,7 @@ class CellSegmentator:
fp_per_class = false_positives.sum(axis=0).astype(int)
fn_per_class = false_negatives.sum(axis=0).astype(int)
# 4) Per-class: compute F1 for each class and return vector
# 4) Per-class: compute AP for each class and return vector
if reduction == "per_class":
ap_per_class = np.zeros(num_classes, dtype=float)
for c in range(num_classes):
@ -1752,7 +1751,7 @@ class CellSegmentator:
def __compute_flows_from_masks(
self,
true_masks: Tensor
) -> np.ndarray:
) -> dict[str, np.ndarray]:
"""
Convert segmentation masks to flow fields for training.
@ -1760,8 +1759,9 @@ class CellSegmentator:
true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks.
Returns:
numpy array of concatenated [flow_vectors, renumbered_true_masks] per image.
renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow.
Dict with:
- "flow": (B, 2*C, H, W) per-class flow (flow_vectors[0] is Y flow, flow_vectors[1] is X flow)
- "labels": (B, C, H, W) per-class labels.
"""
# Move to CPU numpy
_true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16)
@ -1783,7 +1783,7 @@ class CellSegmentator:
flow_vectors = np.stack([self.__compute_flow_from_mask(_true_masks[i])
for i in range(batch_size)])
return np.concatenate((flow_vectors, _true_masks), axis=1).astype(np.float32)
return {"flow": flow_vectors.astype(np.float32), "labels": _true_masks.astype(np.float32)}
def __compute_flow_from_mask(
@ -2245,9 +2245,10 @@ class CellSegmentator:
# Follow the flow vectors to generate coordinate mappings
flow_coordinates = self.__follow_flows(
flow_field=channel_flow_vectors * channel_mask / 5.0,
flow_field=channel_flow_vectors * channel_mask,
initial_coords=nonzero_coords,
num_iters=num_iters
num_iters=num_iters,
step_scale=0.2
)
if not torch.is_tensor(flow_coordinates):
@ -2292,7 +2293,8 @@ class CellSegmentator:
self,
flow_field: np.ndarray,
initial_coords: np.ndarray,
num_iters: int = 200
num_iters: int = 200,
step_scale: float = 1.0
) -> np.ndarray | torch.Tensor:
"""
Trace pixel positions through a flow field via iterative interpolation.
@ -2301,6 +2303,7 @@ class CellSegmentator:
flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors.
initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions.
num_iters (int): Number of integration steps.
step_scale (float): Step length in "pixels per iteration".
Returns:
(np.ndarray | torch.Tensor): Final (y, x) positions of each point.
@ -2311,50 +2314,40 @@ class CellSegmentator:
# Choose GPU/MPS path if available
if self._device.type in ("cuda", "mps"):
# Load initial positions and flow into tensors (flip order for grid_sample)
# Prepare point tensor: shape [1, 1, num_points, 2]
pts = torch.zeros((1, 1, initial_coords.shape[1], dims),
dtype=torch.float32, device=self._device)
pts = torch.empty((1, 1, initial_coords.shape[1], dims), dtype=torch.float32, device=self._device)
pts[..., 0] = torch.from_numpy(initial_coords[1]).to(self._device, torch.float32) # x
pts[..., 1] = torch.from_numpy(initial_coords[0]).to(self._device, torch.float32) # y
# Prepare flow volume: shape [1, 2, height, width]
flow_vol = torch.zeros((1, dims, height, width),
dtype=torch.float32, device=self._device)
# Load initial positions and flow into tensors (flip order for grid_sample)
# dim 0 = x
# dim 1 = y
for i in range(dims):
pts[0, 0, :, dims - i - 1] = (
torch.from_numpy(initial_coords[i])
.to(self._device, torch.float32)
)
flow_vol[0, dims - i - 1] = (
torch.from_numpy(flow_field[i])
.to(self._device, torch.float32)
)
flow_vol = torch.empty((1, dims, height, width), dtype=torch.float32, device=self._device)
flow_vol[:, 0] = torch.from_numpy(flow_field[1]).to(self._device, torch.float32) # x-component
flow_vol[:, 1] = torch.from_numpy(flow_field[0]).to(self._device, torch.float32) # y-component
# Prepare normalization factors for x and y (max index)
max_indices = torch.tensor([width - 1, height - 1],
dtype=torch.float32, device=self._device)
max_indices = torch.tensor([width - 1, height - 1], dtype=torch.float32, device=self._device) # (2,)
# Reshape for broadcasting to point tensor dims
max_idx_pt = max_indices.view(1, 1, 1, dims)
# Reshape for broadcasting to flow volume dims
max_idx_flow = max_indices.view(1, dims, 1, 1)
# Normalize flow values to [-1, 1] range
flow_vol = (flow_vol * 2) / max_idx_flow
flow_vol = (flow_vol * step_scale * 2.0) / max_idx_flow
# Normalize points to [-1, 1]
pts = (pts / max_idx_pt) * 2 - 1
pts = (pts / max_idx_pt) * 2.0 - 1.0
# Iterate: sample flow and update points
for _ in range(num_iters):
sampled = torch.nn.functional.grid_sample(
flow_vol, pts, align_corners=False
flow_vol, pts, align_corners=True
)
# Update each coordinate and clamp to valid range
for i in range(dims):
pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0)
# Denormalize back to original pixel coordinates
pts = (pts + 1) * 0.5 * max_idx_pt
pts = (pts + 1.0) * 0.5 * max_idx_pt
# Swap channels back to (y, x) and flatten
final_pts = pts[..., [1, 0]].squeeze()
# Convert from (num_points, 2) to (2, num_points)
@ -2368,8 +2361,8 @@ class CellSegmentator:
# Interpolate flow at current positions
self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta)
# Update positions and clamp to image bounds
current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1)
current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1)
current_pos[0] = np.clip(current_pos[0] + step_scale * temp_delta[0], 0, height - 1)
current_pos[1] = np.clip(current_pos[1] + step_scale * temp_delta[1], 0, width - 1)
return current_pos

@ -13,42 +13,57 @@ from core.data import (
from core.segmentator import CellSegmentator
def main() -> None:
parser = argparse.ArgumentParser(
description="Train or predict cell segmentator with specified config file."
)
parser.add_argument(
'-c', '--config',
type=str,
help='Path to the JSON config file'
)
parser.add_argument(
'-m', '--mode',
choices=['train', 'test', 'predict'],
default='train',
help='Run mode: train, test or predict'
)
parser.add_argument(
'--no-save-masks',
action='store_false',
dest='save_masks',
help='If set, do NOT save predicted masks (saving is enabled by default)'
)
parser.add_argument(
'--only-masks',
action='store_true',
help=('If set and save-masks set, save only the raw predicted'
' masks without additional visualizations')
)
def main(
manual: bool = False,
config_path: str | None = None,
mode: str | None = None,
save_masks: bool = True,
only_masks: bool = False
) -> None:
if not manual:
parser = argparse.ArgumentParser(
description="Train or predict cell segmentator with specified config file."
)
parser.add_argument(
'-c', '--config',
type=str,
help='Path to the JSON config file'
)
parser.add_argument(
'-m', '--mode',
choices=['train', 'test', 'predict'],
default='train',
help='Run mode: train, test or predict'
)
parser.add_argument(
'--no-save-masks',
action='store_false',
dest='save_masks',
help='If set, do NOT save predicted masks (saving is enabled by default)'
)
parser.add_argument(
'--only-masks',
action='store_true',
help=('If set and save-masks set, save only the raw predicted'
' masks without additional visualizations')
)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(0)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(0)
args = parser.parse_args()
args = parser.parse_args()
mode = args.mode
config_path = args.config
mode = args.mode
config_path = args.config
save_masks = args.save_masks
only_masks = args.only_masks
else:
if mode is None or config_path is None:
raise ValueError("In manual mode, you must specify the path to the config and mode!")
config = Config.load_json(config_path)
if mode == 'train' and not config.dataset_config.is_training:
@ -62,7 +77,7 @@ def main() -> None:
if config.wandb_config.use_wandb:
# Initialize W&B
wandb.init(config=config.asdict(), **config.wandb_config.asdict())
wandb.init(config=config.asdict(), reinit="finish_previous", **config.wandb_config.asdict())
# How many batches to wait before logging training status
wandb.config.log_interval = 10
@ -84,9 +99,9 @@ def main() -> None:
wandb.watch(segmentator._model, log="all", log_graph=True)
try:
segmentator.run(save_results=args.save_masks, only_masks=args.only_masks)
except Exception:
raise
segmentator.run(save_results=save_masks, only_masks=only_masks)
except Exception as e:
raise e
finally:
if config.dataset_config.is_training:
# Prepare saving path
@ -101,8 +116,52 @@ def main() -> None:
if config.wandb_config.use_wandb:
wandb.save(saving_path)
wandb.finish()
if __name__ == "__main__":
main()
train_configs = [
"/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cA.json",
"/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cB.json",
"/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cSoma.json",
"/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cAB.json",
"/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cABSoma.json",
"/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cytoCell.json",
"/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cytoNuc.json",
"/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cytoCellNuc.json",
]
predict_configs = [
"/workspace/model-v/config/templates/predict/ModelV_cA.json",
"/workspace/model-v/config/templates/predict/ModelV_cB.json",
"/workspace/model-v/config/templates/predict/ModelV_cSoma.json",
"/workspace/model-v/config/templates/predict/ModelV_cAB.json",
"/workspace/model-v/config/templates/predict/ModelV_cABSoma.json",
"/workspace/model-v/config/templates/predict/ModelV_cytoCell.json",
"/workspace/model-v/config/templates/predict/ModelV_cytoNuc.json",
"/workspace/model-v/config/templates/predict/ModelV_cytoCellNuc.json",
]
for config in train_configs:
print(f"Hande config {config}")
main(
manual=True,
config_path=config,
mode="train",
save_masks=True,
only_masks=False
)
for config in predict_configs:
print(f"Hande config {config}")
main(
manual=True,
config_path=config,
mode="predict",
save_masks=True,
only_masks=False
)

@ -17,6 +17,7 @@ dependencies = [
"pillow>=11.2.1",
"pydantic>=2.11.4",
"scikit-image>=0.25.2",
"scikit-learn>=1.7.2",
"scipy>=1.15.3",
"segmentation-models-pytorch>=0.5.0",
"tabulate>=0.9.0",

@ -411,6 +411,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 },
]
[[package]]
name = "joblib"
version = "1.5.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396 },
]
[[package]]
name = "kiwisolver"
version = "1.4.8"
@ -620,6 +629,7 @@ dependencies = [
{ name = "pillow" },
{ name = "pydantic" },
{ name = "scikit-image" },
{ name = "scikit-learn" },
{ name = "scipy" },
{ name = "segmentation-models-pytorch" },
{ name = "tabulate" },
@ -645,6 +655,7 @@ requires-dist = [
{ name = "pillow", specifier = ">=11.2.1" },
{ name = "pydantic", specifier = ">=2.11.4" },
{ name = "scikit-image", specifier = ">=0.25.2" },
{ name = "scikit-learn", specifier = ">=1.7.2" },
{ name = "scipy", specifier = ">=1.15.3" },
{ name = "segmentation-models-pytorch", specifier = ">=0.5.0" },
{ name = "tabulate", specifier = ">=0.9.0" },
@ -1161,6 +1172,45 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/cc/75e9f17e3670b5ed93c32456fda823333c6279b144cd93e2c03aa06aa472/scikit_image-0.25.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:330d061bd107d12f8d68f1d611ae27b3b813b8cdb0300a71d07b1379178dd4cd", size = 13862801 },
]
[[package]]
name = "scikit-learn"
version = "1.7.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "joblib" },
{ name = "numpy" },
{ name = "scipy" },
{ name = "threadpoolctl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/98/c2/a7855e41c9d285dfe86dc50b250978105dce513d6e459ea66a6aeb0e1e0c/scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda", size = 7193136 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/43/83/564e141eef908a5863a54da8ca342a137f45a0bfb71d1d79704c9894c9d1/scikit_learn-1.7.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7509693451651cd7361d30ce4e86a1347493554f172b1c72a39300fa2aea79e", size = 9331967 },
{ url = "https://files.pythonhosted.org/packages/18/d6/ba863a4171ac9d7314c4d3fc251f015704a2caeee41ced89f321c049ed83/scikit_learn-1.7.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:0486c8f827c2e7b64837c731c8feff72c0bd2b998067a8a9cbc10643c31f0fe1", size = 8648645 },
{ url = "https://files.pythonhosted.org/packages/ef/0e/97dbca66347b8cf0ea8b529e6bb9367e337ba2e8be0ef5c1a545232abfde/scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89877e19a80c7b11a2891a27c21c4894fb18e2c2e077815bcade10d34287b20d", size = 9715424 },
{ url = "https://files.pythonhosted.org/packages/f7/32/1f3b22e3207e1d2c883a7e09abb956362e7d1bd2f14458c7de258a26ac15/scikit_learn-1.7.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8da8bf89d4d79aaec192d2bda62f9b56ae4e5b4ef93b6a56b5de4977e375c1f1", size = 9509234 },
{ url = "https://files.pythonhosted.org/packages/9f/71/34ddbd21f1da67c7a768146968b4d0220ee6831e4bcbad3e03dd3eae88b6/scikit_learn-1.7.2-cp311-cp311-win_amd64.whl", hash = "sha256:9b7ed8d58725030568523e937c43e56bc01cadb478fc43c042a9aca1dacb3ba1", size = 8894244 },
{ url = "https://files.pythonhosted.org/packages/a7/aa/3996e2196075689afb9fce0410ebdb4a09099d7964d061d7213700204409/scikit_learn-1.7.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8d91a97fa2b706943822398ab943cde71858a50245e31bc71dba62aab1d60a96", size = 9259818 },
{ url = "https://files.pythonhosted.org/packages/43/5d/779320063e88af9c4a7c2cf463ff11c21ac9c8bd730c4a294b0000b666c9/scikit_learn-1.7.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:acbc0f5fd2edd3432a22c69bed78e837c70cf896cd7993d71d51ba6708507476", size = 8636997 },
{ url = "https://files.pythonhosted.org/packages/5c/d0/0c577d9325b05594fdd33aa970bf53fb673f051a45496842caee13cfd7fe/scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e5bf3d930aee75a65478df91ac1225ff89cd28e9ac7bd1196853a9229b6adb0b", size = 9478381 },
{ url = "https://files.pythonhosted.org/packages/82/70/8bf44b933837ba8494ca0fc9a9ab60f1c13b062ad0197f60a56e2fc4c43e/scikit_learn-1.7.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4d6e9deed1a47aca9fe2f267ab8e8fe82ee20b4526b2c0cd9e135cea10feb44", size = 9300296 },
{ url = "https://files.pythonhosted.org/packages/c6/99/ed35197a158f1fdc2fe7c3680e9c70d0128f662e1fee4ed495f4b5e13db0/scikit_learn-1.7.2-cp312-cp312-win_amd64.whl", hash = "sha256:6088aa475f0785e01bcf8529f55280a3d7d298679f50c0bb70a2364a82d0b290", size = 8731256 },
{ url = "https://files.pythonhosted.org/packages/ae/93/a3038cb0293037fd335f77f31fe053b89c72f17b1c8908c576c29d953e84/scikit_learn-1.7.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b7dacaa05e5d76759fb071558a8b5130f4845166d88654a0f9bdf3eb57851b7", size = 9212382 },
{ url = "https://files.pythonhosted.org/packages/40/dd/9a88879b0c1104259136146e4742026b52df8540c39fec21a6383f8292c7/scikit_learn-1.7.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:abebbd61ad9e1deed54cca45caea8ad5f79e1b93173dece40bb8e0c658dbe6fe", size = 8592042 },
{ url = "https://files.pythonhosted.org/packages/46/af/c5e286471b7d10871b811b72ae794ac5fe2989c0a2df07f0ec723030f5f5/scikit_learn-1.7.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:502c18e39849c0ea1a5d681af1dbcf15f6cce601aebb657aabbfe84133c1907f", size = 9434180 },
{ url = "https://files.pythonhosted.org/packages/f1/fd/df59faa53312d585023b2da27e866524ffb8faf87a68516c23896c718320/scikit_learn-1.7.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a4c328a71785382fe3fe676a9ecf2c86189249beff90bf85e22bdb7efaf9ae0", size = 9283660 },
{ url = "https://files.pythonhosted.org/packages/a7/c7/03000262759d7b6f38c836ff9d512f438a70d8a8ddae68ee80de72dcfb63/scikit_learn-1.7.2-cp313-cp313-win_amd64.whl", hash = "sha256:63a9afd6f7b229aad94618c01c252ce9e6fa97918c5ca19c9a17a087d819440c", size = 8702057 },
{ url = "https://files.pythonhosted.org/packages/55/87/ef5eb1f267084532c8e4aef98a28b6ffe7425acbfd64b5e2f2e066bc29b3/scikit_learn-1.7.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9acb6c5e867447b4e1390930e3944a005e2cb115922e693c08a323421a6966e8", size = 9558731 },
{ url = "https://files.pythonhosted.org/packages/93/f8/6c1e3fc14b10118068d7938878a9f3f4e6d7b74a8ddb1e5bed65159ccda8/scikit_learn-1.7.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:2a41e2a0ef45063e654152ec9d8bcfc39f7afce35b08902bfe290c2498a67a6a", size = 9038852 },
{ url = "https://files.pythonhosted.org/packages/83/87/066cafc896ee540c34becf95d30375fe5cbe93c3b75a0ee9aa852cd60021/scikit_learn-1.7.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98335fb98509b73385b3ab2bd0639b1f610541d3988ee675c670371d6a87aa7c", size = 9527094 },
{ url = "https://files.pythonhosted.org/packages/9c/2b/4903e1ccafa1f6453b1ab78413938c8800633988c838aa0be386cbb33072/scikit_learn-1.7.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:191e5550980d45449126e23ed1d5e9e24b2c68329ee1f691a3987476e115e09c", size = 9367436 },
{ url = "https://files.pythonhosted.org/packages/b5/aa/8444be3cfb10451617ff9d177b3c190288f4563e6c50ff02728be67ad094/scikit_learn-1.7.2-cp313-cp313t-win_amd64.whl", hash = "sha256:57dc4deb1d3762c75d685507fbd0bc17160144b2f2ba4ccea5dc285ab0d0e973", size = 9275749 },
{ url = "https://files.pythonhosted.org/packages/d9/82/dee5acf66837852e8e68df6d8d3a6cb22d3df997b733b032f513d95205b7/scikit_learn-1.7.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:fa8f63940e29c82d1e67a45d5297bdebbcb585f5a5a50c4914cc2e852ab77f33", size = 9208906 },
{ url = "https://files.pythonhosted.org/packages/3c/30/9029e54e17b87cb7d50d51a5926429c683d5b4c1732f0507a6c3bed9bf65/scikit_learn-1.7.2-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:f95dc55b7902b91331fa4e5845dd5bde0580c9cd9612b1b2791b7e80c3d32615", size = 8627836 },
{ url = "https://files.pythonhosted.org/packages/60/18/4a52c635c71b536879f4b971c2cedf32c35ee78f48367885ed8025d1f7ee/scikit_learn-1.7.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9656e4a53e54578ad10a434dc1f993330568cfee176dff07112b8785fb413106", size = 9426236 },
{ url = "https://files.pythonhosted.org/packages/99/7e/290362f6ab582128c53445458a5befd471ed1ea37953d5bcf80604619250/scikit_learn-1.7.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96dc05a854add0e50d3f47a1ef21a10a595016da5b007c7d9cd9d0bffd1fcc61", size = 9312593 },
{ url = "https://files.pythonhosted.org/packages/8e/87/24f541b6d62b1794939ae6422f8023703bbf6900378b2b34e0b4384dfefd/scikit_learn-1.7.2-cp314-cp314-win_amd64.whl", hash = "sha256:bb24510ed3f9f61476181e4db51ce801e2ba37541def12dc9333b946fc7a9cf8", size = 8820007 },
]
[[package]]
name = "scipy"
version = "1.15.3"
@ -1344,6 +1394,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 },
]
[[package]]
name = "threadpoolctl"
version = "3.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638 },
]
[[package]]
name = "tifffile"
version = "2025.3.30"

Loading…
Cancel
Save