import json import sys import os import csv import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as nn_f from tqdm import trange sys.path.append(os.path.abspath(sys.path[0] + '/../')) parser = argparse.ArgumentParser() parser.add_argument('-s', '--stereo', action='store_true', help="Render stereo video") parser.add_argument('-d', '--disparity', type=float, default=0.06, help="The stereo disparity") parser.add_argument('-R', '--replace', action='store_true', help="Replace the existed frames in the intermediate output directory") parser.add_argument('--noCE', action='store_true', help="Disable constrast enhancement") parser.add_argument('-i', '--input', type=str) parser.add_argument('-r', '--range', type=str, help="The range of frames to render, specified as format: start,end") parser.add_argument('-f', '--fps', type=int, help="The FPS of output video. if not specified, a sequence of images will be saved instead") parser.add_argument('-m', '--model', type=str, help="The directory containing fovea* and periph* model file") parser.add_argument('view_file', type=str, help="The path to .csv or .json file which contains a sequence of poses and gazes") opt = parser.parse_args() from utils import netio, img, device from utils.view import * from utils.types import * from components.fnr import FoveatedNeuralRenderer from model import Model def load_model(path: Path) -> Model: checkpoint, _ = netio.load_checkpoint(path) checkpoint["model"][1]["sampler"]["perturb"] = False model = Model.create(*checkpoint["model"]) model.load_state_dict(checkpoint["states"]["model"]) model.to(device.default()).eval() return model def load_csv(data_desc_file: Path) -> tuple[Trans, torch.Tensor]: def to_tensor(line_content): return torch.tensor([float(str) for str in line_content.split(',')]) with open(data_desc_file, 'r', encoding='utf-8') as file: lines = file.readlines() simple_fmt = len(lines[3].split(',')) == 1 old_fmt = len(lines[7].split(',')) == 1 lines_per_view = 3 if simple_fmt else 7 if old_fmt else 9 n = len(lines) // lines_per_view gaze_dirs = torch.empty(n, 2, 3) gazes = torch.empty(n, 2, 2) view_mats = torch.empty(n, 2, 4, 4) view_idx = 0 for i in range(0, len(lines), lines_per_view): if simple_fmt: if lines[i + 1].startswith('0,0,0,0'): continue gazes[view_idx] = to_tensor(lines[i + 1]).view(2, 2) view_mats[view_idx] = to_tensor(lines[i + 2]).view(1, 4, 4).expand(2, -1, -1) else: if lines[i + 1].startswith('0,0,0') or lines[i + 2].startswith('0,0,0'): continue j = i + 1 gaze_dirs[view_idx, 0] = to_tensor(lines[j]) j += 1 gaze_dirs[view_idx, 1] = to_tensor(lines[j]) j += 1 if not old_fmt: gazes[view_idx, 0] = to_tensor(lines[j]) j += 1 gazes[view_idx, 1] = to_tensor(lines[j]) j += 1 view_mats[view_idx, 0] = to_tensor(lines[j]).view(4, 4) j += 1 view_mats[view_idx, 1] = to_tensor(lines[j]).view(4, 4) view_idx += 1 view_mats = view_mats[:view_idx] gaze_dirs = gaze_dirs[:view_idx] gazes = gazes[:view_idx] if old_fmt: gaze2 = -gaze_dirs[..., :2] / gaze_dirs[..., 2:] fov = math.radians(55) tan = torch.tensor([math.tan(fov), math.tan(fov) * 1600 / 1440]) gazeper = gaze2 / tan gazes = 0.5 * torch.stack([1440 * gazeper[..., 0], 1600 * gazeper[..., 1]], -1) view_t = 0.5 * (view_mats[:, 0, :3, 3] + view_mats[:, 1, :3, 3]) return Trans(view_t, view_mats[:, 0, :3, :3]), gazes def load_json(data_desc_file: Path) -> tuple[Trans, torch.Tensor]: with open(data_desc_file, 'r', encoding='utf-8') as file: data = json.load(file) view_t = torch.tensor(data['view_centers']) view_r = torch.tensor(data['view_rots']).view(-1, 3, 3) if data.get("gl_coord"): view_t[:, 2] *= -1 view_r[:, 2] *= -1 view_r[..., 2] *= -1 if data.get('gazes'): if len(data['gazes'][0]) == 2: gazes = torch.tensor(data['gazes']).view(-1, 1, 2).expand(-1, 2, -1) else: gazes = torch.tensor(data['gazes']).view(-1, 2, 2) else: gazes = torch.zeros(view_t.size(0), 2, 2) return Trans(view_t, view_r), gazes def load_views_and_gazes(data_desc_file: Path) -> tuple[Trans, torch.Tensor]: if data_desc_file.suffix == '.csv': views, gazes = load_csv(data_desc_file) else: views, gazes = load_json(data_desc_file) gazes[:, :, 1] = (gazes[:, :1, 1] + gazes[:, 1:, 1]) * 0.5 return views, gazes torch.set_grad_enabled(False) view_file = Path(opt.view_file) stereo_disparity = opt.disparity res_full = (1600, 1440) if opt.model: model_dir = Path(opt.model) fov_list = [20.0, 45.0, 110.0] res_list = [(256, 256), (256, 256), (256, 230)] fovea_net = load_model(next(model_dir.glob("fovea*.tar"))) periph_net = load_model(next(model_dir.glob("periph*.tar"))) renderer = FoveatedNeuralRenderer(fov_list, res_list, nn.ModuleList([fovea_net, periph_net, periph_net]), res_full, device=device.default()) else: renderer = None # Load Dataset views, gazes = load_views_and_gazes(Path(opt.view_file)) if opt.range: opt.range = [int(val) for val in opt.range.split(",")] if len(opt.range) == 1: opt.range = [0, opt.range[0]] views = views[opt.range[0]:opt.range[1]] gazes = gazes[opt.range[0]:opt.range[1]] views = views.to(device.default()) n_views = views.shape[0] print('Dataset loaded. Views:', n_views) videodir = view_file.absolute().parent tempdir = Path('/dev/shm/dvs_tmp/video') if opt.input: videoname = Path(opt.input).parent.stem else: videoname = f"{view_file.stem}_{('stereo' if opt.stereo else 'mono')}" if opt.noCE: videoname += "_noCE" gazeout = videodir / f"{videoname}_gaze.csv" if opt.fps: inferout = f"{videodir}/{opt.input}" if opt.input else f"{tempdir}/{videoname}/%04d.bmp" hintout = f"{tempdir}/{videoname}_hint/%04d.bmp" else: inferout = f"{videodir}/{opt.input}" if opt.input else f"{videodir}/{videoname}/%04d.png" hintout = f"{videodir}/{videoname}_hint/%04d.png" scale = img.load(inferout % 0).shape[-1] / res_full[1] if opt.input else 1 hint = img.load(f"{sys.path[0]}/fovea_hint.png", with_alpha=True).to(device=device.default()) hint = nn_f.interpolate(hint, mode='bilinear', scale_factor=scale, align_corners=False) print("Video dir:", videodir) print("Infer out:", inferout) print("Hint out:", hintout) def add_hint(image, center, right_center=None): if right_center is not None: add_hint(image[..., :image.size(-1) // 2], center) add_hint(image[..., image.size(-1) // 2:], right_center) return center = (center[0] * scale, center[1] * scale) fovea_origin = ( int(center[0]) + image.size(-1) // 2 - hint.size(-1) // 2, int(center[1]) + image.size(-2) // 2 - hint.size(-2) // 2 ) fovea_region = ( ..., slice(fovea_origin[1], fovea_origin[1] + hint.size(-2)), slice(fovea_origin[0], fovea_origin[0] + hint.size(-1)), ) try: image[fovea_region] = image[fovea_region] * (1 - hint[:, 3:]) + \ hint[:, :3] * hint[:, 3:] except Exception: print(fovea_region, image.shape, hint.shape) exit() os.makedirs(os.path.dirname(inferout), exist_ok=True) os.makedirs(os.path.dirname(hintout), exist_ok=True) hint_offset = infer_offset = 0 if not opt.replace: for view_idx in range(n_views): if os.path.exists(inferout % view_idx): infer_offset += 1 if os.path.exists(hintout % view_idx): hint_offset += 1 hint_offset = max(0, hint_offset - 1) infer_offset = n_views if opt.input else max(0, infer_offset - 1) for view_idx in trange(n_views): if view_idx < hint_offset: continue gaze = gazes[view_idx] if not opt.stereo: gaze = gaze.sum(0, True) * 0.5 gaze = gaze.tolist() if view_idx < infer_offset: frame = img.load(inferout % view_idx).to(device=device.default()) else: if renderer is None: raise Exception key = 'blended_raw' if opt.noCE else 'blended' view_trans = views[view_idx] if opt.stereo: left_images, right_images = renderer(view_trans, *gaze, stereo_disparity=stereo_disparity, mono_periph_mode=3, ret_raw=True) frame = torch.cat([left_images[key], right_images[key]], -1) else: frame = renderer(view_trans, *gaze, ret_raw=True)[key] img.save(frame, inferout % view_idx) add_hint(frame, *gaze) img.save(frame, hintout % view_idx) gazes_out = gazes.reshape(-1, 4) if opt.stereo else gazes.sum(1) * 0.5 with open(gazeout, 'w') as fp: csv_writer = csv.writer(fp) csv_writer.writerows(gazes_out.tolist()) if opt.fps: # Generate video without hint os.system(f'ffmpeg -y -r {opt.fps:d} -i {inferout} -c:v libx264 {videodir}/{videoname}.mp4') # Generate video with hint os.system(f'ffmpeg -y -r {opt.fps:d} -i {hintout} -c:v libx264 {videodir}/{videoname}_hint.mp4') # Clean temp images if not opt.input: shutil.rmtree(os.path.dirname(inferout)) shutil.rmtree(os.path.dirname(hintout))