gen_video.py 10.9 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
import json
Nianchen Deng's avatar
Nianchen Deng committed
2
3
4
5
import sys
import os
import argparse
import torch
Nianchen Deng's avatar
Nianchen Deng committed
6
7
8
import shutil
import torch.nn as nn
import torch.nn.functional as nn_f
Nianchen Deng's avatar
Nianchen Deng committed
9

Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
11
sys.path.append(os.path.abspath(sys.path[0] + '/../'))

Nianchen Deng's avatar
Nianchen Deng committed
12
parser = argparse.ArgumentParser()
Nianchen Deng's avatar
Nianchen Deng committed
13
14
15
16
17
18
parser.add_argument('-s', '--stereo', action='store_true')
parser.add_argument('-R', '--replace', action='store_true')
parser.add_argument('--noCE', action='store_true')
parser.add_argument('-i', '--input', type=str)
parser.add_argument('-r', '--range', type=str)
parser.add_argument('-f', '--fps', type=int)
Nianchen Deng's avatar
Nianchen Deng committed
19
20
21
parser.add_argument('--device', type=int, default=0,
                    help='Which CUDA device to use.')
parser.add_argument('scene', type=str)
Nianchen Deng's avatar
Nianchen Deng committed
22
parser.add_argument('view_file', type=str)
Nianchen Deng's avatar
Nianchen Deng committed
23
24
25
26
27
28
29
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
31
32
33
34
from configs.spherical_view_syn import SphericalViewSynConfig
from utils import netio
from utils import misc
from utils import img
from utils import device
Nianchen Deng's avatar
Nianchen Deng committed
35
from utils.view import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
36
from utils import sphere
Nianchen Deng's avatar
Nianchen Deng committed
37
from components.fnr import FoveatedNeuralRenderer
Nianchen Deng's avatar
sync    
Nianchen Deng committed
38
from utils.progress_bar import progress_bar
Nianchen Deng's avatar
Nianchen Deng committed
39
40
41
42
43


def load_net(path):
    config = SphericalViewSynConfig()
    config.from_id(path[:-4])
Nianchen Deng's avatar
Nianchen Deng committed
44
    config.sa['perturb_sample'] = False
Nianchen Deng's avatar
Nianchen Deng committed
45
    # config.print()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
46
47
    net = config.create_net().to(device.default())
    netio.load(path, net)
Nianchen Deng's avatar
Nianchen Deng committed
48
49
50
51
52
53
54
55
56
    return net


def find_file(prefix):
    for path in os.listdir():
        if path.startswith(prefix):
            return path
    return None

Nianchen Deng's avatar
Nianchen Deng committed
57

Nianchen Deng's avatar
Nianchen Deng committed
58
59
def clamp_gaze(gaze):
    return gaze
Nianchen Deng's avatar
sync    
Nianchen Deng committed
60
    scoord = sphere.cartesian2spherical(gaze)
Nianchen Deng's avatar
Nianchen Deng committed
61
62


Nianchen Deng's avatar
Nianchen Deng committed
63
64
65
66
def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]:
    def to_tensor(line_content):
        return torch.tensor([float(str) for str in line_content.split(',')])

Nianchen Deng's avatar
Nianchen Deng committed
67
68
    with open(data_desc_file, 'r', encoding='utf-8') as file:
        lines = file.readlines()
Nianchen Deng's avatar
Nianchen Deng committed
69
70
71
72
73
74
75
        simple_fmt = len(lines[3].split(',')) == 1
        old_fmt = len(lines[7].split(',')) == 1
        lines_per_view = 3 if simple_fmt else 7 if old_fmt else 9
        n = len(lines) // lines_per_view
        gaze_dirs = torch.empty(n, 2, 3)
        gazes = torch.empty(n, 2, 2)
        view_mats = torch.empty(n, 2, 4, 4)
Nianchen Deng's avatar
Nianchen Deng committed
76
        view_idx = 0
