test.py 9.34 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
import os
import argparse
import torch
import torch.nn.functional as nn_f
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6
import cv2
import numpy as np
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
9
10
11
12
13
from pathlib import Path

parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str,
                    help='The model file to load for testing')
parser.add_argument('-r', '--output-res', type=str,
                    help='Output resolution')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
parser.add_argument('-o', '--output', nargs='*', type=str, default=['perf', 'color'],
Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
16
17
18
19
                    help='Specify what to output (perf, color, depth, all)')
parser.add_argument('--output-type', type=str, default='image',
                    help='Specify the output type (image, video, debug)')
parser.add_argument('--views', type=str,
                    help='Specify the range of views to test')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
20
parser.add_argument('-s', '--samples', type=int)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
parser.add_argument('-p', '--prompt', action='store_true',
                    help='Interactive prompt mode')
parser.add_argument('--time', action='store_true',
                    help='Enable time measurement')
parser.add_argument('dataset', type=str,
                    help='Dataset description file')
args = parser.parse_args()


import model as mdl
from loss.ssim import ssim
from utils import color
from utils import interact
from utils import device
from utils import img
Nianchen Deng's avatar
sync    
Nianchen Deng committed
36
37
from utils import netio
from utils import math
Nianchen Deng's avatar
sync    
Nianchen Deng committed
38
39
from utils.perf import Perf, enable_perf, get_perf_result
from utils.progress_bar import progress_bar
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
from data import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
42
43


DATA_LOADER_CHUNK_SIZE = 1e8
Nianchen Deng's avatar
sync    
Nianchen Deng committed
44
torch.set_grad_enabled(False)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
46


Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
data_desc_path = get_dataset_desc_path(args.dataset)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
os.chdir(data_desc_path.parent)
nets_dir = Path("_nets")
data_desc_path = data_desc_path.name


def set_outputs(args, outputs_str: str):
    args.output = [s.strip() for s in outputs_str.split(',')]


if args.prompt:  # Prompt test model, output resolution, output mode
    model_files = [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.tar")] \
        + [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.pth")]
    args.model = interact.input_enum('Specify test model:', model_files,
                                     err_msg='No such model file')
    args.output_res = interact.input_ex('Specify output resolution:',
                                        default='')
    set_outputs(args, interact.input_ex('Specify the outputs | [perf,color,depth,layers,diffuse,specular]/all:',
                                        default='perf,color'))
    args.output_type = interact.input_enum('Specify the output type | image/video:',
                                           ['image', 'video'],
                                           err_msg='Wrong output type',
                                           default='image')
args.output_res = tuple(int(s) for s in reversed(args.output_res.split('x'))) if args.output_res \
    else None
args.output_flags = {
    item: item in args.output or 'all' in args.output
    for item in ['perf', 'color', 'depth', 'layers', 'diffuse', 'specular']
}
args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None

if args.time:
    enable_perf()

dataset = DatasetFactory.load(data_desc_path, res=args.output_res,
                              load_images=args.output_flags['perf'],
                              views_to_load=args.views)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")

Nianchen Deng's avatar
sync    
Nianchen Deng committed
86
87
RAYS_PER_BATCH = dataset.res[0] * dataset.res[1] // 4

Nianchen Deng's avatar
sync    
Nianchen Deng committed
88
89
90

model_path: Path = nets_dir / args.model
model_name = model_path.parent.name
Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
92
93
94
95
96
97
98
99
states, _ = netio.load_checkpoint(model_path)
if args.samples:
    states['args']['n_samples'] = args.samples
model = mdl.deserialize(states,
                        raymarching_early_stop_tolerance=0.01,
                        raymarching_chunk_size_or_sections=None,
                        perturb_sample=False).to(device.default()).eval()
print(f"model: {model_name} ({model._get_name()})")
print("args:", json.dumps(model.args))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
100
101
102
103
104
105
106
107
108

run_dir = model_path.parent
output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}"
output_dataset_id = '%s%s' % (
    dataset.name,
    f'_{args.output_res[1]}x{args.output_res[0]}' if args.output_res else ''
)


Nianchen Deng's avatar
sync    
Nianchen Deng committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# 1. Initialize data loader
data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
                         shuffle=False, enable_preload=not args.time,
                         color=color.from_str(model.args['color']))

# 3. Test on dataset
print("Begin test, batch size is %d" % RAYS_PER_BATCH)

i = 0
offset = 0
chns = model.chns('color')
n = dataset.n_views
total_pixels = math.prod([n, *dataset.res])

