voxel_sampler_export3d.py 5.11 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import sys
import argparse
import torch

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str,
                    help='The model file to load for testing')
parser.add_argument('-r', '--output-rays', type=int, default=100,
                    help='How many rays to output')
parser.add_argument('-p', '--prompt', action='store_true',
                    help='Interactive prompt mode')
parser.add_argument('dataset', type=str,
                    help='Dataset description file')
args = parser.parse_args()


import model as mdl
from utils import misc
from utils import color
from utils import interact
from utils import device
from data.dataset_factory import *
from data.loader import DataLoader
from modules import Samples, Voxels
from model.nsvf import NSVF

model: NSVF
samples: Samples

DATA_LOADER_CHUNK_SIZE = 1e8


data_desc_path = args.dataset if args.dataset.endswith('.json') \
    else os.path.join(args.dataset, 'train.json')
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'


def get_model_files(datadir):
    model_files = []
    for root, _, files in os.walk(datadir):
        model_files += [
            os.path.join(root, file).replace(datadir, '')
            for file in files if file.endswith('.tar') or file.endswith('.pth')
        ]
    return model_files


if args.prompt:  # Prompt test model, output resolution, output mode
    model_files = get_model_files(data_dir)
    args.model = interact.input_enum('Specify test model:', model_files,
                                     err_msg='No such model file')
    args.output_rays = interact.input_ex('Specify number of rays to output:',
                                         interact.input_to_int(), default=10)

model_path = os.path.join(data_dir, args.model)
model_name = os.path.splitext(os.path.basename(model_path))[0]
model, iters = mdl.load(model_path, {"perturb_sample": False})
model.to(device.default()).eval()
model_class = model.__class__.__name__
model_args = model.args
print(f"model: {model_name} ({model_class})")
print("args:", json.dumps(model.args0))

dataset = DatasetFactory.load(data_desc_path)
print("Dataset loaded: " + data_desc_path)

run_dir = os.path.dirname(model_path) + '/'
output_dir = f"{run_dir}output_{int(model_name.split('_')[-1])}"


if __name__ == "__main__":
    with torch.no_grad():
        # 1. Initialize data loader
        data_loader = DataLoader(dataset, args.output_rays, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
                                 shuffle=True, enable_preload=True,
                                 color=color.from_str(model.args['color']))
        sys.stdout.write("Export samples...\r")
        for _, rays_o, rays_d, extra in data_loader:
            samples, rays_mask = model.sampler(rays_o, rays_d, model.space)
            invalid_rays_o = rays_o[torch.logical_not(rays_mask)]
            invalid_rays_d = rays_d[torch.logical_not(rays_mask)]
            rays_o = rays_o[rays_mask]
            rays_d = rays_d[rays_mask]
            break
        print("Export samples...Done")
        
        os.makedirs(output_dir, exist_ok=True)

        export_data = {}

        if model.space.bbox is not None:
            export_data['bbox'] = model.space.bbox.tolist()
        if isinstance(model.space, Voxels):
            export_data['voxel_size'] = model.space.voxel_size.tolist()
            export_data['voxels'] = model.space.voxels.tolist()

            if False:
                voxel_access_counts = torch.zeros(model.space.n_voxels, dtype=torch.long,
                                                device=device.default())
                iters_in_epoch = 0
                data_loader.batch_size = 2 ** 20
                for _, rays_o1, rays_d1, _ in data_loader:
                    model(rays_o1, rays_d1,
                        raymarching_tolerance=0.5,
                        raymarching_chunk_size=0,
                        voxel_access_counts=voxel_access_counts)
                    iters_in_epoch += 1
                    percent = iters_in_epoch / len(data_loader) * 100
                    sys.stdout.write(f'Export voxel access counts...{percent:.1f}%   \r')
                export_data['voxel_access_counts'] = voxel_access_counts.tolist()
                print("Export voxel access counts...Done  ")

        export_data.update({
            'rays_o': rays_o.tolist(),
            'rays_d': rays_d.tolist(),
            'invalid_rays_o': invalid_rays_o.tolist(),
            'invalid_rays_d': invalid_rays_d.tolist(),
            'samples': {
                'depths': samples.depths.tolist(),
                'dists': samples.dists.tolist(),
                'voxel_indices': samples.voxel_indices.tolist()
            }
        })
        with open(f'{output_dir}/debug_voxel_sampler_export3d.json', 'w') as fp:
            json.dump(export_data, fp)
        print("Write JSON file...Done")

        args.output_rays
        print(f"Rays: total {args.output_rays}, valid {rays_o.size(0)}")
        print(f"Samples: average {samples.voxel_indices.ne(-1).sum(-1).float().mean().item()} per ray")