import sys
import os
import argparse
import json
import numpy as np

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

from utils import misc
from utils.colmap_read_model import read_model

parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str)
args = parser.parse_args()

data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, "train.json")


cameras, images, points3D = read_model(os.path.join(data_dir, 'sparse/0'), '.bin')
print("Model loaded.")
print("num_cameras:", len(cameras))
print("num_images:", len(images))
print("num_points3D:", len(points3D))

cam = cameras[list(cameras.keys())[0]]

views = np.array([int(images[img_id].name[5:9]) for img_id in images])
view_centers = np.array([images[img_id].tvec for img_id in images])
view_rots = []
for img_id in images:
    im = images[img_id]
    R = im.qvec2rotmat()
    view_rots.append(R.reshape([9]).tolist())
view_rots = np.array(view_rots)

indices = np.argsort(views)
views = views[indices]
view_centers = view_centers[indices]
view_rots = view_rots[indices]

pts = np.array([points3D[pt_id].xyz for pt_id in points3D])
zvals = np.sqrt(np.sum(pts * pts, 1))
dataset_desc = {
    'view_file_pattern': f"images/image%04d.jpg",
    'gl_coord': True,
    'view_res': {
        'x': cam.width,
        'y': cam.height
    },
    'cam_params': {
        'fx': cam.params[0],
        'fy': cam.params[0],
        'cx': cam.params[1],
        'cy': cam.params[2]
    },
    'range': {
        'min': np.min(view_centers, 0).tolist() + [0, 0],
        'max': np.max(view_centers, 0).tolist() + [0, 0]
    },
    'depth_range': {
        'min': np.min(zvals),
        'max': np.max(zvals)
    },
    'samples': [len(view_centers)],
    'view_centers': view_centers.tolist(),
    'view_rots': view_rots.tolist(),
    'views': views.tolist()
}

with open(out_desc_path, 'w') as fp:
    json.dump(dataset_desc, fp, indent=4)