gen_video.py 9.77 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
import json
Nianchen Deng's avatar
Nianchen Deng committed
2
3
import sys
import os
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
import csv
Nianchen Deng's avatar
Nianchen Deng committed
5
import argparse
Nianchen Deng's avatar
Nianchen Deng committed
6
import shutil
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
import torch
Nianchen Deng's avatar
Nianchen Deng committed
8
9
import torch.nn as nn
import torch.nn.functional as nn_f
Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
from tqdm import trange
Nianchen Deng's avatar
Nianchen Deng committed
11

Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
13
sys.path.append(os.path.abspath(sys.path[0] + '/../'))

Nianchen Deng's avatar
Nianchen Deng committed
14
parser = argparse.ArgumentParser()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
16
17
18
19
20
21
22
parser.add_argument('-s', '--stereo', action='store_true',
                    help="Render stereo video")
parser.add_argument('-d', '--disparity', type=float, default=0.06,
                    help="The stereo disparity")
parser.add_argument('-R', '--replace', action='store_true',
                    help="Replace the existed frames in the intermediate output directory")
parser.add_argument('--noCE', action='store_true',
                    help="Disable constrast enhancement")
Nianchen Deng's avatar
Nianchen Deng committed
23
parser.add_argument('-i', '--input', type=str)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
24
25
26
27
28
29
30
31
parser.add_argument('-r', '--range', type=str,
                    help="The range of frames to render, specified as format: start,end")
parser.add_argument('-f', '--fps', type=int,
                    help="The FPS of output video. if not specified, a sequence of images will be saved instead")
parser.add_argument('-m', '--model', type=str,
                    help="The directory containing fovea* and periph* model file")
parser.add_argument('view_file', type=str,
                    help="The path to .csv or .json file which contains a sequence of poses and gazes")
Nianchen Deng's avatar
Nianchen Deng committed
32
33
34
opt = parser.parse_args()


Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
from utils import netio, img, device
Nianchen Deng's avatar
Nianchen Deng committed
36
from utils.view import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
37
from utils.types import *
Nianchen Deng's avatar
Nianchen Deng committed
38
from components.fnr import FoveatedNeuralRenderer
Nianchen Deng's avatar
sync    
Nianchen Deng committed
39
from model import Model
Nianchen Deng's avatar
Nianchen Deng committed
40
41


Nianchen Deng's avatar
sync    
Nianchen Deng committed
42
43
44
45
46
47
48
def load_model(path: Path) -> Model:
    checkpoint, _ = netio.load_checkpoint(path)
    checkpoint["model"][1]["sampler"]["perturb"] = False
    model = Model.create(*checkpoint["model"])
    model.load_state_dict(checkpoint["states"]["model"])
    model.to(device.default()).eval()
    return model
Nianchen Deng's avatar
Nianchen Deng committed
49

