import torch import torch.optim as optim from pydantic import BaseModel from typing import Any, Iterable, Optional class BaseOptimizer: """Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction.""" def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel): super().__init__() self.optim: Optional[optim.Optimizer] = None def zero_grad(self, set_to_none: bool = True) -> None: """ Clears the gradients of all optimized tensors. Args: set_to_none (bool): If True, sets gradients to None instead of zero. This can reduce memory usage and improve performance. (Introduced in PyTorch 1.7+) """ if self.optim is not None: self.optim.zero_grad(set_to_none=set_to_none) def step(self, closure: Optional[Any] = None) -> Any: """ Performs a single optimization step (parameter update). Args: closure (Optional[Callable]): A closure that reevaluates the model and returns the loss. This is required for optimizers like LBFGS that need multiple forward passes. Returns: Any: The return value depends on the specific optimizer implementation. """ if self.optim is not None: return self.optim.step(closure=closure)