You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.4 KiB
40 lines
1.4 KiB
import torch
|
|
import torch.optim as optim
|
|
from pydantic import BaseModel
|
|
from typing import Any, Iterable, Optional
|
|
|
|
|
|
class BaseOptimizer:
|
|
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
|
|
|
|
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel):
|
|
super().__init__()
|
|
self.optim: Optional[optim.Optimizer] = None
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = True) -> None:
|
|
"""
|
|
Clears the gradients of all optimized tensors.
|
|
|
|
Args:
|
|
set_to_none (bool): If True, sets gradients to None instead of zero.
|
|
This can reduce memory usage and improve performance.
|
|
(Introduced in PyTorch 1.7+)
|
|
"""
|
|
if self.optim is not None:
|
|
self.optim.zero_grad(set_to_none=set_to_none)
|
|
|
|
|
|
def step(self, closure: Optional[Any] = None) -> Any:
|
|
"""
|
|
Performs a single optimization step (parameter update).
|
|
|
|
Args:
|
|
closure (Optional[Callable]): A closure that reevaluates the model and returns the loss.
|
|
This is required for optimizers like LBFGS that need multiple forward passes.
|
|
|
|
Returns:
|
|
Any: The return value depends on the specific optimizer implementation.
|
|
"""
|
|
if self.optim is not None:
|
|
return self.optim.step(closure=closure) |