convert_nerf_checkpoint.py 4.03 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import sys
import os
import configargparse
import argparse
import torch
from pathlib import Path
from typing import Any


sys.path.append(os.path.abspath(sys.path[0] + '/../'))


parser = argparse.ArgumentParser()
parser.add_argument('input', type=str)
args = parser.parse_args()


def get_sampler_mode(nerf_config: dict[str, str]) -> str:
    lindisp = nerf_config["lindisp"] == "True"
    spherical = nerf_config.get("spherical") == "True"
    if lindisp:
        return "spherical" if spherical else "xyz_disp"
    return "spherical_radius" if spherical else "xyz"


def convert_network_state(input_network_state: dict[str, torch.Tensor], output_prefix: str) -> dict[str, torch.Tensor]:
    output_model_state: dict[str, torch.Tensor] = {}
    for key, value in input_network_state.items():
        key_parts = key.split(".")
        module = key_parts[0]
        suffix = key_parts[-1]
        match module:
            case "pts_linears":
                output_model_state[f"{output_prefix}core.field.net.layers.{key_parts[1]}.net.0.{suffix}"] = value
            case "alpha_linear":
                output_model_state[f"{output_prefix}core.density_decoder.net.net.0.{suffix}"] = value
            case "feature_linear":
                output_model_state[f"{output_prefix}core.color_decoder.feature_layer.net.0.{suffix}"] = value
            case "views_linears" if nerf_config["use_viewdirs"] == "True":
                output_model_state[f"{output_prefix}core.color_decoder.net.layers.0.net.0.{suffix}"] = value
            case "rgb_linear":
                output_model_state[f"{output_prefix}core.color_decoder.net.layers.1.net.0.{suffix}"] = value
            case "output_linear":
                output_model_state[f"{output_prefix}core.density_decoder.net.net.0.{suffix}"] = value[3:]
                output_model_state[f"{output_prefix}core.color_decoder.net.net.0.{suffix}"] = value[:3]
                
    return output_model_state


# Load nerf's args file and checkpoint file
# checkpoint file is at [nerf_repo_path]/logs/[expname]/xxxx.tar
ckpt_path = Path(args.input)
expdir = ckpt_path.parent
nerf_repo_path = expdir.parent.parent
args_path = expdir / "args.txt"
expname = expdir.stem
ckpt_iters = int(ckpt_path.stem)

config_parser = configargparse.DefaultConfigFileParser()
with open(args_path) as fp:
    nerf_config: dict[str, str] = config_parser.parse(fp)
input_checkpoint: dict[str, Any] = torch.load(ckpt_path)

output_model_state = convert_network_state(input_checkpoint["network_fn_state_dict"], "")
if "network_fine_state_dict" in input_checkpoint:
    output_model_state.update(convert_network_state(
        input_checkpoint["network_fine_state_dict"], "fine_"))

output_args = {
    "model": "NeRF",
    "model_args":{
        "color": "rgb",
        "n_samples": int(nerf_config["N_samples"]),
        "sample_mode": get_sampler_mode(nerf_config),
        "perturb_sampling": nerf_config["perturb"] == "1.0",
        "depth": int(nerf_config["netdepth"]),
        "width": int(nerf_config["netwidth"]),
        "skips": [4],
        "act": "relu",
        "ln": False,
        "color_decoder": "NeRF" if nerf_config["use_viewdirs"] == "True" else "Basic",
        "n_importance": int(nerf_config["N_importance"]),
        "fine_depth": int(nerf_config["netdepth_fine"]),
        "fine_width": int(nerf_config["netwidth_fine"]),
        "xfreqs": int(nerf_config["multires"]),
        "dfreqs": int(nerf_config["multires_views"]),
        "raw_noise_std": float(nerf_config["raw_noise_std"]),
        "near": float(nerf_config.get("near", "1")),
        "far": float(nerf_config.get("far", "10")),
        "white_bg": nerf_config["white_bkgd"] == "True",
    },
    "trainer": None 
}

output_path = nerf_repo_path.joinpath(nerf_config["datadir"]) / "_nets" / \
    Path(nerf_config.get("train_file", "train.json")).stem / expname / \
    f"checkpoint_{ckpt_iters}.tar"

output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
    "args": output_args,
    "states": {
        "model": output_model_state
    }
}, output_path)
print(f"Checkpoint is saved to {output_path}")