Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import functools | |
| import json | |
| import os | |
| from pathlib import Path | |
| from pdb import set_trace as st | |
| import blobfile as bf | |
| import imageio | |
| import numpy as np | |
| import torch as th | |
| import torch.distributed as dist | |
| import torchvision | |
| from PIL import Image | |
| from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
| from torch.optim import AdamW | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from guided_diffusion import dist_util, logger | |
| from guided_diffusion.fp16_util import MixedPrecisionTrainer | |
| from guided_diffusion.nn import update_ema | |
| from guided_diffusion.resample import LossAwareSampler, UniformSampler | |
| from guided_diffusion.train_util import (calc_average_loss, | |
| find_ema_checkpoint, | |
| find_resume_checkpoint, | |
| get_blob_logdir, log_rec3d_loss_dict, | |
| parse_resume_step_from_filename) | |
| from .train_util import TrainLoop3DRec | |
| class TrainLoop3DRecEG3D(TrainLoop3DRec): | |
| def __init__(self, | |
| *, | |
| G, | |
| rec_model, | |
| loss_class, | |
| data, | |
| eval_data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| eval_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=0.001, | |
| weight_decay=0, | |
| lr_anneal_steps=0, | |
| iterations=10001, | |
| load_submodule_name='', | |
| ignore_resume_opt=False, | |
| model_name='rec', | |
| use_amp=False, | |
| # hybrid_training=False, | |
| **kwargs): | |
| super().__init__(rec_model=rec_model, | |
| loss_class=loss_class, | |
| data=data, | |
| eval_data=eval_data, | |
| batch_size=batch_size, | |
| microbatch=microbatch, | |
| lr=lr, | |
| ema_rate=ema_rate, | |
| log_interval=log_interval, | |
| eval_interval=eval_interval, | |
| save_interval=save_interval, | |
| resume_checkpoint=resume_checkpoint, | |
| use_fp16=use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| weight_decay=weight_decay, | |
| lr_anneal_steps=lr_anneal_steps, | |
| iterations=iterations, | |
| load_submodule_name=load_submodule_name, | |
| ignore_resume_opt=ignore_resume_opt, | |
| model_name=model_name, | |
| use_amp=use_amp, | |
| **kwargs) | |
| self.G = G | |
| # self.hybrid_training = hybrid_training | |
| self.pool_224 = th.nn.AdaptiveAvgPool2d((224, 224)) | |
| def run_G( | |
| self, | |
| z, | |
| c, | |
| swapping_prob, | |
| neural_rendering_resolution, | |
| update_emas=False, | |
| return_raw_only=False, | |
| ): | |
| """add truncation psi | |
| Args: | |
| z (_type_): _description_ | |
| c (_type_): _description_ | |
| swapping_prob (_type_): _description_ | |
| neural_rendering_resolution (_type_): _description_ | |
| update_emas (bool, optional): _description_. Defaults to False. | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| c_gen_conditioning = th.zeros_like(c) | |
| # ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas) | |
| ws = self.G.mapping( | |
| z, | |
| c_gen_conditioning, | |
| truncation_psi=0.7, | |
| truncation_cutoff=None, | |
| update_emas=update_emas, | |
| ) | |
| gen_output = self.G.synthesis( | |
| ws, # BS * 14 * 512 | |
| c, | |
| neural_rendering_resolution=neural_rendering_resolution, | |
| update_emas=update_emas, | |
| noise_mode='const', | |
| return_raw_only=return_raw_only | |
| # return_meta=True # return feature_volume | |
| ) # fix the SynthesisLayer modulation noise, otherviwe the same latent code may output two different ID | |
| return gen_output, ws | |
| def run_loop(self, batch=None): | |
| while (not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps): | |
| # let all processes sync up before starting with a new epoch of training | |
| dist_util.synchronize() | |
| # batch, cond = next(self.data) | |
| # if batch is None: | |
| batch = next(self.data) | |
| # batch = self.run_G() | |
| self.run_step(batch) | |
| if self.step % self.log_interval == 0 and dist_util.get_rank( | |
| ) == 0: | |
| out = logger.dumpkvs() | |
| # * log to tensorboard | |
| for k, v in out.items(): | |
| self.writer.add_scalar(f'Loss/{k}', v, | |
| self.step + self.resume_step) | |
| if self.step % self.eval_interval == 0 and self.step != 0: | |
| # if dist_util.get_rank() == 0: | |
| # self.eval_loop() | |
| # self.eval_novelview_loop() | |
| # let all processes sync up before starting with a new epoch of training | |
| dist_util.synchronize() | |
| if self.step % self.save_interval == 0: | |
| self.save() | |
| dist_util.synchronize() | |
| # Run for a finite amount of time in integration tests. | |
| if os.environ.get("DIFFUSION_TRAINING_TEST", | |
| "") and self.step > 0: | |
| return | |
| self.step += 1 | |
| if self.step > self.iterations: | |
| print('reached maximum iterations, exiting') | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() | |
| exit() | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() | |
| def run_step(self, batch, *args): | |
| self.forward_backward(batch) | |
| took_step = self.mp_trainer_rec.optimize(self.opt) | |
| if took_step: | |
| self._update_ema() | |
| self._anneal_lr() | |
| self.log_step() | |
| def forward_backward(self, batch, *args, **kwargs): | |
| self.mp_trainer_rec.zero_grad() | |
| batch_size = batch['c'].shape[0] | |
| for i in range(0, batch_size, self.microbatch): | |
| micro = {'c': batch['c'].to(dist_util.dev())} | |
| with th.no_grad(): # * infer gt | |
| eg3d_batch, ws = self.run_G( | |
| z=th.randn(micro['c'].shape[0], | |
| 512).to(dist_util.dev()), | |
| c=micro['c'].to(dist_util.dev( | |
| )), # use real img pose here? or synthesized pose. | |
| swapping_prob=0, | |
| neural_rendering_resolution=128) | |
| micro.update({ | |
| 'img': | |
| eg3d_batch['image_raw'], # gt | |
| 'img_to_encoder': | |
| self.pool_224(eg3d_batch['image']), | |
| 'depth': | |
| eg3d_batch['image_depth'], | |
| 'img_sr': eg3d_batch['image'], | |
| }) | |
| last_batch = (i + self.microbatch) >= batch_size | |
| # wrap forward within amp | |
| with th.autocast(device_type='cuda', | |
| dtype=th.float16, | |
| enabled=self.mp_trainer_rec.use_amp): | |
| pred_gen_output = self.rec_model( | |
| img=micro['img_to_encoder'], # pool from 512 | |
| c=micro['c']) # pred: (B, 3, 64, 64) | |
| # target = micro | |
| target = dict( | |
| img=eg3d_batch['image_raw'], | |
| shape_synthesized=eg3d_batch['shape_synthesized'], | |
| img_sr=eg3d_batch['image'], | |
| ) | |
| pred_gen_output['shape_synthesized_query'] = { | |
| 'coarse_densities': | |
| pred_gen_output['shape_synthesized']['coarse_densities'], | |
| 'image_depth': pred_gen_output['image_depth'], | |
| } | |
| eg3d_batch['shape_synthesized']['image_depth'] = eg3d_batch['image_depth'] | |
| batch_size, num_rays, _, _ = pred_gen_output[ | |
| 'shape_synthesized']['coarse_densities'].shape | |
| for coord_key in ['fine_coords']: # TODO add surface points | |
| sigma = self.rec_model( | |
| latent=pred_gen_output['latent_denormalized'], | |
| coordinates=eg3d_batch['shape_synthesized'][coord_key], | |
| directions=th.randn_like( | |
| eg3d_batch['shape_synthesized'][coord_key]), | |
| behaviour='triplane_renderer', | |
| )['sigma'] | |
| rendering_kwargs = self.rec_model( | |
| behaviour='get_rendering_kwargs') | |
| sigma = sigma.reshape( | |
| batch_size, num_rays, | |
| rendering_kwargs['depth_resolution_importance'], 1) | |
| pred_gen_output['shape_synthesized_query'][ | |
| f"{coord_key.split('_')[0]}_densities"] = sigma | |
| # * 2D reconstruction loss | |
| if last_batch or not self.use_ddp: | |
| loss, loss_dict = self.loss_class(pred_gen_output, | |
| target, | |
| test_mode=False) | |
| else: | |
| with self.rec_model.no_sync(): # type: ignore | |
| loss, loss_dict = self.loss_class(pred_gen_output, | |
| target, | |
| test_mode=False) | |
| # * fully mimic 3D geometry output | |
| loss_shape = self.calc_shape_rec_loss( | |
| pred_gen_output['shape_synthesized_query'], | |
| eg3d_batch['shape_synthesized']) | |
| loss += loss_shape.mean() | |
| # * add feature loss on feature_image | |
| loss_feature_volume = th.nn.functional.mse_loss( | |
| eg3d_batch['feature_volume'], | |
| pred_gen_output['feature_volume']) | |
| loss += loss_feature_volume * 0.1 | |
| loss_ws = th.nn.functional.mse_loss( | |
| ws[:, -1:, :], | |
| pred_gen_output['sr_w_code']) | |
| loss += loss_ws * 0.1 | |
| loss_dict.update( | |
| dict(loss_feature_volume=loss_feature_volume, | |
| loss=loss, | |
| loss_shape=loss_shape, | |
| loss_ws=loss_ws)) | |
| loss_dict.update(dict(loss_feature_volume=loss_feature_volume, loss=loss, loss_shape=loss_shape)) | |
| log_rec3d_loss_dict(loss_dict) | |
| self.mp_trainer_rec.backward(loss) | |
| # for name, p in self.ddp_model.named_parameters(): | |
| # if p.grad is None: | |
| # print(f"found rec unused param: {name}") | |
| if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
| with th.no_grad(): | |
| # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
| pred_img = pred_gen_output['image_raw'] | |
| gt_img = micro['img'] | |
| if 'depth' in micro: | |
| gt_depth = micro['depth'] | |
| if gt_depth.ndim == 3: | |
| gt_depth = gt_depth.unsqueeze(1) | |
| gt_depth = (gt_depth - gt_depth.min()) / ( | |
| gt_depth.max() - gt_depth.min()) | |
| pred_depth = pred_gen_output['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / ( | |
| pred_depth.max() - pred_depth.min()) | |
| gt_vis = th.cat( | |
| [gt_img, | |
| gt_depth.repeat_interleave(3, dim=1)], | |
| dim=-1) # TODO, fail to load depth. range [0, 1] | |
| else: | |
| gt_vis = th.cat( | |
| [gt_img], | |
| dim=-1) # TODO, fail to load depth. range [0, 1] | |
| if 'image_sr' in pred_gen_output: | |
| pred_img = th.cat([ | |
| self.pool_512(pred_img), | |
| pred_gen_output['image_sr'] | |
| ], | |
| dim=-1) | |
| pred_depth = self.pool_512(pred_depth) | |
| gt_depth = self.pool_512(gt_depth) | |
| gt_vis = th.cat( | |
| [self.pool_512(micro['img']), micro['img_sr'], gt_depth.repeat_interleave(3, dim=1)], | |
| dim=-1) | |
| pred_vis = th.cat( | |
| [pred_img, | |
| pred_depth.repeat_interleave(3, dim=1)], | |
| dim=-1) # B, 3, H, W | |
| vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
| 1, 2, 0).cpu() # ! pred in range[-1, 1] | |
| # vis_grid = torchvision.utils.make_grid(vis) # HWC | |
| vis = vis.numpy() * 127.5 + 127.5 | |
| vis = vis.clip(0, 255).astype(np.uint8) | |
| Image.fromarray(vis).save( | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
| print( | |
| 'log vis to: ', | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
| # self.writer.add_image(f'images', | |
| # vis, | |
| # self.step + self.resume_step, | |
| # dataformats='HWC') | |
| return pred_gen_output | |
| def calc_shape_rec_loss( | |
| self, | |
| pred_shape: dict, | |
| gt_shape: dict, | |
| ): | |
| loss_shape, loss_shape_dict = self.loss_class.calc_shape_rec_loss( | |
| pred_shape, | |
| gt_shape, | |
| dist_util.dev(), | |
| ) | |
| for loss_k, loss_v in loss_shape_dict.items(): | |
| # training_stats.report('Loss/E/3D/{}'.format(loss_k), loss_v) | |
| log_rec3d_loss_dict({'Loss/3D/{}'.format(loss_k): loss_v}) | |
| return loss_shape | |
| # @th.inference_mode() | |
| def eval_novelview_loop(self): | |
| # novel view synthesis given evaluation camera trajectory | |
| video_out = imageio.get_writer( | |
| f'{logger.get_dir()}/video_novelview_real_{self.step+self.resume_step}.mp4', | |
| mode='I', | |
| fps=60, | |
| codec='libx264') | |
| all_loss_dict = [] | |
| novel_view_micro = {} | |
| # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
| for i, batch in enumerate(tqdm(self.eval_data)): | |
| # for i in range(0, 8, self.microbatch): | |
| # c = c_list[i].to(dist_util.dev()).reshape(1, -1) | |
| micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
| if i == 0: | |
| novel_view_micro = { | |
| k: v[0:1].to(dist_util.dev()).repeat_interleave( | |
| micro['img'].shape[0], 0) | |
| for k, v in batch.items() | |
| } | |
| else: | |
| # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]: | |
| novel_view_micro = { | |
| k: v[0:1].to(dist_util.dev()).repeat_interleave( | |
| micro['img'].shape[0], 0) | |
| for k, v in novel_view_micro.items() | |
| } | |
| # st() | |
| pred = self.rec_model(img=novel_view_micro['img_to_encoder'], | |
| c=micro['c']) # pred: (B, 3, 64, 64) | |
| # _, loss_dict = self.loss_class(pred, micro, test_mode=True) | |
| # all_loss_dict.append(loss_dict) | |
| # ! move to other places, add tensorboard | |
| pred_depth = pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
| pred_depth.min()) | |
| if 'image_sr' in pred: | |
| pred_vis = th.cat([ | |
| micro['img_sr'], | |
| self.pool_512(pred['image_raw']), pred['image_sr'], | |
| self.pool_512(pred_depth).repeat_interleave(3, dim=1) | |
| ], | |
| dim=-1) | |
| else: | |
| pred_vis = th.cat([ | |
| self.pool_128(micro['img']), pred['image_raw'], | |
| pred_depth.repeat_interleave(3, dim=1) | |
| ], | |
| dim=-1) # B, 3, H, W | |
| vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
| vis = vis * 127.5 + 127.5 | |
| vis = vis.clip(0, 255).astype(np.uint8) | |
| for j in range(vis.shape[0]): | |
| video_out.append_data(vis[j]) | |
| video_out.close() | |
| # val_scores_for_logging = calc_average_loss(all_loss_dict) | |
| # with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), | |
| # 'a') as f: | |
| # json.dump({'step': self.step, **val_scores_for_logging}, f) | |
| # * log to tensorboard | |
| # for k, v in val_scores_for_logging.items(): | |
| # self.writer.add_scalar(f'Eval/NovelView/{k}', v, | |
| # self.step + self.resume_step) | |
| del video_out | |
| # del pred_vis | |
| # del pred | |
| th.cuda.empty_cache() | |
| # self.eval_novelview_loop_eg3d() | |
| def eval_novelview_loop_eg3d(self): | |
| # novel view synthesis given evaluation camera trajectory | |
| video_out = imageio.get_writer( | |
| f'{logger.get_dir()}/video_novelview_synthetic_{self.step+self.resume_step}.mp4', | |
| mode='I', | |
| fps=60, | |
| codec='libx264') | |
| all_loss_dict = [] | |
| novel_view_micro = {} | |
| # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
| for i, batch in enumerate(tqdm(self.eval_data)): | |
| # for i in range(0, 8, self.microbatch): | |
| # c = c_list[i].to(dist_util.dev()).reshape(1, -1) | |
| micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
| if i == 0: | |
| # novel_view_micro = { | |
| # k: v[0:1].to(dist_util.dev()).repeat_interleave( | |
| # micro['img'].shape[0], 0) | |
| # for k, v in batch.items() | |
| # } | |
| with th.no_grad(): # * infer gt | |
| eg3d_batch, _ = self.run_G( | |
| z=th.randn(micro['c'].shape[0], | |
| 512).to(dist_util.dev()), | |
| c=micro['c'].to(dist_util.dev( | |
| )), # use real img pose here? or synthesized pose. | |
| swapping_prob=0, | |
| neural_rendering_resolution=128) | |
| novel_view_micro.update({ | |
| 'img': | |
| eg3d_batch['image_raw'], # gt | |
| 'img_to_encoder': | |
| self.pool_224(eg3d_batch['image']), | |
| 'depth': | |
| eg3d_batch['image_depth'], | |
| }) | |
| else: | |
| # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]: | |
| novel_view_micro = { | |
| k: v[0:1].to(dist_util.dev()).repeat_interleave( | |
| micro['img'].shape[0], 0) | |
| for k, v in novel_view_micro.items() | |
| } | |
| # st() | |
| pred = self.rec_model(img=novel_view_micro['img_to_encoder'], | |
| c=micro['c']) # pred: (B, 3, 64, 64) | |
| # _, loss_dict = self.loss_class(pred, micro, test_mode=True) | |
| # all_loss_dict.append(loss_dict) | |
| # ! move to other places, add tensorboard | |
| pred_depth = pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
| pred_depth.min()) | |
| if 'image_sr' in pred: | |
| pred_vis = th.cat([ | |
| micro['img_sr'], | |
| self.pool_512(pred['image_raw']), pred['image_sr'], | |
| self.pool_512(pred_depth).repeat_interleave(3, dim=1) | |
| ], | |
| dim=-1) | |
| else: | |
| pred_vis = th.cat([ | |
| self.pool_128(micro['img']), pred['image_raw'], | |
| pred_depth.repeat_interleave(3, dim=1) | |
| ], | |
| dim=-1) # B, 3, H, W | |
| vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
| vis = vis * 127.5 + 127.5 | |
| vis = vis.clip(0, 255).astype(np.uint8) | |
| for j in range(vis.shape[0]): | |
| video_out.append_data(vis[j]) | |
| video_out.close() | |
| # val_scores_for_logging = calc_average_loss(all_loss_dict) | |
| # with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), | |
| # 'a') as f: | |
| # json.dump({'step': self.step, **val_scores_for_logging}, f) | |
| # # * log to tensorboard | |
| # for k, v in val_scores_for_logging.items(): | |
| # self.writer.add_scalar(f'Eval/NovelView/{k}', v, | |
| # self.step + self.resume_step) | |
| del video_out | |
| # del pred_vis | |
| # del pred | |
| th.cuda.empty_cache() |