You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							108 lines
						
					
					
						
							3.8 KiB
						
					
					
				
			
		
		
	
	
							108 lines
						
					
					
						
							3.8 KiB
						
					
					
				| from .base import *
 | |
| from typing import List, Literal, Union
 | |
| from pydantic import BaseModel, ConfigDict
 | |
| 
 | |
| 
 | |
| class BCELossParams(BaseModel):
 | |
|     """
 | |
|     Class for handling parameters for both `nn.BCELoss` and `nn.BCEWithLogitsLoss`.
 | |
|     """
 | |
|     model_config = ConfigDict(frozen=True)
 | |
|     
 | |
|     weight: Optional[List[Union[int, float]]] = None  # Sample weights
 | |
|     reduction: Literal["none", "mean", "sum"] = "mean"  # Reduction method
 | |
|     pos_weight: Optional[List[Union[int, float]]] = None  # Used only for BCEWithLogitsLoss
 | |
| 
 | |
|     def asdict(self, with_logits: bool = False) -> Dict[str, Any]:
 | |
|         """
 | |
|         Returns a dictionary of valid parameters for `nn.BCEWithLogitsLoss` and `nn.BCELoss`.
 | |
| 
 | |
|         - If `with_logits=False`, `pos_weight` is **removed** to avoid errors.
 | |
|         - Ensures only the valid parameters are passed based on the loss function.
 | |
| 
 | |
|         Args:
 | |
|             with_logits (bool): If `True`, includes `pos_weight` (for `nn.BCEWithLogitsLoss`).
 | |
|                                If `False`, removes `pos_weight` (for `nn.BCELoss`).
 | |
| 
 | |
|         Returns:
 | |
|             Dict[str, Any]: Filtered dictionary of parameters.
 | |
|         """
 | |
|         loss_kwargs = self.model_dump()
 | |
|         if not with_logits:
 | |
|             loss_kwargs.pop("pos_weight", None)  # Remove pos_weight if using BCELoss
 | |
|         
 | |
|         weight = loss_kwargs.get("weight")
 | |
|         pos_weight = loss_kwargs.get("pos_weight")
 | |
|         
 | |
|         if weight is not None:
 | |
|             loss_kwargs["weight"] = torch.Tensor(weight)
 | |
|             
 | |
|         if pos_weight is not None:
 | |
|             loss_kwargs["pos_weight"] = torch.Tensor(pos_weight) 
 | |
|         
 | |
|         return {k: v for k, v in loss_kwargs.items() if v is not None}  # Remove None values
 | |
| 
 | |
| 
 | |
| class BCELoss(BaseLoss):
 | |
|     """
 | |
|     Custom loss function wrapper for `nn.BCELoss and nn.BCEWithLogitsLoss` with tracking of loss metrics.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, bce_params: Optional[BCELossParams] = None, with_logits: bool = False):
 | |
|         """
 | |
|         Initializes the loss function with optional BCELoss parameters.
 | |
| 
 | |
|         Args:
 | |
|             bce_params (Optional[Dict[str, Any]]): Parameters for nn.BCELoss (default: None).
 | |
|         """
 | |
|         super().__init__()
 | |
|         _bce_params = bce_params.asdict(with_logits=with_logits) if bce_params is not None else {}
 | |
|         
 | |
|         # Initialize loss functions with user-provided parameters or PyTorch defaults
 | |
|         self.bce_loss = nn.BCEWithLogitsLoss(**_bce_params) if with_logits else nn.BCELoss(**_bce_params)
 | |
|         
 | |
|         # Using CumulativeAverage from MONAI to track loss metrics
 | |
|         self.loss_bce_metric = CumulativeAverage()
 | |
| 
 | |
| 
 | |
|     def forward(self, outputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
 | |
|         """
 | |
|         Computes the loss between true labels and prediction outputs.
 | |
| 
 | |
|         Args:
 | |
|             outputs (torch.Tensor): Model predictions of shape (batch_size, channels, H, W).
 | |
|             target (torch.Tensor): Ground truth labels in one-hot format.
 | |
| 
 | |
|         Returns:
 | |
|             torch.Tensor: The total loss value.
 | |
|         """
 | |
|         # Ensure target is on the same device as outputs
 | |
|         assert (
 | |
|             target.device == outputs.device
 | |
|         ), (
 | |
|             "Target tensor must be moved to the same device as outputs "
 | |
|             "before calling forward()."
 | |
|         )
 | |
|         
 | |
|         loss = self.bce_loss(outputs, target)
 | |
|         self.loss_bce_metric.append(loss.item())
 | |
|         
 | |
|         return loss
 | |
| 
 | |
| 
 | |
|     def get_loss_metrics(self) -> Dict[str, float]:
 | |
|         """
 | |
|         Retrieves the tracked loss metrics.
 | |
| 
 | |
|         Returns:
 | |
|             Dict[str, float]: A dictionary containing the average BCE loss.
 | |
|         """
 | |
|         return {
 | |
|             "loss": round(self.loss_bce_metric.aggregate().item(), 4),
 | |
|         }
 | |
| 
 | |
| 
 | |
|     def reset_metrics(self):
 | |
|         """Resets the stored loss metrics."""
 | |
|         self.loss_bce_metric.reset()
 |