You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

44 lines
1.1 KiB

import abc
import torch
import torch.nn as nn
from typing import Dict, Any, Optional
from monai.metrics.cumulative_average import CumulativeAverage
class BaseLoss(abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self):
"""
"""
super().__init__()
@abc.abstractmethod
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the loss between true labels and prediction outputs.
Args:
outputs (torch.Tensor): Model predictions.
target (torch.Tensor): Ground truth.
Returns:
torch.Tensor: The total loss value.
"""
@abc.abstractmethod
def get_loss_metrics(self) -> Dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the loss name and average loss value.
"""
@abc.abstractmethod
def reset_metrics(self):
"""Resets the stored loss metrics."""