out = {}
if args.output_flags['perf'] or args.output_flags['color']:
    out['color'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['diffuse']:
    out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['specular']:
    out['specular'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['depth']:
    out['depth'] = torch.full([total_pixels, 1], math.huge, device=device.default())
gt_images = torch.empty_like(out['color']) if dataset.image_path else None

tot_time = 0
tot_iters = len(data_loader)
progress_bar(i, tot_iters, 'Inferring...')
for data in data_loader:
    if args.output_flags['perf']:
        test_perf = Perf.Node("Test")
    n_rays = data['rays_o'].size(0)
    idx = slice(offset, offset + n_rays)
    ret = model(data, *out.keys())
    
    if ret is not None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
145
        for key in out:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146
147
            if key not in ret:
                out[key] = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
148
            else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
                if 'rays_filter' in ret:
                    out[key][idx][ret['rays_filter']] = ret[key]
                else:
                    out[key][idx] = ret[key]
    if args.output_flags['perf']:
        test_perf.close()
        torch.cuda.synchronize()
        tot_time += test_perf.duration()
    if gt_images is not None:
        gt_images[idx] = data['color']
    i += 1
    progress_bar(i, tot_iters, 'Inferring...')
    offset += n_rays

# 4. Save results
print('Saving results...')
output_dir.mkdir(parents=True, exist_ok=True)

out = {key: value for key, value in out.items() if value is not None}
for key in out:
    out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]])
    if key == 'color' or key == 'diffuse' or key == 'specular':
        out[key] = out[key].permute(0, 3, 1, 2)

if args.output_flags['perf']:
    perf_errors = torch.full([n], math.nan)
    perf_ssims = torch.full([n], math.nan)
    if gt_images is not None:
        gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2)
        for i in range(n):
            perf_errors[i] = nn_f.mse_loss(gt_images[i], out['color'][i]).item()
            perf_ssims[i] = ssim(gt_images[i:i + 1], out['color'][i:i + 1]).item() * 100
    perf_mean_time = tot_time / n
    perf_mean_error = torch.mean(perf_errors).item()
    perf_name = f'perf_{output_dataset_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv'

    # Remove old performance reports
    for file in output_dir.glob(f'perf_{output_dataset_id}*'):
        file.unlink()

    # Save new performance reports
    with (output_dir / perf_name).open('w') as fp:
        fp.write('View, PSNR, SSIM\n')
        fp.writelines([
            f'{dataset.indices[i]}, '
            f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n'
            for i in range(n)
        ])

    error_images = ((gt_images - out['color'])**2).sum(1, True) / chns
    error_images = (error_images / 1e-2).clamp(0, 1) * 255
    error_images = img.torch2np(error_images)
    error_images = np.asarray(error_images, dtype=np.uint8)
    output_subdir = output_dir / f"{output_dataset_id}_error"
    output_subdir.mkdir(exist_ok=True)
    for i in range(n):
        heat_img = cv2.applyColorMap(error_images[i], cv2.COLORMAP_JET)  # 注意此处的三通道热力图是cv2专有的GBR排列
        cv2.imwrite(f'{output_subdir}/{dataset.indices[i]:0>4d}.png', heat_img)

for output_type in ['color', 'diffuse', 'specular']:
    if output_type not in out:
        continue
    if args.output_type == 'video':
        output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4"
        img.save_video(out[output_type], output_file, 30)
    else:
        output_subdir = output_dir / f"{output_dataset_id}_{output_type}"
        output_subdir.mkdir(exist_ok=True)
        img.save(out[output_type],
                 [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])

if 'depth' in out:
    colored_depths = img.colorize_depthmap(out['depth'][..., 0], model.args['sample_range'])
    if args.output_type == 'video':
        output_file = output_dir / f"{output_dataset_id}_depth.mp4"
        img.save_video(colored_depths, output_file, 30)
    else:
        output_subdir = output_dir / f"{output_dataset_id}_depth"
        output_subdir.mkdir(exist_ok=True)
        img.save(colored_depths, [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
        #output_subdir = output_dir / f"{output_dataset_id}_bins"
        # output_dir.mkdir(exist_ok=True)
        #img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])

if args.time:
    s = "Performance Report ==>\n"
    res = get_perf_result()
    if res is None:
        s += "No available data.\n"
    else:
        for key, val in res.items():
            path_segs = key.split("/")
            s += "  " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
    print(s)