import shutil import sys import os import argparse import json import re import numpy as np from typing import Any from pathlib import Path sys.path.append(os.path.abspath(sys.path[0] + '/../../')) from utils.colmap_read_model import read_model, Image def check_model_path(path: Path) -> bool: """ Check whether the specified path contains colmap model files. :param path `Path`: path to check :return `bool`: whether the specified path contains colmap model files """ return all([ (path / f"{f}.bin").exists() for f in ['cameras', 'images', 'points3D'] ]) def get_image_id(im: Image): """ Extract image id from image filename like xxxx001.png :param im `Image`: colmap's image info :return `int`: image id """ return int(re.match(r"\D+(\d+)\.\w+", os.path.split(im.name)[1]).group(1)) def normalize(x: np.ndarray) -> np.ndarray: return x / np.linalg.norm(x) def view_matrix(z: np.ndarray, up: np.ndarray, pos: np.ndarray) -> np.ndarray: """ Construct view matrix from z, up and position. :param z `ndarray(3)`: z axis :param up `ndarray(3): up direction :param pos `ndarray(3)`: center position :return `ndarray(3, 4): view matrix """ vec2 = normalize(z) vec0 = normalize(np.cross(up, vec2)) vec1 = normalize(np.cross(vec2, vec0)) return np.stack([vec0, vec1, vec2, pos], 1) def poses_avg(poses: np.ndarray) -> np.ndarray: """ Calculate the average of the given poses :param poses `ndarray(B, 3, 4)`: poses :return `ndarray(3, 4)`: average pose """ center = np.mean(poses[..., 3]) vec2 = normalize(np.sum(poses[..., 2], 0)) up = np.sum(poses[..., 1], 0) return view_matrix(vec2, up, center) def recenter(poses: np.ndarray, pts: np.ndarray): center = poses[..., 3:].mean(0) # (1, 3, 1) return np.concatenate([poses[..., :3], poses[..., 3:] - center], -1), pts - center[..., 0] poses_ = poses + 0 bottom = np.reshape([0, 0, 0, 1.], [1, 4]) c2w = poses_avg(poses) c2w = np.concatenate([c2w[:3, :4], bottom], -2) bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) poses = np.concatenate([poses[:, :3, :4], bottom], -2) poses = np.linalg.inv(c2w) @ poses poses_[:, :3, :4] = poses[:, :3, :4] poses = poses_ return poses parser = argparse.ArgumentParser() parser.add_argument('dataset', type=str) parser.add_argument('--scale-down', type=int, default=1) args = parser.parse_args() data_dir = Path(args.dataset) scale_down = args.scale_down if check_model_path(data_dir / "input"): model_path = data_dir / "input" else: raise RuntimeError("No colmap model found.") cameras, images, points3D = read_model(model_path, '.bin') print("Colmap model loaded.") print("num_cameras:", len(cameras)) print("num_images:", len(images)) print("num_points3D:", len(points3D)) cam = cameras[1] images = [im for im in images.values()] w2c_mats = np.stack([np.concatenate([np.concatenate([im.qvec2rotmat(), im.tvec.reshape([3, 1])], 1), np.array([[0, 0, 0, 1.]])], 0) for im in images], 0) # (B, 4, 4) c2w_mats = np.linalg.inv(w2c_mats) poses = c2w_mats[:, :3, :] poses[..., 1:3] *= -1 # colmap: [x,-y,-z] -> conventional: [x,y,z] pts = np.array([p.xyz for p in points3D.values()]) poses, pts = recenter(poses, pts) norms = np.linalg.norm(pts, axis=1) near, far = np.percentile(norms, 1), np.percentile(norms, 99) trans_range = np.max(np.linalg.norm(poses[..., 3], axis=1)) print(f"Near: {near}, far: {far}, trans range: {trans_range}") if scale_down > 1: print("Scale images...") from tools import image_scale image_scale.run(data_dir / "input/images", data_dir / f"input/images{scale_down}", data_dir / "input/images", 1. / scale_down) general_desc = { 'color_file': f"view%04d{os.path.splitext(images[0].name)[1]}", 'gl_coord': True, 'view_res': { 'x': cam.width // scale_down, 'y': cam.height // scale_down }, 'cam_params': { 'f': cam.params[0] / scale_down, 'cx': cam.params[1] / scale_down, 'cy': cam.params[2] / scale_down }, 'depth_range': { 'min': max(near, trans_range * 1.1), 'max': far }, # 'samples': [poses.shape[0]], # 'view_centers': poses[..., 3].tolist(), # 'view_rots': poses[:, :3, :3].reshape([-1, 9]).tolist(), # 'views': views } with open(data_dir / "input/dataset.json") as fp: datasets: dict[str, Any] = json.load(fp) for dataset, image_dirs in datasets.items(): if scale_down > 1: dataset = f"{dataset}{scale_down}" view_centers = [] view_rots = [] im_names = [] for image_dir in image_dirs: for i, im in enumerate(images): if im.name.startswith(image_dir): view_centers.append(poses[i, :, 3].tolist()) view_rots.append(poses[i, :3, :3].flatten().tolist()) im_names.append(im.name) # Create symbol links to input images shutil.rmtree(data_dir / dataset, ignore_errors=True) (data_dir / dataset).mkdir() for i, im_name in enumerate(im_names): (data_dir / dataset / (general_desc["color_file"] % i)).symlink_to(f"../input/images{scale_down if scale_down > 1 else ''}/{im_name}") dataset_desc = { **general_desc, "samples": [len(view_centers)], "view_centers": view_centers, "view_rots": view_rots } with open(data_dir / f"{dataset}.json", 'w') as fp: json.dump(dataset_desc, fp, indent=4)