From 4bdd1ee872fa3f79703f325ef6f9ccc19579d12b Mon Sep 17 00:00:00 2001 From: laynholt Date: Tue, 25 Nov 2025 13:12:50 +0000 Subject: [PATCH] model outputs converted to dict --- core/losses/mse_with_bce.py | 34 ++++----- core/models/model_v.py | 142 ++++++++++++++++++++++++++++++------ core/segmentator.py | 89 +++++++++++----------- main.py | 135 ++++++++++++++++++++++++---------- pyproject.toml | 1 + uv.lock | 59 +++++++++++++++ 6 files changed, 333 insertions(+), 127 deletions(-) diff --git a/core/losses/mse_with_bce.py b/core/losses/mse_with_bce.py index f5f89ba..542cfbc 100644 --- a/core/losses/mse_with_bce.py +++ b/core/losses/mse_with_bce.py @@ -68,34 +68,34 @@ class BCE_MSE_Loss(BaseLoss): self.loss_bce_metric = CumulativeAverage() self.loss_mse_metric = CumulativeAverage() - def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward(self, outputs: dict[str, torch.Tensor], target: dict[str, torch.Tensor]) -> torch.Tensor: # type: ignore """ Computes the loss between true labels and prediction outputs. Args: - outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W). - target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W). + outputs (dict[str, torch.Tensor]): Model predictions: + - "flow": (B, 2*C, H, W) per-class flow (flow_vectors[0] is Y flow, flow_vectors[1] is X flow) + - "logits": (B, C, H, W) per-class logits. + target (dict[str, torch.Tensor]): Ground truth labels: + - "flow": (B, 2*C, H, W) per-class flow (flow_vectors[0] is Y flow, flow_vectors[1] is X flow) + - "labels": (B, C, H, W) per-class labels. Returns: torch.Tensor: The total loss value. """ - # Ensure target is on the same device as outputs - assert ( - target.device == outputs.device - ), ( - "Target tensor must be moved to the same device as outputs " - "before calling forward()." - ) - + logits: torch.Tensor = outputs["logits"] # (B,C,H,W) + flow_pred: torch.Tensor = outputs["flow"] # (B,2*C,H,W) + labels: torch.Tensor = target["labels"] # (B,C,H,W) + flow_tgt: torch.Tensor = target["flow"] # (B,2*C,H,W) + + labels = labels.to(device=logits.device) + flow_tgt = flow_tgt.to(device=flow_pred.device) + # Cell Recognition Loss - cellprob_loss = self.bce_loss( - outputs[:, -self.num_classes:], (target[:, -self.num_classes:] > 0).float() - ) + cellprob_loss = self.bce_loss(logits, (labels > 0).float()) # Cell Distinction Loss - gradflow_loss = 0.5 * self.mse_loss( - outputs[:, :2 * self.num_classes], 5.0 * target[:, :2 * self.num_classes] - ) + gradflow_loss = 0.5 * self.mse_loss(flow_pred, 5.0 * flow_tgt) # Total loss total_loss = cellprob_loss + gradflow_loss diff --git a/core/models/model_v.py b/core/models/model_v.py index 35418ea..8bd109d 100644 --- a/core/models/model_v.py +++ b/core/models/model_v.py @@ -1,7 +1,7 @@ import torch from torch import nn -from typing import Any +from typing import Any, Callable from segmentation_models_pytorch import MAnet from segmentation_models_pytorch.base.modules import Activation @@ -20,6 +20,10 @@ class ModelVParams(BaseModel): in_channels: int = 3 # Number of input channels out_classes: int = 1 # Number of output classes + prefer_gn: bool = True # Prefer GroupNorm for small batches; falls back to BatchNorm if prefer_gn=False + upsample: int = 1 # Upsampling factor for heads (keep 1 if decoder outputs final spatial size) + zero_init_heads: bool = False # Apply zero init to final conv in heads (stabilizes early training) + def asdict(self) -> dict[str, Any]: """ Returns a dictionary of valid parameters for `nn.ModelV`. @@ -39,25 +43,58 @@ class ModelV(MAnet): super().__init__(**params.asdict()) self.num_classes = params.out_classes + in_ch: int = params.decoder_channels[-1] # Remove the default segmentation head as it's not used in this architecture self.segmentation_head = None # Modify all activation functions in the encoder and decoder from ReLU to Mish - _convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True)) - _convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True)) + _convert_activations(self.encoder, nn.ReLU, lambda: nn.Mish(inplace=True)) + _convert_activations(self.decoder, nn.ReLU, lambda: nn.Mish(inplace=True)) + + self.prefer_gn: bool = params.prefer_gn + self.upsample: int = params.upsample + self.zero_init_heads: bool = params.zero_init_heads # Add custom segmentation heads for different segmentation tasks - self.cellprob_head = DeepSegmentationHead( - in_channels=params.decoder_channels[-1], out_channels=params.out_classes + # Per-class logits head (C channels) + self.logits_head = DeepSegmentationHead( + in_channels=in_ch, + out_channels=self.num_classes, + kernel_size=3, + activation=None, # keep logits raw; apply loss with logits + upsampling=self.upsample, + prefer_gn=self.prefer_gn, ) + + # Flow head per-class (2*C) + flow_out = 2 * self.num_classes self.gradflow_head = DeepSegmentationHead( - in_channels=params.decoder_channels[-1], out_channels=2 * params.out_classes + in_channels=in_ch, + out_channels=flow_out, + kernel_size=3, + activation=None, # we'll apply tanh explicitly (if enabled) + upsampling=self.upsample, + prefer_gn=self.prefer_gn, ) + + if self.zero_init_heads: + zero_init_last_conv(self.logits_head) + zero_init_last_conv(self.gradflow_head) + + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Forward pass through the network. + Args: + x: Input tensor of shape (B, in_channels, H, W). - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass through the network""" + Returns: + Dict with: + - "flow": (B, 2*C, H, W) per-class flow; + - "logits": (B, C, H, W) per-class logits. + """ # Ensure the input shape is correct self.check_input_shape(x) @@ -66,17 +103,21 @@ class ModelV(MAnet): decoder_output = self.decoder(features) # Generate masks for cell probability and gradient flows - cellprob_mask = self.cellprob_head(decoder_output) - gradflow_mask = self.gradflow_head(decoder_output) + logits = self.logits_head(decoder_output) # (B, C, H, W) + flow = self.gradflow_head(decoder_output) # (B, 2*C, H, W) - # Concatenate the masks for output - masks = torch.cat((gradflow_mask, cellprob_mask), dim=1) - - return masks + return {"flow": flow, "logits": logits} class DeepSegmentationHead(nn.Sequential): - """Custom segmentation head for generating specific masks""" + """ + A robust segmentation head block: + Conv(bias=False) -> Norm -> Mish -> Conv -> (Upsample) -> (Activation?) + Notes: + * Using bias=False on the first conv since normalization follows it. + * GroupNorm is preferred for small batch sizes; fall back to BatchNorm2d. + * The 'mid' width is clamped by a minimal value to avoid too narrow bottlenecks. + """ def __init__( self, @@ -85,35 +126,88 @@ class DeepSegmentationHead(nn.Sequential): kernel_size: int = 3, activation: str | None = None, upsampling: int = 1, + prefer_gn: bool = True, + min_mid: int = 8, + reduce_ratio: float = 0.5, ) -> None: - # Define a sequence of layers for the segmentation head + mid = compute_mid(in_channels, r=reduce_ratio, min_mid=min_mid) + norm_layer = make_norm(mid, prefer_gn=prefer_gn) + layers: list[nn.Module] = [ nn.Conv2d( in_channels, - in_channels // 2, + mid, kernel_size=kernel_size, padding=kernel_size // 2, + bias=False, ), + norm_layer, nn.Mish(inplace=True), - nn.BatchNorm2d(in_channels // 2), nn.Conv2d( - in_channels // 2, + mid, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, + bias=True, # final conv may keep bias; can be zero-initialized ), - nn.UpsamplingBilinear2d(scale_factor=upsampling) - if upsampling > 1 - else nn.Identity(), + nn.Upsample(scale_factor=upsampling, mode="bilinear", align_corners=False) + if upsampling > 1 else nn.Identity(), Activation(activation) if activation else nn.Identity(), ] super().__init__(*layers) -def _convert_activations(module: nn.Module, from_activation: type, to_activation: nn.Module) -> None: +def _convert_activations(module: nn.Module, from_activation: type, to_activation: Callable) -> None: """Recursively convert activation functions in a module""" for name, child in module.named_children(): if isinstance(child, from_activation): - setattr(module, name, to_activation) + setattr(module, name, to_activation()) else: _convert_activations(child, from_activation, to_activation) + + +def make_norm(num_channels: int, prefer_gn: bool = True) -> nn.Module: + """ + Return a normalization layer resilient to small batch sizes. + GroupNorm is independent of batch dimension and thus stable when B is small. + """ + if prefer_gn: + for g in (32, 16, 8, 4, 2, 1): + if num_channels % g == 0: + return nn.GroupNorm(g, num_channels) + # Fallback: 1 group ~ LayerNorm across channels (per-spatial) + return nn.GroupNorm(1, num_channels) + else: + return nn.BatchNorm2d(num_channels) + + +def compute_mid( + in_ch: int, + r: float = 0.5, + min_mid: int = 8, + groups_hint: tuple[int, ...] = (32, 16, 8, 4, 2) +) -> int: + """ + Compute intermediate channel width for the head. + Ensures a minimum width and (optionally) tries to align to a group size. + """ + raw = max(min_mid, int(round(in_ch * r))) + for g in groups_hint: + if raw % g == 0: + return raw + return raw + + +def zero_init_last_conv(module: nn.Module) -> None: + """ + Zero-initialize the last Conv2d in a head to make its initial contribution neutral. + This often stabilizes early training for multi-task heads. + """ + last_conv: nn.Conv2d | None = None + for m in module.modules(): + if isinstance(m, nn.Conv2d): + last_conv = m + if last_conv is not None: + nn.init.zeros_(last_conv.weight) + if last_conv.bias is not None: + nn.init.zeros_(last_conv.bias) \ No newline at end of file diff --git a/core/segmentator.py b/core/segmentator.py index 747646d..9ce7aca 100644 --- a/core/segmentator.py +++ b/core/segmentator.py @@ -973,7 +973,7 @@ class CellSegmentator: # Compute loss for this batch batch_loss = self._criterion( raw_output, - torch.from_numpy(flow_targets).to(device=raw_output.device) + {k: torch.from_numpy(v) for k, v in flow_targets.items()} ) # Post-process and compute F1 during validation and testing @@ -1027,7 +1027,7 @@ class CellSegmentator: else: epoch_metrics = {} - # Include F1 and mAP for validation and testing + # Include F1 and AP for validation and testing if mode in ("valid", "test"): # Concatenating by batch: shape (num_batches*B, C) tp_array = np.vstack(all_tp) @@ -1046,13 +1046,13 @@ class CellSegmentator: epoch_metrics[f"{mode}_f1_score_macro"] = self.__compute_f1_metric( tp_array, fp_array, fn_array, reduction="macro" ) - epoch_metrics[f"{mode}_mAP_micro"] = self.__compute_average_precision_metric( + epoch_metrics[f"{mode}_AP_micro"] = self.__compute_average_precision_metric( tp_array, fp_array, fn_array, reduction="micro" ) - epoch_metrics[f"{mode}_mAP_macro"] = self.__compute_average_precision_metric( + epoch_metrics[f"{mode}_AP_macro"] = self.__compute_average_precision_metric( tp_array, fp_array, fn_array, reduction="macro" ) - epoch_metrics[f"{mode}_mAP_pc"] = self.__compute_average_precision_metric( + epoch_metrics[f"{mode}_AP_pc"] = self.__compute_average_precision_metric( tp_array, fp_array, fn_array, reduction="per_class" ) @@ -1063,7 +1063,7 @@ class CellSegmentator: self, inputs: torch.Tensor, mode: Literal["train", "valid", "test", "predict"] = "train" - ) -> torch.Tensor: + ) -> dict[str, torch.Tensor]: """ Perform model inference for different stages. @@ -1072,7 +1072,7 @@ class CellSegmentator: stage (Literal[...]): One of "train", "valid", "test", "predict". Returns: - torch.Tensor: Model outputs tensor. + dict[str, torch.Tensor]: Model outputs tensor. """ if mode != "train": # Use sliding window inference for non-training phases @@ -1093,14 +1093,16 @@ class CellSegmentator: def __post_process_predictions( self, - raw_outputs: torch.Tensor, + raw_outputs: dict[str, torch.Tensor], ground_truth: torch.Tensor | None = None ) -> tuple[np.ndarray, np.ndarray | None]: """ Post-process raw network outputs to extract instance segmentation masks. Args: - raw_outputs (torch.Tensor): Raw model outputs of shape (B, С, H, W). + raw_outputs (dict[str, torch.Tensor]): Raw model outputs: + - "flow": (B, 2*C, H, W) per-class flow; + - "logits": (B, C, H, W) per-class logits. ground_truth (torch.Tensor | None): Ground truth masks of shape (B, С, H, W). Returns: @@ -1109,11 +1111,9 @@ class CellSegmentator: - labels_np: Converted ground truth of shape (B, С, H, W) or None if ground_truth was not provided. """ - # Move outputs to CPU and convert to numpy - outputs_np = raw_outputs.cpu().numpy() # Split channels: gradient flows then class logits - gradflow = outputs_np[:, :2 * self._model.num_classes] - logits = outputs_np[:, -self._model.num_classes :] + gradflow = raw_outputs["flow"].cpu().numpy() + logits = raw_outputs["logits"].cpu().numpy() # Apply sigmoid to logits to get probabilities probabilities = self.__sigmoid(logits) @@ -1271,12 +1271,11 @@ class CellSegmentator: if reduction == "macro": f1_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): - _, _, f1_c = compute_f1_score( + _, _, f1_per_class[c] = compute_f1_score( tp_per_class[c], fp_per_class[c], fn_per_class[c] ) - f1_per_class[c] = f1_c return float(f1_per_class.mean()) # 6) Weighted: class-wise F1 weighted by support = TP + FN @@ -1381,7 +1380,7 @@ class CellSegmentator: fp_per_class = false_positives.sum(axis=0).astype(int) fn_per_class = false_negatives.sum(axis=0).astype(int) - # 4) Per-class: compute F1 for each class and return vector + # 4) Per-class: compute AP for each class and return vector if reduction == "per_class": ap_per_class = np.zeros(num_classes, dtype=float) for c in range(num_classes): @@ -1752,7 +1751,7 @@ class CellSegmentator: def __compute_flows_from_masks( self, true_masks: Tensor - ) -> np.ndarray: + ) -> dict[str, np.ndarray]: """ Convert segmentation masks to flow fields for training. @@ -1760,8 +1759,9 @@ class CellSegmentator: true_masks: Torch tensor of shape (batch, C, H, W) containing integer masks. Returns: - numpy array of concatenated [flow_vectors, renumbered_true_masks] per image. - renumbered_true_masks is labels, flow_vectors[0] is Y flow, flow_vectors[1] is X flow. + Dict with: + - "flow": (B, 2*C, H, W) per-class flow (flow_vectors[0] is Y flow, flow_vectors[1] is X flow) + - "labels": (B, C, H, W) per-class labels. """ # Move to CPU numpy _true_masks: np.ndarray = true_masks.cpu().numpy().astype(np.int16) @@ -1783,7 +1783,7 @@ class CellSegmentator: flow_vectors = np.stack([self.__compute_flow_from_mask(_true_masks[i]) for i in range(batch_size)]) - return np.concatenate((flow_vectors, _true_masks), axis=1).astype(np.float32) + return {"flow": flow_vectors.astype(np.float32), "labels": _true_masks.astype(np.float32)} def __compute_flow_from_mask( @@ -2245,9 +2245,10 @@ class CellSegmentator: # Follow the flow vectors to generate coordinate mappings flow_coordinates = self.__follow_flows( - flow_field=channel_flow_vectors * channel_mask / 5.0, + flow_field=channel_flow_vectors * channel_mask, initial_coords=nonzero_coords, - num_iters=num_iters + num_iters=num_iters, + step_scale=0.2 ) if not torch.is_tensor(flow_coordinates): @@ -2292,7 +2293,8 @@ class CellSegmentator: self, flow_field: np.ndarray, initial_coords: np.ndarray, - num_iters: int = 200 + num_iters: int = 200, + step_scale: float = 1.0 ) -> np.ndarray | torch.Tensor: """ Trace pixel positions through a flow field via iterative interpolation. @@ -2301,6 +2303,7 @@ class CellSegmentator: flow_field (np.ndarray): Array of shape (2, H, W) containing flow vectors. initial_coords (np.ndarray): Array of shape (2, num_points) with starting (y, x) positions. num_iters (int): Number of integration steps. + step_scale (float): Step length in "pixels per iteration". Returns: (np.ndarray | torch.Tensor): Final (y, x) positions of each point. @@ -2311,50 +2314,40 @@ class CellSegmentator: # Choose GPU/MPS path if available if self._device.type in ("cuda", "mps"): + # Load initial positions and flow into tensors (flip order for grid_sample) # Prepare point tensor: shape [1, 1, num_points, 2] - pts = torch.zeros((1, 1, initial_coords.shape[1], dims), - dtype=torch.float32, device=self._device) + pts = torch.empty((1, 1, initial_coords.shape[1], dims), dtype=torch.float32, device=self._device) + pts[..., 0] = torch.from_numpy(initial_coords[1]).to(self._device, torch.float32) # x + pts[..., 1] = torch.from_numpy(initial_coords[0]).to(self._device, torch.float32) # y + # Prepare flow volume: shape [1, 2, height, width] - flow_vol = torch.zeros((1, dims, height, width), - dtype=torch.float32, device=self._device) - - # Load initial positions and flow into tensors (flip order for grid_sample) - # dim 0 = x - # dim 1 = y - for i in range(dims): - pts[0, 0, :, dims - i - 1] = ( - torch.from_numpy(initial_coords[i]) - .to(self._device, torch.float32) - ) - flow_vol[0, dims - i - 1] = ( - torch.from_numpy(flow_field[i]) - .to(self._device, torch.float32) - ) + flow_vol = torch.empty((1, dims, height, width), dtype=torch.float32, device=self._device) + flow_vol[:, 0] = torch.from_numpy(flow_field[1]).to(self._device, torch.float32) # x-component + flow_vol[:, 1] = torch.from_numpy(flow_field[0]).to(self._device, torch.float32) # y-component # Prepare normalization factors for x and y (max index) - max_indices = torch.tensor([width - 1, height - 1], - dtype=torch.float32, device=self._device) + max_indices = torch.tensor([width - 1, height - 1], dtype=torch.float32, device=self._device) # (2,) # Reshape for broadcasting to point tensor dims max_idx_pt = max_indices.view(1, 1, 1, dims) # Reshape for broadcasting to flow volume dims max_idx_flow = max_indices.view(1, dims, 1, 1) # Normalize flow values to [-1, 1] range - flow_vol = (flow_vol * 2) / max_idx_flow + flow_vol = (flow_vol * step_scale * 2.0) / max_idx_flow # Normalize points to [-1, 1] - pts = (pts / max_idx_pt) * 2 - 1 + pts = (pts / max_idx_pt) * 2.0 - 1.0 # Iterate: sample flow and update points for _ in range(num_iters): sampled = torch.nn.functional.grid_sample( - flow_vol, pts, align_corners=False + flow_vol, pts, align_corners=True ) # Update each coordinate and clamp to valid range for i in range(dims): pts[..., i] = torch.clamp(pts[..., i] + sampled[:, i], -1.0, 1.0) # Denormalize back to original pixel coordinates - pts = (pts + 1) * 0.5 * max_idx_pt + pts = (pts + 1.0) * 0.5 * max_idx_pt # Swap channels back to (y, x) and flatten final_pts = pts[..., [1, 0]].squeeze() # Convert from (num_points, 2) to (2, num_points) @@ -2368,8 +2361,8 @@ class CellSegmentator: # Interpolate flow at current positions self.__map_coordinates(flow_field, current_pos[0], current_pos[1], temp_delta) # Update positions and clamp to image bounds - current_pos[0] = np.clip(current_pos[0] + temp_delta[0], 0, height - 1) - current_pos[1] = np.clip(current_pos[1] + temp_delta[1], 0, width - 1) + current_pos[0] = np.clip(current_pos[0] + step_scale * temp_delta[0], 0, height - 1) + current_pos[1] = np.clip(current_pos[1] + step_scale * temp_delta[1], 0, width - 1) return current_pos diff --git a/main.py b/main.py index 4ddb2dd..e72da38 100644 --- a/main.py +++ b/main.py @@ -13,42 +13,57 @@ from core.data import ( from core.segmentator import CellSegmentator -def main() -> None: - parser = argparse.ArgumentParser( - description="Train or predict cell segmentator with specified config file." - ) - parser.add_argument( - '-c', '--config', - type=str, - help='Path to the JSON config file' - ) - parser.add_argument( - '-m', '--mode', - choices=['train', 'test', 'predict'], - default='train', - help='Run mode: train, test or predict' - ) - parser.add_argument( - '--no-save-masks', - action='store_false', - dest='save_masks', - help='If set, do NOT save predicted masks (saving is enabled by default)' - ) - parser.add_argument( - '--only-masks', - action='store_true', - help=('If set and save-masks set, save only the raw predicted' - ' masks without additional visualizations') - ) +def main( + manual: bool = False, + config_path: str | None = None, + mode: str | None = None, + save_masks: bool = True, + only_masks: bool = False +) -> None: + + if not manual: + parser = argparse.ArgumentParser( + description="Train or predict cell segmentator with specified config file." + ) + parser.add_argument( + '-c', '--config', + type=str, + help='Path to the JSON config file' + ) + parser.add_argument( + '-m', '--mode', + choices=['train', 'test', 'predict'], + default='train', + help='Run mode: train, test or predict' + ) + parser.add_argument( + '--no-save-masks', + action='store_false', + dest='save_masks', + help='If set, do NOT save predicted masks (saving is enabled by default)' + ) + parser.add_argument( + '--only-masks', + action='store_true', + help=('If set and save-masks set, save only the raw predicted' + ' masks without additional visualizations') + ) - if len(sys.argv) == 1: - parser.print_help() - sys.exit(0) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(0) - args = parser.parse_args() + args = parser.parse_args() - mode = args.mode - config_path = args.config + mode = args.mode + config_path = args.config + save_masks = args.save_masks + only_masks = args.only_masks + + else: + if mode is None or config_path is None: + raise ValueError("In manual mode, you must specify the path to the config and mode!") + config = Config.load_json(config_path) if mode == 'train' and not config.dataset_config.is_training: @@ -62,7 +77,7 @@ def main() -> None: if config.wandb_config.use_wandb: # Initialize W&B - wandb.init(config=config.asdict(), **config.wandb_config.asdict()) + wandb.init(config=config.asdict(), reinit="finish_previous", **config.wandb_config.asdict()) # How many batches to wait before logging training status wandb.config.log_interval = 10 @@ -84,9 +99,9 @@ def main() -> None: wandb.watch(segmentator._model, log="all", log_graph=True) try: - segmentator.run(save_results=args.save_masks, only_masks=args.only_masks) - except Exception: - raise + segmentator.run(save_results=save_masks, only_masks=only_masks) + except Exception as e: + raise e finally: if config.dataset_config.is_training: # Prepare saving path @@ -101,8 +116,52 @@ def main() -> None: if config.wandb_config.use_wandb: wandb.save(saving_path) + wandb.finish() if __name__ == "__main__": - main() + train_configs = [ + "/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cA.json", + "/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cB.json", + "/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cSoma.json", + "/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cAB.json", + "/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cABSoma.json", + + "/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cytoCell.json", + "/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cytoNuc.json", + "/workspace/model-v/config/templates/train/ModelV_BCE_MSE_Loss_AdamW_CosineAnnealing_cytoCellNuc.json", + ] + + predict_configs = [ + "/workspace/model-v/config/templates/predict/ModelV_cA.json", + "/workspace/model-v/config/templates/predict/ModelV_cB.json", + "/workspace/model-v/config/templates/predict/ModelV_cSoma.json", + "/workspace/model-v/config/templates/predict/ModelV_cAB.json", + "/workspace/model-v/config/templates/predict/ModelV_cABSoma.json", + + "/workspace/model-v/config/templates/predict/ModelV_cytoCell.json", + "/workspace/model-v/config/templates/predict/ModelV_cytoNuc.json", + "/workspace/model-v/config/templates/predict/ModelV_cytoCellNuc.json", + ] + + for config in train_configs: + print(f"Hande config {config}") + main( + manual=True, + config_path=config, + mode="train", + save_masks=True, + only_masks=False + ) + + for config in predict_configs: + print(f"Hande config {config}") + main( + manual=True, + config_path=config, + mode="predict", + save_masks=True, + only_masks=False + ) + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index c611c56..7619d45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "pillow>=11.2.1", "pydantic>=2.11.4", "scikit-image>=0.25.2", + "scikit-learn>=1.7.2", "scipy>=1.15.3", "segmentation-models-pytorch>=0.5.0", "tabulate>=0.9.0", diff --git a/uv.lock b/uv.lock index 9f849e7..fdb9f13 100644 --- a/uv.lock +++ b/uv.lock @@ -411,6 +411,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 }, ] +[[package]] +name = "joblib" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396 }, +] + [[package]] name = "kiwisolver" version = "1.4.8" @@ -620,6 +629,7 @@ dependencies = [ { name = "pillow" }, { name = "pydantic" }, { name = "scikit-image" }, + { name = "scikit-learn" }, { name = "scipy" }, { name = "segmentation-models-pytorch" }, { name = "tabulate" }, @@ -645,6 +655,7 @@ requires-dist = [ { name = "pillow", specifier = ">=11.2.1" }, { name = "pydantic", specifier = ">=2.11.4" }, { name = "scikit-image", specifier = ">=0.25.2" }, + { name = "scikit-learn", specifier = ">=1.7.2" }, { name = "scipy", specifier = ">=1.15.3" }, { name = "segmentation-models-pytorch", specifier = ">=0.5.0" }, { name = "tabulate", specifier = ">=0.9.0" }, @@ -1161,6 +1172,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/cc/75e9f17e3670b5ed93c32456fda823333c6279b144cd93e2c03aa06aa472/scikit_image-0.25.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:330d061bd107d12f8d68f1d611ae27b3b813b8cdb0300a71d07b1379178dd4cd", size = 13862801 }, ] +[[package]] +name = "scikit-learn" +version = "1.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/c2/a7855e41c9d285dfe86dc50b250978105dce513d6e459ea66a6aeb0e1e0c/scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda", size = 7193136 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/83/564e141eef908a5863a54da8ca342a137f45a0bfb71d1d79704c9894c9d1/scikit_learn-1.7.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7509693451651cd7361d30ce4e86a1347493554f172b1c72a39300fa2aea79e", size = 9331967 }, + { url = "https://files.pythonhosted.org/packages/18/d6/ba863a4171ac9d7314c4d3fc251f015704a2caeee41ced89f321c049ed83/scikit_learn-1.7.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:0486c8f827c2e7b64837c731c8feff72c0bd2b998067a8a9cbc10643c31f0fe1", size = 8648645 }, + { url = "https://files.pythonhosted.org/packages/ef/0e/97dbca66347b8cf0ea8b529e6bb9367e337ba2e8be0ef5c1a545232abfde/scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89877e19a80c7b11a2891a27c21c4894fb18e2c2e077815bcade10d34287b20d", size = 9715424 }, + { url = "https://files.pythonhosted.org/packages/f7/32/1f3b22e3207e1d2c883a7e09abb956362e7d1bd2f14458c7de258a26ac15/scikit_learn-1.7.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8da8bf89d4d79aaec192d2bda62f9b56ae4e5b4ef93b6a56b5de4977e375c1f1", size = 9509234 }, + { url = "https://files.pythonhosted.org/packages/9f/71/34ddbd21f1da67c7a768146968b4d0220ee6831e4bcbad3e03dd3eae88b6/scikit_learn-1.7.2-cp311-cp311-win_amd64.whl", hash = "sha256:9b7ed8d58725030568523e937c43e56bc01cadb478fc43c042a9aca1dacb3ba1", size = 8894244 }, + { url = "https://files.pythonhosted.org/packages/a7/aa/3996e2196075689afb9fce0410ebdb4a09099d7964d061d7213700204409/scikit_learn-1.7.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8d91a97fa2b706943822398ab943cde71858a50245e31bc71dba62aab1d60a96", size = 9259818 }, + { url = "https://files.pythonhosted.org/packages/43/5d/779320063e88af9c4a7c2cf463ff11c21ac9c8bd730c4a294b0000b666c9/scikit_learn-1.7.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:acbc0f5fd2edd3432a22c69bed78e837c70cf896cd7993d71d51ba6708507476", size = 8636997 }, + { url = "https://files.pythonhosted.org/packages/5c/d0/0c577d9325b05594fdd33aa970bf53fb673f051a45496842caee13cfd7fe/scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e5bf3d930aee75a65478df91ac1225ff89cd28e9ac7bd1196853a9229b6adb0b", size = 9478381 }, + { url = "https://files.pythonhosted.org/packages/82/70/8bf44b933837ba8494ca0fc9a9ab60f1c13b062ad0197f60a56e2fc4c43e/scikit_learn-1.7.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4d6e9deed1a47aca9fe2f267ab8e8fe82ee20b4526b2c0cd9e135cea10feb44", size = 9300296 }, + { url = "https://files.pythonhosted.org/packages/c6/99/ed35197a158f1fdc2fe7c3680e9c70d0128f662e1fee4ed495f4b5e13db0/scikit_learn-1.7.2-cp312-cp312-win_amd64.whl", hash = "sha256:6088aa475f0785e01bcf8529f55280a3d7d298679f50c0bb70a2364a82d0b290", size = 8731256 }, + { url = "https://files.pythonhosted.org/packages/ae/93/a3038cb0293037fd335f77f31fe053b89c72f17b1c8908c576c29d953e84/scikit_learn-1.7.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b7dacaa05e5d76759fb071558a8b5130f4845166d88654a0f9bdf3eb57851b7", size = 9212382 }, + { url = "https://files.pythonhosted.org/packages/40/dd/9a88879b0c1104259136146e4742026b52df8540c39fec21a6383f8292c7/scikit_learn-1.7.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:abebbd61ad9e1deed54cca45caea8ad5f79e1b93173dece40bb8e0c658dbe6fe", size = 8592042 }, + { url = "https://files.pythonhosted.org/packages/46/af/c5e286471b7d10871b811b72ae794ac5fe2989c0a2df07f0ec723030f5f5/scikit_learn-1.7.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:502c18e39849c0ea1a5d681af1dbcf15f6cce601aebb657aabbfe84133c1907f", size = 9434180 }, + { url = "https://files.pythonhosted.org/packages/f1/fd/df59faa53312d585023b2da27e866524ffb8faf87a68516c23896c718320/scikit_learn-1.7.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a4c328a71785382fe3fe676a9ecf2c86189249beff90bf85e22bdb7efaf9ae0", size = 9283660 }, + { url = "https://files.pythonhosted.org/packages/a7/c7/03000262759d7b6f38c836ff9d512f438a70d8a8ddae68ee80de72dcfb63/scikit_learn-1.7.2-cp313-cp313-win_amd64.whl", hash = "sha256:63a9afd6f7b229aad94618c01c252ce9e6fa97918c5ca19c9a17a087d819440c", size = 8702057 }, + { url = "https://files.pythonhosted.org/packages/55/87/ef5eb1f267084532c8e4aef98a28b6ffe7425acbfd64b5e2f2e066bc29b3/scikit_learn-1.7.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9acb6c5e867447b4e1390930e3944a005e2cb115922e693c08a323421a6966e8", size = 9558731 }, + { url = "https://files.pythonhosted.org/packages/93/f8/6c1e3fc14b10118068d7938878a9f3f4e6d7b74a8ddb1e5bed65159ccda8/scikit_learn-1.7.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:2a41e2a0ef45063e654152ec9d8bcfc39f7afce35b08902bfe290c2498a67a6a", size = 9038852 }, + { url = "https://files.pythonhosted.org/packages/83/87/066cafc896ee540c34becf95d30375fe5cbe93c3b75a0ee9aa852cd60021/scikit_learn-1.7.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98335fb98509b73385b3ab2bd0639b1f610541d3988ee675c670371d6a87aa7c", size = 9527094 }, + { url = "https://files.pythonhosted.org/packages/9c/2b/4903e1ccafa1f6453b1ab78413938c8800633988c838aa0be386cbb33072/scikit_learn-1.7.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:191e5550980d45449126e23ed1d5e9e24b2c68329ee1f691a3987476e115e09c", size = 9367436 }, + { url = "https://files.pythonhosted.org/packages/b5/aa/8444be3cfb10451617ff9d177b3c190288f4563e6c50ff02728be67ad094/scikit_learn-1.7.2-cp313-cp313t-win_amd64.whl", hash = "sha256:57dc4deb1d3762c75d685507fbd0bc17160144b2f2ba4ccea5dc285ab0d0e973", size = 9275749 }, + { url = "https://files.pythonhosted.org/packages/d9/82/dee5acf66837852e8e68df6d8d3a6cb22d3df997b733b032f513d95205b7/scikit_learn-1.7.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:fa8f63940e29c82d1e67a45d5297bdebbcb585f5a5a50c4914cc2e852ab77f33", size = 9208906 }, + { url = "https://files.pythonhosted.org/packages/3c/30/9029e54e17b87cb7d50d51a5926429c683d5b4c1732f0507a6c3bed9bf65/scikit_learn-1.7.2-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:f95dc55b7902b91331fa4e5845dd5bde0580c9cd9612b1b2791b7e80c3d32615", size = 8627836 }, + { url = "https://files.pythonhosted.org/packages/60/18/4a52c635c71b536879f4b971c2cedf32c35ee78f48367885ed8025d1f7ee/scikit_learn-1.7.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9656e4a53e54578ad10a434dc1f993330568cfee176dff07112b8785fb413106", size = 9426236 }, + { url = "https://files.pythonhosted.org/packages/99/7e/290362f6ab582128c53445458a5befd471ed1ea37953d5bcf80604619250/scikit_learn-1.7.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96dc05a854add0e50d3f47a1ef21a10a595016da5b007c7d9cd9d0bffd1fcc61", size = 9312593 }, + { url = "https://files.pythonhosted.org/packages/8e/87/24f541b6d62b1794939ae6422f8023703bbf6900378b2b34e0b4384dfefd/scikit_learn-1.7.2-cp314-cp314-win_amd64.whl", hash = "sha256:bb24510ed3f9f61476181e4db51ce801e2ba37541def12dc9333b946fc7a9cf8", size = 8820007 }, +] + [[package]] name = "scipy" version = "1.15.3" @@ -1344,6 +1394,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 }, ] +[[package]] +name = "threadpoolctl" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638 }, +] + [[package]] name = "tifffile" version = "2025.3.30"