import torch import torch.nn.functional as nn_f from utils import view from utils import math class GuideRefinement(object): def __init__(self, guides_image, guides_view: view.Trans, guides_cam: view.CameraParam, net) -> None: rays_o, rays_d = guides_cam.get_global_rays(guides_view, flatten=True) guides_inferred = torch.stack([ net(rays_o[i], rays_d[i]).view( guides_cam.res[0], guides_cam.res[1], -1).permute(2, 0, 1) for i in range(guides_image.size(0)) ], 0) self.guides_diff = (guides_image - guides_inferred) / \ (guides_inferred + math.tiny) self.guides_view = guides_view self.guides_cam = guides_cam def get_warp(self, rays_o, rays_d, depthmap, tgt_trans: view.Trans, tgt_cam): rays_size = list(depthmap.size()) + [3] rays_o = rays_o.view(rays_size) rays_d = rays_d.view(rays_size) #print(rays_o.size(), rays_d.size(), depthmap.size()) pcloud = rays_o + rays_d * depthmap[..., None] #print('pcloud', pcloud.size()) pcloud_in_tgt = tgt_trans.trans_point(pcloud, inverse=True) # print(pcloud_in_tgt.size()) pixel_positions = tgt_cam.proj(pcloud_in_tgt) pixel_positions[..., 0] /= tgt_cam.res[1] * 0.5 pixel_positions[..., 1] /= tgt_cam.res[0] * 0.5 pixel_positions -= 1 return pixel_positions def refine_by_guide(self, image, depthmap, rays_o, rays_d, is_lr): if is_lr: image = nn_f.upsample( image[None, ...], scale_factor=2, mode='bicubic')[0] depthmap = nn_f.upsample( depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0] warp = self.get_warp(rays_o, rays_d, depthmap, self.guides_view, self.guides_cam) warped_diff = nn_f.grid_sample(self.guides_diff, warp) print(warp.size(), warped_diff.size()) avg_diff = torch.mean(warped_diff, 0) return image * (1 + avg_diff)