loss.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import torch
  2. from bicubic import BicubicDownSample
  3. class LossBuilder(torch.nn.Module):
  4. def __init__(self, ref_im, loss_str, eps):
  5. super(LossBuilder, self).__init__()
  6. assert ref_im.shape[2]==ref_im.shape[3]
  7. im_size = ref_im.shape[2]
  8. factor=1024//im_size
  9. assert im_size*factor==1024
  10. self.D = BicubicDownSample(factor=factor)
  11. self.ref_im = ref_im
  12. self.parsed_loss = [loss_term.split('*') for loss_term in loss_str.split('+')]
  13. self.eps = eps
  14. # Takes a list of tensors, flattens them, and concatenates them into a vector
  15. # Used to calculate euclidian distance between lists of tensors
  16. def flatcat(self, l):
  17. l = l if(isinstance(l, list)) else [l]
  18. return torch.cat([x.flatten() for x in l], dim=0)
  19. def _loss_l2(self, gen_im_lr, ref_im, **kwargs):
  20. return ((gen_im_lr - ref_im).pow(2).mean((1, 2, 3)).clamp(min=self.eps).sum())
  21. def _loss_l1(self, gen_im_lr, ref_im, **kwargs):
  22. return 10*((gen_im_lr - ref_im).abs().mean((1, 2, 3)).clamp(min=self.eps).sum())
  23. # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors
  24. def _loss_geocross(self, latent, **kwargs):
  25. if(latent.shape[1] == 1):
  26. return 0
  27. else:
  28. X = latent.view(-1, 1, 18, 512)
  29. Y = latent.view(-1, 18, 1, 512)
  30. A = ((X-Y).pow(2).sum(-1)+1e-9).sqrt()
  31. B = ((X+Y).pow(2).sum(-1)+1e-9).sqrt()
  32. D = 2*torch.atan2(A, B)
  33. D = ((D.pow(2)*512).mean((1, 2))/8.).sum()
  34. return D
  35. def forward(self, latent, gen_im):
  36. var_dict = {'latent': latent,
  37. 'gen_im_lr': self.D(gen_im),
  38. 'ref_im': self.ref_im,
  39. }
  40. loss = 0
  41. loss_fun_dict = {
  42. 'L2': self._loss_l2,
  43. 'L1': self._loss_l1,
  44. 'GEOCROSS': self._loss_geocross,
  45. }
  46. losses = {}
  47. for weight, loss_type in self.parsed_loss:
  48. tmp_loss = loss_fun_dict[loss_type](**var_dict)
  49. losses[loss_type] = tmp_loss
  50. loss += float(weight)*tmp_loss
  51. return loss, losses