|
|
|
@ -1,4 +1,3 @@
|
|
|
|
|
import time
|
|
|
|
|
import random
|
|
|
|
|
import numpy as np
|
|
|
|
|
from numba import njit, prange
|
|
|
|
@ -26,7 +25,9 @@ import matplotlib.colors as mcolors
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import glob
|
|
|
|
|
import csv
|
|
|
|
|
import copy
|
|
|
|
|
import time
|
|
|
|
|
import tifffile as tiff
|
|
|
|
|
|
|
|
|
|
from pprint import pformat
|
|
|
|
@ -374,6 +375,10 @@ class CellSegmentator:
|
|
|
|
|
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
|
|
|
|
|
self.__print_with_logging(test_metrics, 0)
|
|
|
|
|
|
|
|
|
|
save_path = self._dataset_setup.common.predictions_dir
|
|
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
|
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None:
|
|
|
|
|
"""
|
|
|
|
@ -387,6 +392,10 @@ class CellSegmentator:
|
|
|
|
|
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
|
|
|
|
|
self.__print_with_logging(test_metrics, 0)
|
|
|
|
|
|
|
|
|
|
save_path = self._dataset_setup.common.predictions_dir
|
|
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
|
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, only_masks: bool = False) -> None:
|
|
|
|
|
"""
|
|
|
|
@ -446,6 +455,12 @@ class CellSegmentator:
|
|
|
|
|
"""
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Masks saving: {'enabled' if save_results else 'disabled'}; "
|
|
|
|
|
f"Additional visualizations: "
|
|
|
|
|
f"{'enabled' if save_results and not only_masks else 'disabled'}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 1) TRAINING PATH
|
|
|
|
|
if self._dataset_setup.is_training:
|
|
|
|
|
# Launch the full training loop (with validation, scheduler steps, etc.)
|
|
|
|
@ -778,17 +793,17 @@ class CellSegmentator:
|
|
|
|
|
return Dataset(data, transforms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __print_with_logging(self, results: Dict[str, Union[float, np.ndarray]], step: int) -> None:
|
|
|
|
|
def __print_with_logging(self, metrics: Dict[str, Union[float, np.ndarray]], step: int) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Print metrics in a tabular format and log to W&B.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
results (Dict[str, Union[float, np.ndarray]]): Mapping from metric names
|
|
|
|
|
metrics (Dict[str, Union[float, np.ndarray]]): Mapping from metric names
|
|
|
|
|
to either a float or a ND numpy array.
|
|
|
|
|
step (int): epoch index.
|
|
|
|
|
"""
|
|
|
|
|
rows: list[tuple[str, str]] = []
|
|
|
|
|
for key, val in results.items():
|
|
|
|
|
for key, val in metrics.items():
|
|
|
|
|
if isinstance(val, np.ndarray):
|
|
|
|
|
# Convert array to string, e.g. '[0.2, 0.8, 0.5]'
|
|
|
|
|
val_str = np.array2string(val, separator=', ')
|
|
|
|
@ -808,7 +823,7 @@ class CellSegmentator:
|
|
|
|
|
if self._wandb_config.use_wandb:
|
|
|
|
|
# Keep only scalar values
|
|
|
|
|
scalar_results: dict[str, float] = {}
|
|
|
|
|
for key, val in results.items():
|
|
|
|
|
for key, val in metrics.items():
|
|
|
|
|
if isinstance(val, np.ndarray):
|
|
|
|
|
continue
|
|
|
|
|
# Ensure float type
|
|
|
|
@ -816,6 +831,34 @@ class CellSegmentator:
|
|
|
|
|
wandb.log(scalar_results, step=step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __save_metrics_to_csv(
|
|
|
|
|
self,
|
|
|
|
|
metrics: Dict[str, Union[float, np.ndarray]],
|
|
|
|
|
output_path: str
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
metrics (Dict[str, Union[float, np.ndarray]]):
|
|
|
|
|
Mapping from metric names to scalar values or numpy arrays.
|
|
|
|
|
output_path (str):
|
|
|
|
|
Path to the output CSV file.
|
|
|
|
|
"""
|
|
|
|
|
with open(output_path, mode='w', newline='') as csv_file:
|
|
|
|
|
writer = csv.writer(csv_file)
|
|
|
|
|
writer.writerow(['Metric', 'Value'])
|
|
|
|
|
for name, value in metrics.items():
|
|
|
|
|
# Convert numpy arrays to string representation
|
|
|
|
|
if isinstance(value, np.ndarray):
|
|
|
|
|
# Flatten and join with commas
|
|
|
|
|
flat = value.flatten()
|
|
|
|
|
val_str = ','.join([f"{v}" for v in flat])
|
|
|
|
|
else:
|
|
|
|
|
val_str = f"{value}"
|
|
|
|
|
writer.writerow([name, val_str])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __run_epoch(self,
|
|
|
|
|
mode: Literal["train", "valid", "test"],
|
|
|
|
|
epoch: Optional[int] = None,
|
|
|
|
@ -900,7 +943,7 @@ class CellSegmentator:
|
|
|
|
|
predicted_masks=preds,
|
|
|
|
|
ground_truth_masks=labels_post, # type: ignore
|
|
|
|
|
iou_threshold=0.5,
|
|
|
|
|
return_error_masks=(mode == "test")
|
|
|
|
|
return_error_masks=(mode == "test") and save_results is True
|
|
|
|
|
)
|
|
|
|
|
all_tp.append(tp)
|
|
|
|
|
all_fp.append(fp)
|
|
|
|
|