Removed tta argument from common;

Fix issue with save-masks arg in argparse;
Added metric saving to cvs file.
master
laynholt 4 months ago
parent 28f978956c
commit 0f2befc4e5

@ -9,7 +9,6 @@ class DatasetCommonConfig(BaseModel):
""" """
seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations) seed: Optional[int] = 0 # Seed for splitting if data is not pre-split (and all random operations)
device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda') device: str = "cuda:0" # Device used for training/testing (e.g., 'cpu' or 'cuda')
use_tta: bool = False # Flag to use Test-Time Augmentation (TTA)
use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP) use_amp: bool = False # Flag to use Automatic Mixed Precision (AMP)
masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars' masks_subdir: str = "" # Subdirectory where the required masks are located, e.g. 'masks/cars'
predictions_dir: str = "." # Directory to save predictions predictions_dir: str = "." # Directory to save predictions

@ -1,4 +1,3 @@
import time
import random import random
import numpy as np import numpy as np
from numba import njit, prange from numba import njit, prange
@ -26,7 +25,9 @@ import matplotlib.colors as mcolors
import os import os
import glob import glob
import csv
import copy import copy
import time
import tifffile as tiff import tifffile as tiff
from pprint import pformat from pprint import pformat
@ -373,6 +374,10 @@ class CellSegmentator:
if self._test_dataloader is not None: if self._test_dataloader is not None:
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks) test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0) self.__print_with_logging(test_metrics, 0)
save_path = self._dataset_setup.common.predictions_dir
os.makedirs(save_path, exist_ok=True)
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None: def evaluate(self, save_results: bool = True, only_masks: bool = False) -> None:
@ -386,6 +391,10 @@ class CellSegmentator:
""" """
test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks) test_metrics = self.__run_epoch("test", save_results=save_results, only_masks=only_masks)
self.__print_with_logging(test_metrics, 0) self.__print_with_logging(test_metrics, 0)
save_path = self._dataset_setup.common.predictions_dir
os.makedirs(save_path, exist_ok=True)
self.__save_metrics_to_csv(test_metrics, os.path.join(save_path, 'metrics.csv'))
def predict(self, only_masks: bool = False) -> None: def predict(self, only_masks: bool = False) -> None:
@ -446,6 +455,12 @@ class CellSegmentator:
""" """
start_time = time.time() start_time = time.time()
logger.info(
f"Masks saving: {'enabled' if save_results else 'disabled'}; "
f"Additional visualizations: "
f"{'enabled' if save_results and not only_masks else 'disabled'}"
)
# 1) TRAINING PATH # 1) TRAINING PATH
if self._dataset_setup.is_training: if self._dataset_setup.is_training:
# Launch the full training loop (with validation, scheduler steps, etc.) # Launch the full training loop (with validation, scheduler steps, etc.)
@ -778,17 +793,17 @@ class CellSegmentator:
return Dataset(data, transforms) return Dataset(data, transforms)
def __print_with_logging(self, results: Dict[str, Union[float, np.ndarray]], step: int) -> None: def __print_with_logging(self, metrics: Dict[str, Union[float, np.ndarray]], step: int) -> None:
""" """
Print metrics in a tabular format and log to W&B. Print metrics in a tabular format and log to W&B.
Args: Args:
results (Dict[str, Union[float, np.ndarray]]): Mapping from metric names metrics (Dict[str, Union[float, np.ndarray]]): Mapping from metric names
to either a float or a ND numpy array. to either a float or a ND numpy array.
step (int): epoch index. step (int): epoch index.
""" """
rows: list[tuple[str, str]] = [] rows: list[tuple[str, str]] = []
for key, val in results.items(): for key, val in metrics.items():
if isinstance(val, np.ndarray): if isinstance(val, np.ndarray):
# Convert array to string, e.g. '[0.2, 0.8, 0.5]' # Convert array to string, e.g. '[0.2, 0.8, 0.5]'
val_str = np.array2string(val, separator=', ') val_str = np.array2string(val, separator=', ')
@ -808,7 +823,7 @@ class CellSegmentator:
if self._wandb_config.use_wandb: if self._wandb_config.use_wandb:
# Keep only scalar values # Keep only scalar values
scalar_results: dict[str, float] = {} scalar_results: dict[str, float] = {}
for key, val in results.items(): for key, val in metrics.items():
if isinstance(val, np.ndarray): if isinstance(val, np.ndarray):
continue continue
# Ensure float type # Ensure float type
@ -816,6 +831,34 @@ class CellSegmentator:
wandb.log(scalar_results, step=step) wandb.log(scalar_results, step=step)
def __save_metrics_to_csv(
self,
metrics: Dict[str, Union[float, np.ndarray]],
output_path: str
) -> None:
"""
Saves a dictionary of metrics to a CSV file with columns 'Metric' and 'Value'.
Args:
metrics (Dict[str, Union[float, np.ndarray]]):
Mapping from metric names to scalar values or numpy arrays.
output_path (str):
Path to the output CSV file.
"""
with open(output_path, mode='w', newline='') as csv_file:
writer = csv.writer(csv_file)
writer.writerow(['Metric', 'Value'])
for name, value in metrics.items():
# Convert numpy arrays to string representation
if isinstance(value, np.ndarray):
# Flatten and join with commas
flat = value.flatten()
val_str = ','.join([f"{v}" for v in flat])
else:
val_str = f"{value}"
writer.writerow([name, val_str])
def __run_epoch(self, def __run_epoch(self,
mode: Literal["train", "valid", "test"], mode: Literal["train", "valid", "test"],
epoch: Optional[int] = None, epoch: Optional[int] = None,
@ -900,7 +943,7 @@ class CellSegmentator:
predicted_masks=preds, predicted_masks=preds,
ground_truth_masks=labels_post, # type: ignore ground_truth_masks=labels_post, # type: ignore
iou_threshold=0.5, iou_threshold=0.5,
return_error_masks=(mode == "test") return_error_masks=(mode == "test") and save_results is True
) )
all_tp.append(tp) all_tp.append(tp)
all_fp.append(fp) all_fp.append(fp)

@ -24,17 +24,16 @@ def main():
help='Run mode: train, test or predict' help='Run mode: train, test or predict'
) )
parser.add_argument( parser.add_argument(
'-s', '--save-masks', '--no-save-masks',
action='store_true', action='store_false',
default=True, dest='save_masks',
help='If set to False, do not save predicted masks; by default, saving is enabled' help='If set, do NOT save predicted masks (saving is enabled by default)'
) )
parser.add_argument( parser.add_argument(
'--only-masks', '--only-masks',
action='store_true', action='store_true',
default=False,
help=('If set and save-masks set, save only the raw predicted' help=('If set and save-masks set, save only the raw predicted'
' masks without additional visualizations or metrics') ' masks without additional visualizations')
) )
args = parser.parse_args() args = parser.parse_args()

Loading…
Cancel
Save