You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
3.8 KiB

from .base import *
from .bce import BCELossParams
from .mse import MSELossParams
from pydantic import BaseModel, ConfigDict
class BCE_MSE_LossParams(BaseModel):
"""
Class for handling parameters for `nn.MSELoss` with `nn.BCELoss`.
"""
model_config = ConfigDict(frozen=True)
num_classes: int = 1
bce_params: BCELossParams = BCELossParams()
mse_params: MSELossParams = MSELossParams()
def asdict(self) -> Dict[str, Any]:
"""
Returns a dictionary of valid parameters for `nn.BCELoss` and `nn.MSELoss`.
Returns:
Dict[str, Any]: Dictionary of parameters.
"""
return {
"num_classes": self.num_classes,
"bce_params": self.bce_params.asdict(),
"mse_params": self.mse_params.asdict()
}
class BCE_MSE_Loss(BaseLoss):
"""
Custom loss function combining BCE (with or without logits) and MSE losses for cell recognition and distinction.
"""
def __init__(self, params: Optional[BCE_MSE_LossParams]):
"""
Initializes the loss function with optional BCE and MSE parameters.
"""
super().__init__(params=params)
_params = params if params is not None else BCE_MSE_LossParams()
self.num_classes = _params.num_classes
# Process BCE parameters
_bce_params = _params.bce_params.asdict()
# Choose BCE loss function
self.bce_loss = (
nn.BCEWithLogitsLoss(**_bce_params) if _params.bce_params.with_logits else nn.BCELoss(**_bce_params)
)
# Process MSE parameters
_mse_params = _params.mse_params.asdict()
# Initialize MSE loss
self.mse_loss = nn.MSELoss(**_mse_params)
# Using CumulativeAverage from MONAI to track loss metrics
self.loss_bce_metric = CumulativeAverage()
self.loss_mse_metric = CumulativeAverage()
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the loss between true labels and prediction outputs.
Args:
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
target (torch.Tensor): Ground truth labels of shape (batch_size, channels, H, W).
Returns:
torch.Tensor: The total loss value.
"""
# Ensure target is on the same device as outputs
assert (
target.device == outputs.device
), (
"Target tensor must be moved to the same device as outputs "
"before calling forward()."
)
# Cell Recognition Loss
cellprob_loss = self.bce_loss(
outputs[:, -self.num_classes:], target[:, self.num_classes:2 * self.num_classes].float()
)
# Cell Distinction Loss
gradflow_loss = 0.5 * self.mse_loss(
outputs[:, :2 * self.num_classes], 5.0 * target[:, 2 * self.num_classes:]
)
# Total loss
total_loss = cellprob_loss + gradflow_loss
# Track individual losses
self.loss_bce_metric.append(cellprob_loss.item())
self.loss_mse_metric.append(gradflow_loss.item())
return total_loss
def get_loss_metrics(self) -> Dict[str, float]:
"""
Retrieves the tracked loss metrics.
Returns:
Dict[str, float]: A dictionary containing the average BCE and MSE loss.
"""
return {
"bce_loss": round(self.loss_bce_metric.aggregate().item(), 4),
"mse_loss": round(self.loss_mse_metric.aggregate().item(), 4),
"loss": round(
self.loss_bce_metric.aggregate().item() + self.loss_mse_metric.aggregate().item(), 4
),
}
def reset_metrics(self):
"""Resets the stored loss metrics."""
self.loss_bce_metric.reset()
self.loss_mse_metric.reset()