Removed tta argument from common;

Fix issue with save-masks arg in argparse;
Added metric saving to cvs file.
master
laynholt 4 months ago
parent 28f978956c
commit 0f2befc4e5

@ -9,7 +9,6 @@ class DatasetCommonConfig(BaseModel):
"""
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars'
predictions_dir: str = "." # Directory to save predictions

@ -1,4 +1,3 @@
import time
import random
import numpy as np
from numba import njit, prange
@ -26,7 +25,9 @@ import matplotlib.colors as mcolors
import os
import glob
import csv
import copy
import time
import tifffile as tiff
from pprint import pformat
@ -373,6 +374,10 @@ class CellSegmentator:
if self._test_dataloader is not None:
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0)
save_path = self._dataset_setup.common.predictions_dir
os.makedirs(save_path, exist_ok=True)
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None:
@ -386,6 +391,10 @@ class CellSegmentator:
"""
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0)
save_path = self._dataset_setup.common.predictions_dir
os.makedirs(save_path, exist_ok=True)
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
def predict(self, only_masks: bool = False) -> None:
@ -446,6 +455,12 @@ class CellSegmentator:
"""
start_time = time.time()
logger.info(
f"Masks saving: {'enabled' if save_results else 'disabled'}; "
f"Additional visualizations: "
f"{'enabled' if save_results and not only_masks else 'disabled'}"
)
# 1) TRAINING PATH
if self._dataset_setup.is_training:
# Launch the full training loop (with validation, scheduler steps, etc.)
@ -778,17 +793,17 @@ class CellSegmentator:
return Dataset(data, transforms)
def __print_with_logging(self, results: Dict[str, Union[float, np.ndarray]], step: int) -> None:
def __print_with_logging(self, metrics: Dict[str, Union[float, np.ndarray]], step: int) -> None:
"""
Print metrics in a tabular format and log to W&B.
Args:
results (Dict[str, Union[float, np.ndarray]]): Mapping from metric names
metrics (Dict[str, Union[float, np.ndarray]]): Mapping from metric names
to either a float or a ND numpy array.
step (int): epoch index.
"""
rows: list[tuple[str, str]] = []
for key, val in results.items():
for key, val in metrics.items():
if isinstance(val, np.ndarray):
# Convert array to string, e.g. '[0.2, 0.8, 0.5]'
val_str = np.array2string(val, separator=', ')
@ -808,7 +823,7 @@ class CellSegmentator:
if self._wandb_config.use_wandb:
# Keep only scalar values
scalar_results: dict[str, float] = {}
for key, val in results.items():
for key, val in metrics.items():
if isinstance(val, np.ndarray):
continue
# Ensure float type
@ -816,6 +831,34 @@ class CellSegmentator:
wandb.log(scalar_results, step=step)
def __save_metrics_to_csv(
self,
metrics: Dict[str, Union[float, np.ndarray]],
output_path: str
) -> None:
"""
Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'.
Args:
metrics (Dict[str, Union[float, np.ndarray]]):
Mapping from metric names to scalar values or numpy arrays.
output_path (str):
Path to the output CSV file.
"""
with open(output_path, mode='w', newline='') as csv_file:
writer = csv.writer(csv_file)
writer.writerow(['Metric', 'Value'])
for name, value in metrics.items():
# Convert numpy arrays to string representation
if isinstance(value, np.ndarray):
# Flatten and join with commas
flat = value.flatten()
val_str = ','.join([f"{v}" for v in flat])
else:
val_str = f"{value}"
writer.writerow([name, val_str])
def __run_epoch(self,
mode: Literal["train", "valid", "test"],
epoch: Optional[int] = None,
@ -900,7 +943,7 @@ class CellSegmentator:
predicted_masks=preds,
ground_truth_masks=labels_post, # type: ignore
iou_threshold=0.5,
return_error_masks=(mode == "test")
return_error_masks=(mode == "test") and save_results is True
)
all_tp.append(tp)
all_fp.append(fp)

@ -24,17 +24,16 @@ def main():
help='Run mode: train, test or predict'
)
parser.add_argument(
'-s', '--save-masks',
action='store_true',
default=True,
help='If set to False, do not save predicted masks; by default, saving is enabled'
'--no-save-masks',
action='store_false',
dest='save_masks',
help='If set, do NOT save predicted masks (saving is enabled by default)'
)
parser.add_argument(
'--only-masks',
action='store_true',
default=False,
help=('If set and save-masks set, save only the raw predicted'
' masks without additional visualizations or metrics')
' masks without additional visualizations')
)
args = parser.parse_args()

Loading…
Cancel
Save