import sys import os import configargparse import argparse import torch from pathlib import Path from typing import Any sys.path.append(os.path.abspath(sys.path[0] + '/../')) parser = argparse.ArgumentParser() parser.add_argument('input', type=str) args = parser.parse_args() def get_sampler_mode(nerf_config: dict[str, str]) -> str: lindisp = nerf_config["lindisp"] == "True" spherical = nerf_config.get("spherical") == "True" if lindisp: return "spherical" if spherical else "xyz_disp" return "spherical_radius" if spherical else "xyz" def convert_network_state(input_network_state: dict[str, torch.Tensor], output_prefix: str) -> dict[str, torch.Tensor]: output_model_state: dict[str, torch.Tensor] = {} for key, value in input_network_state.items(): key_parts = key.split(".") module = key_parts[0] suffix = key_parts[-1] match module: case "pts_linears": output_model_state[f"{output_prefix}core.field.net.layers.{key_parts[1]}.net.0.{suffix}"] = value case "alpha_linear": output_model_state[f"{output_prefix}core.density_decoder.net.net.0.{suffix}"] = value case "feature_linear": output_model_state[f"{output_prefix}core.color_decoder.feature_layer.net.0.{suffix}"] = value case "views_linears" if nerf_config["use_viewdirs"] == "True": output_model_state[f"{output_prefix}core.color_decoder.net.layers.0.net.0.{suffix}"] = value case "rgb_linear": output_model_state[f"{output_prefix}core.color_decoder.net.layers.1.net.0.{suffix}"] = value case "output_linear": output_model_state[f"{output_prefix}core.density_decoder.net.net.0.{suffix}"] = value[3:] output_model_state[f"{output_prefix}core.color_decoder.net.net.0.{suffix}"] = value[:3] return output_model_state # Load nerf's args file and checkpoint file # checkpoint file is at [nerf_repo_path]/logs/[expname]/xxxx.tar ckpt_path = Path(args.input) expdir = ckpt_path.parent nerf_repo_path = expdir.parent.parent args_path = expdir / "args.txt" expname = expdir.stem ckpt_iters = int(ckpt_path.stem) config_parser = configargparse.DefaultConfigFileParser() with open(args_path) as fp: nerf_config: dict[str, str] = config_parser.parse(fp) input_checkpoint: dict[str, Any] = torch.load(ckpt_path) output_model_state = convert_network_state(input_checkpoint["network_fn_state_dict"], "") if "network_fine_state_dict" in input_checkpoint: output_model_state.update(convert_network_state( input_checkpoint["network_fine_state_dict"], "fine_")) output_args = { "model": "NeRF", "model_args":{ "color": "rgb", "n_samples": int(nerf_config["N_samples"]), "sample_mode": get_sampler_mode(nerf_config), "perturb_sampling": nerf_config["perturb"] == "1.0", "depth": int(nerf_config["netdepth"]), "width": int(nerf_config["netwidth"]), "skips": [4], "act": "relu", "ln": False, "color_decoder": "NeRF" if nerf_config["use_viewdirs"] == "True" else "Basic", "n_importance": int(nerf_config["N_importance"]), "fine_depth": int(nerf_config["netdepth_fine"]), "fine_width": int(nerf_config["netwidth_fine"]), "xfreqs": int(nerf_config["multires"]), "dfreqs": int(nerf_config["multires_views"]), "raw_noise_std": float(nerf_config["raw_noise_std"]), "near": float(nerf_config.get("near", "1")), "far": float(nerf_config.get("far", "10")), "white_bg": nerf_config["white_bkgd"] == "True", }, "trainer": None } output_path = nerf_repo_path.joinpath(nerf_config["datadir"]) / "_nets" / \ Path(nerf_config.get("train_file", "train.json")).stem / expname / \ f"checkpoint_{ckpt_iters}.tar" output_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ "args": output_args, "states": { "model": output_model_state } }, output_path) print(f"Checkpoint is saved to {output_path}")