You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
1.4 KiB

import torch
from torch import optim
from pydantic import BaseModel
from typing import Any, Iterable
class BaseOptimizer:
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel) -> None:
super().__init__()
self.optim: optim.Optimizer | None = None
def zero_grad(self, set_to_none: bool = True) -> None:
"""
Clears the gradients of all optimized tensors.
Args:
set_to_none (bool): If True, sets gradients to None instead of zero.
This can reduce memory usage and improve performance.
(Introduced in PyTorch 1.7+)
"""
if self.optim is not None:
self.optim.zero_grad(set_to_none=set_to_none)
def step(self, closure: Any | None = None) -> Any:
"""
Performs a single optimization step (parameter update).
Args:
closure (Any | None): A closure that reevaluates the model and returns the loss.
This is required for optimizers like LBFGS that need multiple forward passes.
Returns:
Any: The return value depends on the specific optimizer implementation.
"""
if self.optim is not None:
return self.optim.step(closure=closure)