You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
1.1 KiB

import abc
import torch
from pydantic import BaseModel
class BaseLoss(torch.nn.Module, abc.ABC):
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, params: BaseModel | None = None) -> None:
super().__init__()
@abc.abstractmethod
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the loss between true labels and prediction outputs.
Args:
outputs (torch.Tensor): Model predictions.
target (torch.Tensor): Ground truth.
Returns:
torch.Tensor: The total loss value.
"""
@abc.abstractmethod
def get_loss_metrics(self) -> dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
dict(str, float): A dictionary containing the loss name and average loss value.
"""
@abc.abstractmethod
def reset_metrics(self) -> None:
"""Resets the stored loss metrics."""