1234567891011121314151617181920212223242526 |
- import math
- import torch
- from torch.optim import Optimizer
- # Spherical Optimizer Class
- # Uses the first two dimensions as batch information
- # Optimizes over the surface of a sphere using the initial radius throughout
- #
- # Example Usage:
- # opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01)
- class SphericalOptimizer(Optimizer):
- def __init__(self, optimizer, params, **kwargs):
- self.opt = optimizer(params, **kwargs)
- self.params = params
- with torch.no_grad():
- self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params}
- @torch.no_grad()
- def step(self, closure=None):
- loss = self.opt.step(closure)
- for param in self.params:
- param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
- param.mul_(self.radii[param])
- return loss
|