import json import sys import os import argparse import torch import shutil import torch.nn as nn import torch.nn.functional as nn_f sys.path.append(os.path.abspath(sys.path[0] + '/../')) parser = argparse.ArgumentParser() parser.add_argument('-s', '--stereo', action='store_true') parser.add_argument('-R', '--replace', action='store_true') parser.add_argument('--noCE', action='store_true') parser.add_argument('-i', '--input', type=str) parser.add_argument('-r', '--range', type=str) parser.add_argument('-f', '--fps', type=int) parser.add_argument('--device', type=int, default=0, help='Which CUDA device to use.') parser.add_argument('scene', type=str) parser.add_argument('view_file', type=str) opt = parser.parse_args() # Select device torch.cuda.set_device(opt.device) print("Set CUDA:%d as current device." % torch.cuda.current_device()) torch.autograd.set_grad_enabled(False) from configs.spherical_view_syn import SphericalViewSynConfig from utils import netio from utils import misc from utils import img from utils import device from utils.view import * from utils import sphere from components.fnr import FoveatedNeuralRenderer from utils.progress_bar import progress_bar def load_net(path): config = SphericalViewSynConfig() config.from_id(path[:-4]) config.sa['perturb_sample'] = False # config.print() net = config.create_net().to(device.default()) netio.load(path, net) return net def find_file(prefix): for path in os.listdir(): if path.startswith(prefix): return path return None def clamp_gaze(gaze): return gaze scoord = sphere.cartesian2spherical(gaze) def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]: def to_tensor(line_content): return torch.tensor([float(str) for str in line_content.split(',')]) with open(data_desc_file, 'r', encoding='utf-8') as file: lines = file.readlines() simple_fmt = len(lines[3].split(',')) == 1 old_fmt = len(lines[7].split(',')) == 1 lines_per_view = 3 if simple_fmt else 7 if old_fmt else 9 n = len(lines) // lines_per_view gaze_dirs = torch.empty(n, 2, 3) gazes = torch.empty(n, 2, 2) view_mats = torch.empty(n, 2, 4, 4) view_idx = 0 for i in range(0, len(lines), lines_per_view): if simple_fmt: if lines[i + 1].startswith('0,0,0,0'): continue gazes[view_idx] = to_tensor(lines[i + 1]).view(2, 2) view_mats[view_idx] = to_tensor(lines[i + 2]).view(1, 4, 4).expand(2, -1, -1) else: if lines[i + 1].startswith('0,0,0') or lines[i + 2].startswith('0,0,0'): continue j = i + 1 gaze_dirs[view_idx, 0] = clamp_gaze(to_tensor(lines[j])) j += 1 gaze_dirs[view_idx, 1] = clamp_gaze(to_tensor(lines[j])) j += 1 if not old_fmt: gazes[view_idx, 0] = to_tensor(lines[j]) j += 1 gazes[view_idx, 1] = to_tensor(lines[j]) j += 1 view_mats[view_idx, 0] = to_tensor(lines[j]).view(4, 4) j += 1 view_mats[view_idx, 1] = to_tensor(lines[j]).view(4, 4) view_idx += 1 view_mats = view_mats[:view_idx] gaze_dirs = gaze_dirs[:view_idx] gazes = gazes[:view_idx] if old_fmt: gaze2 = -gaze_dirs[..., :2] / gaze_dirs[..., 2:] fov = math.radians(55) tan = torch.tensor([math.tan(fov), math.tan(fov) * 1600 / 1440]) gazeper = gaze2 / tan gazes = 0.5 * torch.stack([1440 * gazeper[..., 0], 1600 * gazeper[..., 1]], -1) view_t = 0.5 * (view_mats[:, 0, :3, 3] + view_mats[:, 1, :3, 3]) return Trans(view_t, view_mats[:, 0, :3, :3]), gazes def load_json(data_desc_file) -> Tuple[Trans, torch.Tensor]: with open(data_desc_file, 'r', encoding='utf-8') as file: data = json.load(file) view_t = torch.tensor(data['view_centers']) view_r = torch.tensor(data['view_rots']).view(-1, 3, 3) if data.get('gazes'): if len(data['gazes'][0]) == 2: gazes = torch.tensor(data['gazes']).view(-1, 1, 2).expand(-1, 2, -1) else: gazes = torch.tensor(data['gazes']).view(-1, 2, 2) else: gazes = torch.zeros(view_t.size(0), 2, 2) return Trans(view_t, view_r), gazes def load_views(data_desc_file: str) -> Tuple[Trans, torch.Tensor]: if data_desc_file.endswith('.csv'): return load_csv(data_desc_file) return load_json(data_desc_file) rot_range = { 'classroom': [120, 80], 'barbershop': [360, 80], 'lobby': [360, 80], 'stones': [360, 80] } trans_range = { 'classroom': 0.6, 'barbershop': 0.3, 'lobby': 1.0, 'stones': 1.0 } fov_list = [20, 45, 110] res_list = [(256, 256), (256, 256), (400, 360)] res_full = (1600, 1440) stereo_disparity = 0.06 cwd = os.getcwd() os.chdir(f"{sys.path[0]}/../data/__new/{opt.scene}_all") fovea_net = load_net(find_file('fovea')) periph_net = load_net(find_file('periph')) renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]), res_full, device=device.default()) os.chdir(cwd) # Load Dataset views, gazes = load_views(opt.view_file) if opt.range: opt.range = [int(val) for val in opt.range.split(",")] if len(opt.range) == 1: opt.range = [0, opt.range[0]] views = views.get(range(opt.range[0], opt.range[1])) gazes = gazes[opt.range[0]:opt.range[1]] views = views.to(device.default()) n_views = views.size()[0] print('Dataset loaded. Views:', n_views) videodir = os.path.dirname(os.path.abspath(opt.view_file)) tempdir = '/dev/shm/dvs_tmp/realvideo' videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}" gazeout = f"{videodir}/{videoname}_gaze.csv" if opt.noCE: videoname += "_noCE" if opt.fps: if opt.input: videoname = os.path.split(opt.input)[0] inferout = f"{videodir}/{opt.input}" if opt.input else f"{tempdir}/{videoname}/%04d.bmp" hintout = f"{tempdir}/{videoname}_hint/%04d.bmp" else: inferout = f"{videodir}/{opt.input}" if opt.input else f"{videodir}/{videoname}/%04d.png" hintout = f"{videodir}/{videoname}_hint/%04d.png" if opt.input: scale = img.load(inferout % 0).shape[-1] / res_full[1] else: scale = 1 hint = img.load(f"{sys.path[0]}/fovea_hint.png", with_alpha=True).to(device=device.default()) hint = nn_f.interpolate(hint, mode='bilinear', scale_factor=scale, align_corners=False) print("Video dir:", videodir) print("Infer out:", inferout) print("Hint out:", hintout) def add_hint(image, center, right_center=None): if right_center is not None: add_hint(image[..., :image.size(-1) // 2], center) add_hint(image[..., image.size(-1) // 2:], right_center) return center = (center[0] * scale, center[1] * scale) fovea_origin = ( int(center[0]) + image.size(-1) // 2 - hint.size(-1) // 2, int(center[1]) + image.size(-2) // 2 - hint.size(-2) // 2 ) fovea_region = ( ..., slice(fovea_origin[1], fovea_origin[1] + hint.size(-2)), slice(fovea_origin[0], fovea_origin[0] + hint.size(-1)), ) try: image[fovea_region] = image[fovea_region] * (1 - hint[:, 3:]) + \ hint[:, :3] * hint[:, 3:] except Exception: print(fovea_region, image.shape, hint.shape) exit() misc.create_dir(os.path.dirname(inferout)) misc.create_dir(os.path.dirname(hintout)) hint_offset = infer_offset = 0 if not opt.replace: for view_idx in range(n_views): if os.path.exists(inferout % view_idx): infer_offset += 1 if os.path.exists(hintout % view_idx): hint_offset += 1 hint_offset = max(0, hint_offset - 1) infer_offset = n_views if opt.input else max(0, infer_offset - 1) if opt.stereo: gazes_out = torch.empty(n_views, 4) for view_idx in range(n_views): shift = gazes[view_idx, 0, 0] - gazes[view_idx, 1, 0] # print(shift.item()) gazel = ((gazes[view_idx, 1, 0] + 0.4 * shift).item(), 0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item()) gazer = ((gazes[view_idx, 0, 0] - 0.4 * shift).item(), gazel[1]) # gazel = ((gazes[view_idx, 0, 0]).item(), # 0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item()) #gazer = ((gazes[view_idx, 1, 0]).item(), gazel[1]) gazes_out[view_idx] = torch.tensor([gazel[0], gazel[1], gazer[0], gazer[1]]) if view_idx < hint_offset: continue if view_idx < infer_offset: frame = img.load(inferout % view_idx).to(device=device.default()) else: view_trans = views.get(view_idx) left_images, right_images = renderer(view_trans, gazel, gazer, stereo_disparity=stereo_disparity, mono_periph_mode=3, ret_raw=True) frame = torch.cat([ left_images['blended_raw'] if opt.noCE else left_images['blended'], right_images['blended_raw'] if opt.noCE else right_images['blended']], -1) frame = img.translate(frame, (0.5, 0.5)) img.save(frame, inferout % view_idx) add_hint(frame, gazel, gazer) img.save(frame, hintout % view_idx) progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx) else: gazes_out = torch.empty(n_views, 2) for view_idx in range(n_views): gaze = 0.5 * (gazes[view_idx, 0] + gazes[view_idx, 1]) gaze = (gaze[0].item(), gaze[1].item()) gazes_out[view_idx] = torch.tensor([gaze[0], gaze[1]]) if view_idx < hint_offset: continue if view_idx < infer_offset: frame = img.load(inferout % view_idx).to(device=device.default()) else: view_trans = views.get(view_idx) frame = renderer(view_trans, gaze, ret_raw=True)['blended_raw' if opt.noCE else 'blended'] frame = img.translate(frame, (0.5, 0.5)) img.save(frame, inferout % view_idx) add_hint(frame, gaze) img.save(frame, hintout % view_idx) progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx) with open(gazeout, 'w') as fp: for i in range(n_views): fp.write(','.join([f'{val.item()}' for val in gazes_out[i]])) fp.write('\n') if opt.fps: # Generate video without hint os.system(f'ffmpeg -y -r {opt.fps:d} -i {inferout} -c:v libx264 {videodir}/{videoname}.mp4') if not opt.input: shutil.rmtree(os.path.dirname(inferout)) # Generate video with hint os.system(f'ffmpeg -y -r {opt.fps:d} -i {hintout} -c:v libx264 {videodir}/{videoname}_hint.mp4') shutil.rmtree(os.path.dirname(hintout))