from stylegan import G_synthesis,G_mapping from dataclasses import dataclass from SphericalOptimizer import SphericalOptimizer from pathlib import Path import numpy as np import time import torch from loss import LossBuilder from functools import partial from drive import open_url class PULSE(torch.nn.Module): def __init__(self, cache_dir, verbose=True): super(PULSE, self).__init__() self.synthesis = G_synthesis().cuda() self.verbose = verbose cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok = True) if self.verbose: print("Loading Synthesis Network") with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f: self.synthesis.load_state_dict(torch.load(f)) for param in self.synthesis.parameters(): param.requires_grad = False self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2) if Path("gaussian_fit.pt").exists(): self.gaussian_fit = torch.load("gaussian_fit.pt") else: if self.verbose: print("\tLoading Mapping Network") mapping = G_mapping().cuda() with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f: mapping.load_state_dict(torch.load(f)) if self.verbose: print("\tRunning Mapping Network") with torch.no_grad(): torch.manual_seed(0) latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda") latent_out = torch.nn.LeakyReLU(5)(mapping(latent)) self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} torch.save(self.gaussian_fit,"gaussian_fit.pt") if self.verbose: print("\tSaved \"gaussian_fit.pt\"") def forward(self, ref_im, seed, loss_str, eps, noise_type, num_trainable_noise_layers, tile_latent, bad_noise_layers, opt_name, learning_rate, steps, lr_schedule, save_intermediate, **kwargs): if seed: torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True batch_size = ref_im.shape[0] # Generate latent tensor if(tile_latent): latent = torch.randn( (batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda') else: latent = torch.randn( (batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda') # Generate list of noise tensors noise = [] # stores all of the noise tensors noise_vars = [] # stores the noise tensors that we want to optimize on for i in range(18): # dimension of the ith noise tensor res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2)) if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]): new_noise = torch.zeros(res, dtype=torch.float, device='cuda') new_noise.requires_grad = False elif(noise_type == 'fixed'): new_noise = torch.randn(res, dtype=torch.float, device='cuda') new_noise.requires_grad = False elif (noise_type == 'trainable'): new_noise = torch.randn(res, dtype=torch.float, device='cuda') if (i < num_trainable_noise_layers): new_noise.requires_grad = True noise_vars.append(new_noise) else: new_noise.requires_grad = False else: raise Exception("unknown noise type") noise.append(new_noise) var_list = [latent]+noise_vars opt_dict = { 'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'sgdm': partial(torch.optim.SGD, momentum=0.9), 'adamax': torch.optim.Adamax } opt_func = opt_dict[opt_name] opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate) schedule_dict = { 'fixed': lambda x: 1, 'linear1cycle': lambda x: (9*(1-np.abs(x/steps-1/2)*2)+1)/10, 'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*steps)-1/2)*2)+1)/10 if x < 0.9*steps else 1/10 + (x-0.9*steps)/(0.1*steps)*(1/1000-1/10), } schedule_func = schedule_dict[lr_schedule] scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func) loss_builder = LossBuilder(ref_im, loss_str, eps).cuda() min_loss = np.inf min_l2 = np.inf best_summary = "" start_t = time.time() gen_im = None if self.verbose: print("Optimizing") for j in range(steps): opt.opt.zero_grad() # Duplicate latent in case tile_latent = True if (tile_latent): latent_in = latent.expand(-1, 18, -1) else: latent_in = latent # Apply learned linear mapping to match latent distribution to that of the mapping network latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"]) # Normalize image to [0,1] instead of [-1,1] gen_im = (self.synthesis(latent_in, noise)+1)/2 # Calculate Losses loss, loss_dict = loss_builder(latent_in, gen_im) loss_dict['TOTAL'] = loss # Save best summary for log if(loss < min_loss): min_loss = loss best_summary = f'BEST ({j+1}) | '+' | '.join( [f'{x}: {y:.4f}' for x, y in loss_dict.items()]) best_im = gen_im.clone() loss_l2 = loss_dict['L2'] if(loss_l2 < min_l2): min_l2 = loss_l2 # Save intermediate HR and LR images if(save_intermediate): yield (best_im.cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1)) loss.backward() opt.step() scheduler.step() total_t = time.time()-start_t current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}' if self.verbose: print(best_summary+current_info) if(min_l2 <= eps): yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1)) else: print("Could not find a face that downscales correctly within epsilon")