import bpy
import json
import os
import math
import numpy as np
from typing import List, Tuple
from itertools import product


class Gen:
    def __init__(self, root_dir: str, dataset_name: str, *,
                 res: tuple[int, int],
                 fov: float,
                 samples: list[int]) -> None:
        self.res = res
        self.fov = fov
        self.samples = samples

        self.scene = bpy.context.scene
        self.cam_obj = self.scene.camera
        self.cam = self.cam_obj.data
        self.scene.render.resolution_x = self.res[0]
        self.scene.render.resolution_y = self.res[1]
        self.init_camera()

        self.root_dir = root_dir
        self.data_dir = f"{root_dir}/{dataset_name}/"
        self.data_name = dataset_name
        self.data_desc_file = f'{root_dir}/{dataset_name}.json'

    def init_camera(self):
        if self.fov < 0:
            self.cam.type = 'PANO'
            self.cam.cycles.panorama_type = 'EQUIRECTANGULAR'
        else:
            self.cam.type = 'PERSP'
            self.cam.lens_unit = 'FOV'
            self.cam.angle = math.radians(self.fov)
        self.cam.dof.use_dof = False
        self.cam.clip_start = 0.1
        self.cam.clip_end = 1000

    def init_desc(self):
        return None

    def save_desc(self):
        with open(self.data_desc_file, 'w') as fp:
            json.dump(self.desc, fp, indent=4)

    def add_sample(self, i, x: list[float], render_only=False):
        self.cam_obj.location = x[:3]
        if len(x) > 3:
            self.cam_obj.rotation_euler = [math.radians(x[4]), math.radians(x[3]), 0]
        self.scene.render.filepath = self.data_dir + self.desc['color_file'] % i
        bpy.ops.render.render(write_still=True)
        if not render_only:
            self.desc['view_centers'].append(x[:3])
            if len(x) > 3:
                self.desc['view_rots'].append(x[3:])
            self.save_desc()

    def gen_grid(self):
        start_view = len(self.desc['view_centers'])
        ranges = [
            np.linspace(self.desc['range']['min'][i],
                        self.desc['range']['max'][i],
                        self.desc['samples'][i])
            for i in range(len(self.desc['samples']))
        ]
        for i, x in enumerate(product(*ranges)):
            if i >= start_view:
                self.add_sample(i, list(x))

    def gen_rand(self):
        pass

    def __call__(self):
        os.makedirs(self.data_dir, exist_ok=True)
        if os.path.exists(self.data_desc_file):
            with open(self.data_desc_file, 'r') as fp:
                self.desc = json.load(fp)
        else:
            self.desc = self.init_desc()

        # Render missing views in data desc
        for i in range(len(self.desc['view_centers'])):
            if not os.path.exists(self.data_dir + self.desc['color_file'] % i):
                x: list[float] = self.desc['view_centers'][i]
                if 'view_rots' in self.desc:
                    x += self.desc['view_rots'][i]
                self.add_sample(i, x, render_only=True)

        if len(self.desc['samples']) == 1:
            self.gen_rand()
        else:
            self.gen_grid()


class GenView(Gen):

    def __init__(self, root_dir: str, dataset_name: str, *,
                 res: tuple[int, int], fov: float, samples: list[int],
                 tbox: tuple[float, float, float], rbox: tuple[float, float]) -> None:
        super().__init__(root_dir, dataset_name, res=res, fov=fov, samples=samples)
        self.tbox = tbox
        self.rbox = rbox

    def init_desc(self):
        return {
            'color_file': 'view_%04d.png',
            "gl_coord": True,
            'view_res': {
                'x': self.res[0],
                'y': self.res[1]
            },
            'cam_params': {
                'fov': self.fov,
                'cx': 0.5,
                'cy': 0.5,
                'normalized': True
            },
            'range': {
                'min': [-self.tbox[0] / 2, -self.tbox[1] / 2, -self.tbox[2] / 2,
                        -self.rbox[0] / 2, -self.rbox[1] / 2],
                'max': [self.tbox[0] / 2, self.tbox[1] / 2, self.tbox[2] / 2,
                        self.rbox[0] / 2, self.rbox[1] / 2]
            },
            'samples': self.samples,
            'view_centers': [],
            'view_rots': []
        }

    def gen_rand(self):
        start_view = len(self.desc['view_centers'])
        n = self.desc['samples'][0] - start_view
        range_min = np.array(self.desc['range']['min'])
        range_max = np.array(self.desc['range']['max'])
        samples = (range_max - range_min) * np.random.rand(n, 5) + range_min
        for i in range(n):
            self.add_sample(i + start_view, list(samples[i]))


class GenPano(Gen):

    def __init__(self, root_dir: str, dataset_name: str, *,
                 samples: list[int], depth_range: tuple[float, float],
                 tbox: tuple[float, float, float] = None) -> None:
        self.depth_range = depth_range
        self.tbox = tbox
        super().__init__(root_dir, dataset_name, res=[4096, 2048], fov=-1, samples=samples)

    def init_desc(self):
        range = {
            'range': {
                'min': [-self.tbox[0] / 2, -self.tbox[1] / 2, -self.tbox[2] / 2],
                'max': [self.tbox[0] / 2, self.tbox[1] / 2, self.tbox[2] / 2]
            }
        } if self.tbox else {}
        return {
            'color_file': 'view_%04d.png',
            "gl_coord": True,
            'view_res': {
                'x': self.res[0],
                'y': self.res[1]
            },
            "cam_params": {
                "type": "pano"
            },
            **range,
            "depth_range": {
                "min": self.depth_range[0],
                "max": self.depth_range[1]
            },
            'samples': self.samples,
            'view_centers': []
        }

    def gen_rand(self):
        start_view = len(self.desc['view_centers'])
        n = self.desc['samples'][0] - start_view
        r_max = self.desc['depth_range']['min']
        pts = (np.random.rand(n * 5, 3) - 0.5) * 2 * r_max
        samples = pts[np.linalg.norm(pts, axis=1) < r_max][:n]
        for i in range(n):
            self.add_sample(i + start_view, list(samples[i]))