Nianchen Deng's avatar
Nianchen Deng committed
50

Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
def load_csv(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
Nianchen Deng's avatar
Nianchen Deng committed
52
53
54
    def to_tensor(line_content):
        return torch.tensor([float(str) for str in line_content.split(',')])

Nianchen Deng's avatar
Nianchen Deng committed
55
56
    with open(data_desc_file, 'r', encoding='utf-8') as file:
        lines = file.readlines()
Nianchen Deng's avatar
Nianchen Deng committed
57
58
59
60
61
62
63
        simple_fmt = len(lines[3].split(',')) == 1
        old_fmt = len(lines[7].split(',')) == 1
        lines_per_view = 3 if simple_fmt else 7 if old_fmt else 9
        n = len(lines) // lines_per_view
        gaze_dirs = torch.empty(n, 2, 3)
        gazes = torch.empty(n, 2, 2)
        view_mats = torch.empty(n, 2, 4, 4)
Nianchen Deng's avatar
Nianchen Deng committed
64
        view_idx = 0
Nianchen Deng's avatar
Nianchen Deng committed
65
66
67
68
69
70
71
72
73
74
        for i in range(0, len(lines), lines_per_view):
            if simple_fmt:
                if lines[i + 1].startswith('0,0,0,0'):
                    continue
                gazes[view_idx] = to_tensor(lines[i + 1]).view(2, 2)
                view_mats[view_idx] = to_tensor(lines[i + 2]).view(1, 4, 4).expand(2, -1, -1)
            else:
                if lines[i + 1].startswith('0,0,0') or lines[i + 2].startswith('0,0,0'):
                    continue
                j = i + 1
Nianchen Deng's avatar
sync    
Nianchen Deng committed
75
                gaze_dirs[view_idx, 0] = to_tensor(lines[j])
Nianchen Deng's avatar
Nianchen Deng committed
76
                j += 1
Nianchen Deng's avatar
sync    
Nianchen Deng committed
77
                gaze_dirs[view_idx, 1] = to_tensor(lines[j])
Nianchen Deng's avatar
Nianchen Deng committed
78
79
80
81
82
83
84
85
86
                j += 1
                if not old_fmt:
                    gazes[view_idx, 0] = to_tensor(lines[j])
                    j += 1
                    gazes[view_idx, 1] = to_tensor(lines[j])
                    j += 1
                view_mats[view_idx, 0] = to_tensor(lines[j]).view(4, 4)
                j += 1
                view_mats[view_idx, 1] = to_tensor(lines[j]).view(4, 4)
Nianchen Deng's avatar
Nianchen Deng committed
87
88
            view_idx += 1

Nianchen Deng's avatar
Nianchen Deng committed
89
90
91
    view_mats = view_mats[:view_idx]
    gaze_dirs = gaze_dirs[:view_idx]
    gazes = gazes[:view_idx]
Nianchen Deng's avatar
Nianchen Deng committed
92

Nianchen Deng's avatar
Nianchen Deng committed
93
94
95
96
97
98
    if old_fmt:
        gaze2 = -gaze_dirs[..., :2] / gaze_dirs[..., 2:]
        fov = math.radians(55)
        tan = torch.tensor([math.tan(fov), math.tan(fov) * 1600 / 1440])
        gazeper = gaze2 / tan
        gazes = 0.5 * torch.stack([1440 * gazeper[..., 0], 1600 * gazeper[..., 1]], -1)
Nianchen Deng's avatar
Nianchen Deng committed
99

Nianchen Deng's avatar
Nianchen Deng committed
100
    view_t = 0.5 * (view_mats[:, 0, :3, 3] + view_mats[:, 1, :3, 3])
Nianchen Deng's avatar
Nianchen Deng committed
101

Nianchen Deng's avatar
Nianchen Deng committed
102
103
    return Trans(view_t, view_mats[:, 0, :3, :3]), gazes

Nianchen Deng's avatar
sync    
Nianchen Deng committed
104
105

def load_json(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
Nianchen Deng's avatar
Nianchen Deng committed
106
107
108
109
    with open(data_desc_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
        view_t = torch.tensor(data['view_centers'])
        view_r = torch.tensor(data['view_rots']).view(-1, 3, 3)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
110
111
112
113
        if data.get("gl_coord"):
            view_t[:, 2] *= -1
            view_r[:, 2] *= -1
            view_r[..., 2] *= -1
Nianchen Deng's avatar
Nianchen Deng committed
114
115
116
117
118
119
120
121
122
        if data.get('gazes'):
            if len(data['gazes'][0]) == 2:
                gazes = torch.tensor(data['gazes']).view(-1, 1, 2).expand(-1, 2, -1)
            else:
                gazes = torch.tensor(data['gazes']).view(-1, 2, 2)
        else:
            gazes = torch.zeros(view_t.size(0), 2, 2)
    return Trans(view_t, view_r), gazes

Nianchen Deng's avatar
sync    
Nianchen Deng committed
123
124
125
126
127
128
129
130
131
132
133
134
135

def load_views_and_gazes(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
    if data_desc_file.suffix == '.csv':
        views, gazes = load_csv(data_desc_file)
    else:
        views, gazes = load_json(data_desc_file)
    gazes[:, :, 1] = (gazes[:, :1, 1] + gazes[:, 1:, 1]) * 0.5
    return views, gazes


torch.set_grad_enabled(False)
view_file = Path(opt.view_file)
stereo_disparity = opt.disparity
Nianchen Deng's avatar
Nianchen Deng committed
136
res_full = (1600, 1440)
Nianchen Deng's avatar
Nianchen Deng committed
137

Nianchen Deng's avatar
sync    
Nianchen Deng committed
138
139
140
141
142
143
144
145
146
147
148
if opt.model:
    model_dir = Path(opt.model)
    fov_list = [20.0, 45.0, 110.0]
    res_list = [(256, 256), (256, 256), (256, 230)]
    fovea_net = load_model(next(model_dir.glob("fovea*.tar")))
    periph_net = load_model(next(model_dir.glob("periph*.tar")))
    renderer = FoveatedNeuralRenderer(fov_list, res_list,
                                    nn.ModuleList([fovea_net, periph_net, periph_net]),
                                    res_full, device=device.default())
else:
    renderer = None
Nianchen Deng's avatar
Nianchen Deng committed
149
150

# Load Dataset
Nianchen Deng's avatar
sync    
Nianchen Deng committed
151
views, gazes = load_views_and_gazes(Path(opt.view_file))
Nianchen Deng's avatar
Nianchen Deng committed
152
153
154
155
if opt.range:
    opt.range = [int(val) for val in opt.range.split(",")]
    if len(opt.range) == 1:
        opt.range = [0, opt.range[0]]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
156
    views = views[opt.range[0]:opt.range[1]]
Nianchen Deng's avatar
Nianchen Deng committed
157
158
    gazes = gazes[opt.range[0]:opt.range[1]]
views = views.to(device.default())
Nianchen Deng's avatar
sync    
Nianchen Deng committed
159
n_views = views.shape[0]
Nianchen Deng's avatar
Nianchen Deng committed
160
161
print('Dataset loaded. Views:', n_views)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
162
163
164
165
166
167
168
169
170
videodir = view_file.absolute().parent
tempdir = Path('/dev/shm/dvs_tmp/video')
if opt.input:
    videoname = Path(opt.input).parent.stem
else:
    videoname = f"{view_file.stem}_{('stereo' if opt.stereo else 'mono')}"
    if opt.noCE:
        videoname += "_noCE"
gazeout = videodir / f"{videoname}_gaze.csv"
Nianchen Deng's avatar
Nianchen Deng committed
171
172
173
174
175
176
if opt.fps:
    inferout = f"{videodir}/{opt.input}" if opt.input else f"{tempdir}/{videoname}/%04d.bmp"
    hintout = f"{tempdir}/{videoname}_hint/%04d.bmp"
else:
    inferout = f"{videodir}/{opt.input}" if opt.input else f"{videodir}/{videoname}/%04d.png"
    hintout = f"{videodir}/{videoname}_hint/%04d.png"
Nianchen Deng's avatar
sync    
Nianchen Deng committed
177
scale = img.load(inferout % 0).shape[-1] / res_full[1] if opt.input else 1
Nianchen Deng's avatar
Nianchen Deng committed
178
179
hint = img.load(f"{sys.path[0]}/fovea_hint.png", with_alpha=True).to(device=device.default())
hint = nn_f.interpolate(hint, mode='bilinear', scale_factor=scale, align_corners=False)
Nianchen Deng's avatar
Nianchen Deng committed
180

Nianchen Deng's avatar
Nianchen Deng committed
181
182
183
print("Video dir:", videodir)
print("Infer out:", inferout)
print("Hint out:", hintout)
Nianchen Deng's avatar
Nianchen Deng committed
184
185


Nianchen Deng's avatar
Nianchen Deng committed
186
187
188
189
190
191
def add_hint(image, center, right_center=None):
    if right_center is not None:
        add_hint(image[..., :image.size(-1) // 2], center)
        add_hint(image[..., image.size(-1) // 2:], right_center)
        return
    center = (center[0] * scale, center[1] * scale)
Nianchen Deng's avatar
Nianchen Deng committed
192
    fovea_origin = (
Nianchen Deng's avatar
Nianchen Deng committed
193
194
        int(center[0]) + image.size(-1) // 2 - hint.size(-1) // 2,
        int(center[1]) + image.size(-2) // 2 - hint.size(-2) // 2
Nianchen Deng's avatar
Nianchen Deng committed
195
196
197
198
199
200
    )
    fovea_region = (
        ...,
        slice(fovea_origin[1], fovea_origin[1] + hint.size(-2)),
        slice(fovea_origin[0], fovea_origin[0] + hint.size(-1)),
    )
Nianchen Deng's avatar
Nianchen Deng committed
201
202
203
204
205
206
207
    try:
        image[fovea_region] = image[fovea_region] * (1 - hint[:, 3:]) + \
            hint[:, :3] * hint[:, 3:]
    except Exception:
        print(fovea_region, image.shape, hint.shape)
        exit()

Nianchen Deng's avatar
Nianchen Deng committed
208

Nianchen Deng's avatar
sync    
Nianchen Deng committed
209
210
os.makedirs(os.path.dirname(inferout), exist_ok=True)
os.makedirs(os.path.dirname(hintout), exist_ok=True)
Nianchen Deng's avatar
Nianchen Deng committed
211

Nianchen Deng's avatar
Nianchen Deng committed
212
213
214
215
216
217
218
219
220
hint_offset = infer_offset = 0
if not opt.replace:
    for view_idx in range(n_views):
        if os.path.exists(inferout % view_idx):
            infer_offset += 1
        if os.path.exists(hintout % view_idx):
            hint_offset += 1
    hint_offset = max(0, hint_offset - 1)
    infer_offset = n_views if opt.input else max(0, infer_offset - 1)
Nianchen Deng's avatar
Nianchen Deng committed
221

Nianchen Deng's avatar
sync    
Nianchen Deng committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
for view_idx in trange(n_views):
    if view_idx < hint_offset:
        continue
    gaze = gazes[view_idx]
    if not opt.stereo:
        gaze = gaze.sum(0, True) * 0.5
    gaze = gaze.tolist()
    if view_idx < infer_offset:
        frame = img.load(inferout % view_idx).to(device=device.default())
    else:
        if renderer is None:
            raise Exception
        key = 'blended_raw' if opt.noCE else 'blended'
        view_trans = views[view_idx]
        if opt.stereo:
            left_images, right_images = renderer(view_trans, *gaze,
                                                    stereo_disparity=stereo_disparity,
                                                    mono_periph_mode=3, ret_raw=True)
            frame = torch.cat([left_images[key], right_images[key]], -1)
Nianchen Deng's avatar
Nianchen Deng committed
241
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
242
243
244
245
            frame = renderer(view_trans, *gaze, ret_raw=True)[key]
        img.save(frame, inferout % view_idx)
    add_hint(frame, *gaze)
    img.save(frame, hintout % view_idx)
Nianchen Deng's avatar
Nianchen Deng committed
246

Nianchen Deng's avatar
sync    
Nianchen Deng committed
247
gazes_out = gazes.reshape(-1, 4) if opt.stereo else gazes.sum(1) * 0.5
Nianchen Deng's avatar
Nianchen Deng committed
248
with open(gazeout, 'w') as fp:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
249
250
    csv_writer = csv.writer(fp)
    csv_writer.writerows(gazes_out.tolist())
Nianchen Deng's avatar
Nianchen Deng committed
251

Nianchen Deng's avatar
Nianchen Deng committed
252
253
254
255
256
257
if opt.fps:
    # Generate video without hint
    os.system(f'ffmpeg -y -r {opt.fps:d} -i {inferout} -c:v libx264 {videodir}/{videoname}.mp4')

    # Generate video with hint
    os.system(f'ffmpeg -y -r {opt.fps:d} -i {hintout} -c:v libx264 {videodir}/{videoname}_hint.mp4')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
258
259
260
261

    # Clean temp images
    if not opt.input:
        shutil.rmtree(os.path.dirname(inferout))
Nianchen Deng's avatar
Nianchen Deng committed
262
    shutil.rmtree(os.path.dirname(hintout))