You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
106 lines
3.6 KiB
106 lines
3.6 KiB
from .base import *
|
|
from .bce import BCELossParams
|
|
from .mse import MSELossParams
|
|
|
|
|
|
class BCE_MSE_Loss(BaseLoss):
|
|
"""
|
|
Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_classes: int,
|
|
bce_params: Optional[BCELossParams] = None,
|
|
mse_params: Optional[MSELossParams] = None,
|
|
bce_with_logits: bool = False,
|
|
):
|
|
"""
|
|
Initializes the loss function with optional BCE and MSE parameters.
|
|
|
|
Args:
|
|
num_classes (int): Number of output classes, used for target shifting.
|
|
bce_params (Optional[BCELossParams]): Parameters for BCEWithLogitsLoss or BCELoss (default: None).
|
|
mse_params (Optional[MSELossParams]): Parameters for MSELoss (default: None).
|
|
bce_with_logits (bool): If True, uses BCEWithLogitsLoss; otherwise, uses BCELoss.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
# Process BCE parameters
|
|
_bce_params = bce_params.asdict(bce_with_logits) if bce_params is not None else {}
|
|
|
|
# Choose BCE loss function
|
|
self.bce_loss = (
|
|
nn.BCEWithLogitsLoss(**_bce_params) if bce_with_logits else nn.BCELoss(**_bce_params)
|
|
)
|
|
|
|
# Process MSE parameters
|
|
_mse_params = mse_params.asdict() if mse_params is not None else {}
|
|
|
|
# Initialize MSE loss
|
|
self.mse_loss = nn.MSELoss(**_mse_params)
|
|
|
|
# Using CumulativeAverage from MONAI to track loss metrics
|
|
self.loss_bce_metric = CumulativeAverage()
|
|
self.loss_mse_metric = CumulativeAverage()
|
|
|
|
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Computes the loss between true labels and prediction outputs.
|
|
|
|
Args:
|
|
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
|
|
target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W).
|
|
|
|
Returns:
|
|
torch.Tensor: The total loss value.
|
|
"""
|
|
# Ensure target is on the same device as outputs
|
|
assert (
|
|
target.device == outputs.device
|
|
), (
|
|
"Target tensor must be moved to the same device as outputs "
|
|
"before calling forward()."
|
|
)
|
|
|
|
# Cell Recognition Loss
|
|
cellprob_loss = self.bce_loss(
|
|
outputs[:, -self.num_classes:], target[:, self.num_classes:2 * self.num_classes].float()
|
|
)
|
|
|
|
# Cell Distinction Loss
|
|
gradflow_loss = 0.5 * self.mse_loss(
|
|
outputs[:, :2 * self.num_classes], 5.0 * target[:, 2 * self.num_classes:]
|
|
)
|
|
|
|
# Total loss
|
|
total_loss = cellprob_loss + gradflow_loss
|
|
|
|
# Track individual losses
|
|
self.loss_bce_metric.append(cellprob_loss.item())
|
|
self.loss_mse_metric.append(gradflow_loss.item())
|
|
|
|
return total_loss
|
|
|
|
def get_loss_metrics(self) -> Dict[str, float]:
|
|
"""
|
|
Retrieves the tracked loss metrics.
|
|
|
|
Returns:
|
|
Dict[str, float]: A dictionary containing the average BCE and MSE loss.
|
|
"""
|
|
return {
|
|
"bce_loss": round(self.loss_bce_metric.aggregate().item(), 4),
|
|
"mse_loss": round(self.loss_mse_metric.aggregate().item(), 4),
|
|
"loss": round(
|
|
self.loss_bce_metric.aggregate().item() + self.loss_mse_metric.aggregate().item(), 4
|
|
),
|
|
}
|
|
|
|
def reset_metrics(self):
|
|
"""Resets the stored loss metrics."""
|
|
self.loss_bce_metric.reset()
|
|
self.loss_mse_metric.reset()
|