Nianchen Deng's avatar
Nianchen Deng committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        for i in range(0, len(lines), lines_per_view):
            if simple_fmt:
                if lines[i + 1].startswith('0,0,0,0'):
                    continue
                gazes[view_idx] = to_tensor(lines[i + 1]).view(2, 2)
                view_mats[view_idx] = to_tensor(lines[i + 2]).view(1, 4, 4).expand(2, -1, -1)
            else:
                if lines[i + 1].startswith('0,0,0') or lines[i + 2].startswith('0,0,0'):
                    continue
                j = i + 1
                gaze_dirs[view_idx, 0] = clamp_gaze(to_tensor(lines[j]))
                j += 1
                gaze_dirs[view_idx, 1] = clamp_gaze(to_tensor(lines[j]))
                j += 1
                if not old_fmt:
                    gazes[view_idx, 0] = to_tensor(lines[j])
                    j += 1
                    gazes[view_idx, 1] = to_tensor(lines[j])
                    j += 1
                view_mats[view_idx, 0] = to_tensor(lines[j]).view(4, 4)
                j += 1
                view_mats[view_idx, 1] = to_tensor(lines[j]).view(4, 4)
Nianchen Deng's avatar
Nianchen Deng committed
99
100
            view_idx += 1

Nianchen Deng's avatar
Nianchen Deng committed
101
102
103
    view_mats = view_mats[:view_idx]
    gaze_dirs = gaze_dirs[:view_idx]
    gazes = gazes[:view_idx]
Nianchen Deng's avatar
Nianchen Deng committed
104

Nianchen Deng's avatar
Nianchen Deng committed
105
106
107
108
109
110
    if old_fmt:
        gaze2 = -gaze_dirs[..., :2] / gaze_dirs[..., 2:]
        fov = math.radians(55)
        tan = torch.tensor([math.tan(fov), math.tan(fov) * 1600 / 1440])
        gazeper = gaze2 / tan
        gazes = 0.5 * torch.stack([1440 * gazeper[..., 0], 1600 * gazeper[..., 1]], -1)
Nianchen Deng's avatar
Nianchen Deng committed
111

Nianchen Deng's avatar
Nianchen Deng committed
112
    view_t = 0.5 * (view_mats[:, 0, :3, 3] + view_mats[:, 1, :3, 3])
Nianchen Deng's avatar
Nianchen Deng committed
113

Nianchen Deng's avatar
Nianchen Deng committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    return Trans(view_t, view_mats[:, 0, :3, :3]), gazes

def load_json(data_desc_file) -> Tuple[Trans, torch.Tensor]:
    with open(data_desc_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
        view_t = torch.tensor(data['view_centers'])
        view_r = torch.tensor(data['view_rots']).view(-1, 3, 3)
        if data.get('gazes'):
            if len(data['gazes'][0]) == 2:
                gazes = torch.tensor(data['gazes']).view(-1, 1, 2).expand(-1, 2, -1)
            else:
                gazes = torch.tensor(data['gazes']).view(-1, 2, 2)
        else:
            gazes = torch.zeros(view_t.size(0), 2, 2)
    return Trans(view_t, view_r), gazes

def load_views(data_desc_file: str) -> Tuple[Trans, torch.Tensor]:
    if data_desc_file.endswith('.csv'):
        return load_csv(data_desc_file)
    return load_json(data_desc_file)

rot_range = {
    'classroom': [120, 80],
    'barbershop': [360, 80],
    'lobby': [360, 80],
    'stones': [360, 80]
}
trans_range = {
    'classroom': 0.6,
    'barbershop': 0.3,
    'lobby': 1.0,
    'stones': 1.0
Nianchen Deng's avatar
Nianchen Deng committed
146
}
Nianchen Deng's avatar
Nianchen Deng committed
147
148
149
150
fov_list = [20, 45, 110]
res_list = [(256, 256), (256, 256), (400, 360)]
res_full = (1600, 1440)
stereo_disparity = 0.06
Nianchen Deng's avatar
Nianchen Deng committed
151

Nianchen Deng's avatar
Nianchen Deng committed
152
153
cwd = os.getcwd()
os.chdir(f"{sys.path[0]}/../data/__new/{opt.scene}_all")
Nianchen Deng's avatar
Nianchen Deng committed
154
155
fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph'))
Nianchen Deng's avatar
Nianchen Deng committed
156
157
158
159
renderer = FoveatedNeuralRenderer(fov_list, res_list,
                                  nn.ModuleList([fovea_net, periph_net, periph_net]),
                                  res_full, device=device.default())
os.chdir(cwd)
Nianchen Deng's avatar
Nianchen Deng committed
160
161
162

# Load Dataset
views, gazes = load_views(opt.view_file)
Nianchen Deng's avatar
Nianchen Deng committed
163
164
165
166
167
168
169
170
171
172
if opt.range:
    opt.range = [int(val) for val in opt.range.split(",")]
    if len(opt.range) == 1:
        opt.range = [0, opt.range[0]]
    views = views.get(range(opt.range[0], opt.range[1]))
    gazes = gazes[opt.range[0]:opt.range[1]]
views = views.to(device.default())
n_views = views.size()[0]
print('Dataset loaded. Views:', n_views)

Nianchen Deng's avatar
Nianchen Deng committed
173

Nianchen Deng's avatar
Nianchen Deng committed
174
videodir = os.path.dirname(os.path.abspath(opt.view_file))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
175
tempdir = '/dev/shm/dvs_tmp/video'
Nianchen Deng's avatar
Nianchen Deng committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}"
gazeout = f"{videodir}/{videoname}_gaze.csv"
if opt.noCE:
    videoname += "_noCE"
