2) added wrappers for optimizer and schedule; 3) changed parameters for losses.master
parent
78f97a72a2
commit
5db2220917
@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
|
||||
class BaseOptimizer:
|
||||
"""Custom loss function combining BCEWithLogitsLoss and MSE losses for cell recognition and distinction."""
|
||||
|
||||
def __init__(self, model_params: Iterable[torch.nn.Parameter], optim_params: BaseModel):
|
||||
super().__init__()
|
||||
self.optim: Optional[optim.Optimizer] = None
|
||||
|
||||
|
||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||
"""
|
||||
Clears the gradients of all optimized tensors.
|
||||
|
||||
Args:
|
||||
set_to_none (bool): If True, sets gradients to None instead of zero.
|
||||
This can reduce memory usage and improve performance.
|
||||
(Introduced in PyTorch 1.7+)
|
||||
"""
|
||||
if self.optim is not None:
|
||||
self.optim.zero_grad(set_to_none=set_to_none)
|
||||
|
||||
|
||||
def step(self, closure: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
Performs a single optimization step (parameter update).
|
||||
|
||||
Args:
|
||||
closure (Optional[Callable]): A closure that reevaluates the model and returns the loss.
|
||||
This is required for optimizers like LBFGS that need multiple forward passes.
|
||||
|
||||
Returns:
|
||||
Any: The return value depends on the specific optimizer implementation.
|
||||
"""
|
||||
if self.optim is not None:
|
||||
return self.optim.step(closure=closure)
|
@ -0,0 +1,27 @@
|
||||
import torch.optim as optim
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class BaseScheduler:
|
||||
"""
|
||||
Abstract base class for learning rate schedulers.
|
||||
Wraps a PyTorch LR scheduler and provides a unified interface.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer: optim.Optimizer, params: BaseModel):
|
||||
self.scheduler: Optional[optim.lr_scheduler.LRScheduler] = None
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
Performs a single scheduler step. This typically updates the learning rate
|
||||
based on the current epoch or step count.
|
||||
"""
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.step()
|
||||
|
||||
def get_last_lr(self) -> List[float]:
|
||||
"""
|
||||
Returns the most recent learning rate(s).
|
||||
"""
|
||||
return self.scheduler.get_last_lr() if self.scheduler else []
|
Loading…
Reference in new issue