You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
3.8 KiB
108 lines
3.8 KiB
from .base import *
|
|
from typing import List, Literal, Union
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
|
|
class BCELossParams(BaseModel):
|
|
"""
|
|
Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`.
|
|
"""
|
|
model_config = ConfigDict(frozen=True)
|
|
|
|
weight: Optional[List[Union[int, float]]] = None # Sample weights
|
|
reduction: Literal["none", "mean", "sum"] = "mean" # Reduction method
|
|
pos_weight: Optional[List[Union[int, float]]] = None # Used only for BCEWithLogitsLoss
|
|
|
|
def asdict(self, with_logits: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`.
|
|
|
|
- If `with_logits=False`, `pos_weight` is **removed** to avoid errors.
|
|
- Ensures only the valid parameters are passed based on the loss function.
|
|
|
|
Args:
|
|
with_logits (bool): If `True`, includes `pos_weight` (for `nn.BCEWithLogitsLoss`).
|
|
If `False`, removes `pos_weight` (for `nn.BCELoss`).
|
|
|
|
Returns:
|
|
Dict[str, Any]: Filtered dictionary of parameters.
|
|
"""
|
|
loss_kwargs = self.model_dump()
|
|
if not with_logits:
|
|
loss_kwargs.pop("pos_weight", None) # Remove pos_weight if using BCELoss
|
|
|
|
weight = loss_kwargs.get("weight")
|
|
pos_weight = loss_kwargs.get("pos_weight")
|
|
|
|
if weight is not None:
|
|
loss_kwargs["weight"] = torch.Tensor(weight)
|
|
|
|
if pos_weight is not None:
|
|
loss_kwargs["pos_weight"] = torch.Tensor(pos_weight)
|
|
|
|
return {k: v for k, v in loss_kwargs.items() if v is not None} # Remove None values
|
|
|
|
|
|
class BCELoss(BaseLoss):
|
|
"""
|
|
Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics.
|
|
"""
|
|
|
|
def __init__(self, bce_params: Optional[BCELossParams] = None, with_logits: bool = False):
|
|
"""
|
|
Initializes the loss function with optional BCELoss parameters.
|
|
|
|
Args:
|
|
bce_params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None).
|
|
"""
|
|
super().__init__()
|
|
_bce_params = bce_params.asdict(with_logits=with_logits) if bce_params is not None else {}
|
|
|
|
# Initialize loss functions with user-provided parameters or PyTorch defaults
|
|
self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params)
|
|
|
|
# Using CumulativeAverage from MONAI to track loss metrics
|
|
self.loss_bce_metric = CumulativeAverage()
|
|
|
|
|
|
def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Computes the loss between true labels and prediction outputs.
|
|
|
|
Args:
|
|
outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
|
|
target (torch.Tensor): Ground truth labels in one-hot format.
|
|
|
|
Returns:
|
|
torch.Tensor: The total loss value.
|
|
"""
|
|
# Ensure target is on the same device as outputs
|
|
assert (
|
|
target.device == outputs.device
|
|
), (
|
|
"Target tensor must be moved to the same device as outputs "
|
|
"before calling forward()."
|
|
)
|
|
|
|
loss = self.bce_loss(outputs, target)
|
|
self.loss_bce_metric.append(loss.item())
|
|
|
|
return loss
|
|
|
|
|
|
def get_loss_metrics(self) -> Dict[str, float]:
|
|
"""
|
|
Retrieves the tracked loss metrics.
|
|
|
|
Returns:
|
|
Dict[str, float]: A dictionary containing the average BCE loss.
|
|
"""
|
|
return {
|
|
"loss": round(self.loss_bce_metric.aggregate().item(), 4),
|
|
}
|
|
|
|
|
|
def reset_metrics(self):
|
|
"""Resets the stored loss metrics."""
|
|
self.loss_bce_metric.reset()
|