if opt.fps:
    if opt.input:
        videoname = os.path.split(opt.input)[0]
    inferout = f"{videodir}/{opt.input}" if opt.input else f"{tempdir}/{videoname}/%04d.bmp"
    hintout = f"{tempdir}/{videoname}_hint/%04d.bmp"
else:
    inferout = f"{videodir}/{opt.input}" if opt.input else f"{videodir}/{videoname}/%04d.png"
    hintout = f"{videodir}/{videoname}_hint/%04d.png"
if opt.input:
    scale = img.load(inferout % 0).shape[-1] / res_full[1]
else:
    scale = 1
hint = img.load(f"{sys.path[0]}/fovea_hint.png", with_alpha=True).to(device=device.default())
hint = nn_f.interpolate(hint, mode='bilinear', scale_factor=scale, align_corners=False)
Nianchen Deng's avatar
Nianchen Deng committed
194

Nianchen Deng's avatar
Nianchen Deng committed
195
196
197
print("Video dir:", videodir)
print("Infer out:", inferout)
print("Hint out:", hintout)
Nianchen Deng's avatar
Nianchen Deng committed
198
199


Nianchen Deng's avatar
Nianchen Deng committed
200
201
202
203
204
205
def add_hint(image, center, right_center=None):
    if right_center is not None:
        add_hint(image[..., :image.size(-1) // 2], center)
        add_hint(image[..., image.size(-1) // 2:], right_center)
        return
    center = (center[0] * scale, center[1] * scale)
Nianchen Deng's avatar
Nianchen Deng committed
206
    fovea_origin = (
Nianchen Deng's avatar
Nianchen Deng committed
207
208
        int(center[0]) + image.size(-1) // 2 - hint.size(-1) // 2,
        int(center[1]) + image.size(-2) // 2 - hint.size(-2) // 2
Nianchen Deng's avatar
Nianchen Deng committed
209
210
211
212
213
214
    )
    fovea_region = (
        ...,
        slice(fovea_origin[1], fovea_origin[1] + hint.size(-2)),
        slice(fovea_origin[0], fovea_origin[0] + hint.size(-1)),
    )
Nianchen Deng's avatar
Nianchen Deng committed
215
216
217
218
219
220
221
    try:
        image[fovea_region] = image[fovea_region] * (1 - hint[:, 3:]) + \
            hint[:, :3] * hint[:, 3:]
    except Exception:
        print(fovea_region, image.shape, hint.shape)
        exit()

Nianchen Deng's avatar
Nianchen Deng committed
222

Nianchen Deng's avatar
sync    
Nianchen Deng committed
223
224
os.makedirs(os.path.dirname(inferout), exist_ok=True)
os.makedirs(os.path.dirname(hintout), exist_ok=True)
Nianchen Deng's avatar
Nianchen Deng committed
225

Nianchen Deng's avatar
Nianchen Deng committed
226
227
228
229
230
231
232
233
234
hint_offset = infer_offset = 0
if not opt.replace:
    for view_idx in range(n_views):
        if os.path.exists(inferout % view_idx):
            infer_offset += 1
        if os.path.exists(hintout % view_idx):
            hint_offset += 1
    hint_offset = max(0, hint_offset - 1)
    infer_offset = n_views if opt.input else max(0, infer_offset - 1)
Nianchen Deng's avatar
Nianchen Deng committed
235

Nianchen Deng's avatar
Nianchen Deng committed
236
237
if opt.stereo:
    gazes_out = torch.empty(n_views, 4)
Nianchen Deng's avatar
Nianchen Deng committed
238
    for view_idx in range(n_views):
Nianchen Deng's avatar
Nianchen Deng committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        shift = gazes[view_idx, 0, 0] - gazes[view_idx, 1, 0]
        # print(shift.item())
        gazel = ((gazes[view_idx, 1, 0] + 0.4 * shift).item(),
                 0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item())
        gazer = ((gazes[view_idx, 0, 0] - 0.4 * shift).item(), gazel[1])
        # gazel = ((gazes[view_idx, 0, 0]).item(),
        #         0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item())
        #gazer = ((gazes[view_idx, 1, 0]).item(), gazel[1])
        gazes_out[view_idx] = torch.tensor([gazel[0], gazel[1], gazer[0], gazer[1]])
        if view_idx < hint_offset:
            continue
        if view_idx < infer_offset:
            frame = img.load(inferout % view_idx).to(device=device.default())
        else:
            view_trans = views.get(view_idx)
            left_images, right_images = renderer(view_trans, gazel, gazer,
                                                 stereo_disparity=stereo_disparity,
                                                 mono_periph_mode=3, ret_raw=True)
            frame = torch.cat([
                left_images['blended_raw'] if opt.noCE else left_images['blended'],
                right_images['blended_raw'] if opt.noCE else right_images['blended']], -1)
            frame = img.translate(frame, (0.5, 0.5))
            img.save(frame, inferout % view_idx)
        add_hint(frame, gazel, gazer)
        img.save(frame, hintout % view_idx)
        progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx)
Nianchen Deng's avatar
Nianchen Deng committed
265
else:
Nianchen Deng's avatar
Nianchen Deng committed
266
    gazes_out = torch.empty(n_views, 2)
Nianchen Deng's avatar
Nianchen Deng committed
267
    for view_idx in range(n_views):
Nianchen Deng's avatar
Nianchen Deng committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        gaze = 0.5 * (gazes[view_idx, 0] + gazes[view_idx, 1])
        gaze = (gaze[0].item(), gaze[1].item())
        gazes_out[view_idx] = torch.tensor([gaze[0], gaze[1]])
        if view_idx < hint_offset:
            continue
        if view_idx < infer_offset:
            frame = img.load(inferout % view_idx).to(device=device.default())
        else:
            view_trans = views.get(view_idx)
            frame = renderer(view_trans, gaze,
                             ret_raw=True)['blended_raw' if opt.noCE else 'blended']
            frame = img.translate(frame, (0.5, 0.5))
            img.save(frame, inferout % view_idx)
        add_hint(frame, gaze)
        img.save(frame, hintout % view_idx)
Nianchen Deng's avatar
Nianchen Deng committed
283
284
        progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx)

Nianchen Deng's avatar
Nianchen Deng committed
285
286
287
288
with open(gazeout, 'w') as fp:
    for i in range(n_views):
        fp.write(','.join([f'{val.item()}' for val in gazes_out[i]]))
        fp.write('\n')
Nianchen Deng's avatar
Nianchen Deng committed
289

Nianchen Deng's avatar
Nianchen Deng committed
290
291
292
293
294
295
296
297
298
if opt.fps:
    # Generate video without hint
    os.system(f'ffmpeg -y -r {opt.fps:d} -i {inferout} -c:v libx264 {videodir}/{videoname}.mp4')
    if not opt.input:
        shutil.rmtree(os.path.dirname(inferout))

    # Generate video with hint
    os.system(f'ffmpeg -y -r {opt.fps:d} -i {hintout} -c:v libx264 {videodir}/{videoname}_hint.mp4')
    shutil.rmtree(os.path.dirname(hintout))