class NeRF(Module):

    def __init__(self, *, x_chns, density_chns, color_chns, nf, n_layers,
                 d_chns=0, d_nf=0, act='relu', skips=[], with_layer_norm=False,
                 density_out_act='relu', color_out_act='sigmoid', feature_layer=False):
        super().__init__()
        self.x_chns = x_chns
        self.d_chns = d_chns
        self.field = FcBlock(in_chns=x_chns, out_chns=None, nf=nf, n_layers=n_layers,
                             skips=skips, act=act, with_ln=with_layer_norm)
        self.density_out = FcLayer(nf, density_chns, density_out_act, with_ln=False) \
            if density_chns > 0 else None
        if color_chns == 0:
            self.color_out = None
        elif d_chns > 0:
            self.feature_layer = feature_layer and FcLayer(nf, nf, with_ln=False)
            self.color_out = FcBlock(in_chns=nf + d_chns, out_chns=color_chns,
                                     nf=d_nf or nf // 2, n_layers=1,
                                     act=act, out_act=color_out_act, with_ln=with_layer_norm)
            self.with_dir = True
        else:
            self.color_out = FcLayer(nf, color_chns, color_out_act, with_ln=False)
            self.with_dir = False

    def forward(self, inputs: NetInput, *outputs: str, field_out: torch.Tensor = None, **kwargs) -> NetOutput:
        ret = NetOutput()
        if field_out is None:
            field_out = self.field(inputs.x)
        if 'field_out' in outputs:
            ret.field_out = field_out
        if 'densities' in outputs and self.density_out:
            ret.densities = self.density_out(field_out)
        if 'colors' in outputs and self.color_out:
            if self.with_dir:
                if self.feature_layer:
                    h = self.feature_layer(field_out)
                h = union(h, inputs.d)
            else:
                h = field_out
            ret.colors = self.color_out(h)
        return ret

    def get_exporter(self):
        return ModelExporter(self.infer, "densities", "colors", x=[self.x_chns], d=[self.d_chns])

    def infer(self, x: torch.Tensor, d: torch.Tensor = None, f: torch.Tensor = None):
        return tuple(self._forward(NetInput(x, d, f), "colors", "densities").values())


class NerfAdvCore(Module):

    def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int,
                 density_net: dict, color_net: dict, specular_net: dict = None,
                 appearance="decomposite", with_layer_norm=False, f_chns=0):
        """
        Create a NeRF-Adv Core Net.
        Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act".
        Other parameters will be properly set automatically.

        :param x_chns `int`: the channels of input "position"
        :param d_chns `int`: the channels of input "direction"
        :param density_chns `int`: the channels of output "density"
        :param color_chns `int`: the channels of output "color"
        :param density_net_params `dict`: parameters for the density net
        :param color_net_params `dict`: parameters for the color net
        :param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None
        :param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite"
        """
        super().__init__()
        self.input_f = f_chns > 0
        self.density_chns = density_chns
        self.color_chns = color_chns
        self.specular_feature_chns = color_net["nf"] if specular_net else 0
        self.color_feature_chns = density_net["nf"]
        self.appearance = appearance
        self.density_net = FcBlock(**density_net,
                                   in_chns=x_chns + f_chns,
                                   out_chns=self.density_chns + self.color_feature_chns,
                                   out_act='relu',
                                   with_ln=with_layer_norm)
        if self.appearance == "newtype":
            self.specular_feature_chns = d_chns * 3
            self.color_net = FcBlock(**color_net,
                                     in_chns=x_chns + self.color_feature_chns,
                                     out_chns=self.color_chns + self.specular_feature_chns,
                                     with_ln=with_layer_norm)
            self.specular_net = "Placeholder"
        else:
            match = re.match("mlp_basis\((\d+)\)", self.appearance)
            if match is not None:
                basis_dim = int(match.group(1))
                self.color_net = FcBlock(**color_net,
                                         in_chns=x_chns + self.color_feature_chns,
                                         out_chns=self.color_chns * basis_dim)
            elif self.appearance == "decomposite":
                self.color_net = FcBlock(**color_net,
                                         in_chns=x_chns + self.color_feature_chns,
                                         out_chns=self.color_chns + self.specular_feature_chns,
                                         with_ln=with_layer_norm)
            else:
                if specular_net:
                    self.color_net = FcBlock(**color_net,
                                             in_chns=x_chns + self.color_feature_chns,
                                             out_chns=self.specular_feature_chns,
                                             with_ln=with_layer_norm)
                else:
                    self.color_net = FcBlock(**color_net,
                                             in_chns=x_chns + d_chns + self.color_feature_chns,
                                             out_chns=self.color_chns,
                                             with_ln=with_layer_norm)
            self.specular_net = FcBlock(**specular_net,
                                        in_chns=d_chns + self.specular_feature_chns,
                                        out_chns=self.color_chns,
                                        with_ln=with_layer_norm) if specular_net else None

    def forward(self, inputs: NetInput, *outputs: str, features: torch.Tensor = None, **kwargs) -> NetOutput:
        output_shape = inputs.shape

        ret: NetOutput = {}

        if 'densities' in outputs or 'features' in outputs:
            density_net_in = union(inputs.x, inputs.f) if self.input_f else inputs.x
            density_net_out: torch.Tensor = self.density_net(density_net_in)
            densities, features = split(density_net_out, self.density_chns, -1)
            if 'features' in outputs:
                ret['features'] = features
            if 'densities' in outputs:
                ret['densities'] = densities

        if 'colors' in outputs or 'specluars' in outputs or 'diffuses' in outputs:
            if 'densities' in ret:
                valid_mask = ret['densities'][..., 0].detach() >= 1e-4
                indices: tuple[torch.Tensor, ...] = valid_mask.nonzero(as_tuple=True)
                inputs, features = inputs[indices], features[indices]
            else:
                indices = None

            color_net_in = [inputs.x, features]
            if not self.specular_net:
                color_net_in.append(inputs.d)
            color_net_out: torch.Tensor = self.color_net(union(*color_net_in))
            diffuses = color_net_out[..., :self.color_chns]
            specular_features = color_net_out[..., -self.specular_feature_chns:]

            if self.appearance == "newtype":
                speculars = torch.matmul(
                    specular_features.reshape(*inputs.shape, -1, inputs.d.shape[-1]),
                    inputs.d[..., None])[..., 0]
                # TODO relu or not?
                diffuses = diffuses.relu()
                speculars = speculars.relu()
                colors = diffuses + speculars
            else:
                if not self.specular_net:
                    colors = diffuses
                    diffuses = None
                    speculars = None
                else:
                    specular_net_out = self.specular_net(union(inputs.d, specular_features))
                    if self.appearance == "decomposite":
                        speculars = specular_net_out
                        colors = diffuses + speculars
                    else:
                        diffuses = None
                        speculars = None
                        colors = specular_net_out
                colors = torch.sigmoid(colors)  # TODO indent or not?

            def postprocess(data: torch.Tensor):
                return data.new_zeros(*output_shape, data.shape[-1]).index_put(indices, data)\
                    if indices is not None else data

            if 'colors' in outputs:
                ret['colors'] = postprocess(colors)
            if 'diffuses' in outputs and diffuses is not None:
                ret['diffuses'] = postprocess(diffuses)
            if 'speculars' in outputs and speculars is not None:
                ret['speculars'] = postprocess(speculars)
        return ret


class MultiNerf(Module):

    @property
    def n_levels(self):
        return len(self.nets)

    def __init__(self, nets: Iterable[Module]):
        super().__init__()
        self.nets = nets
        for i in len(nets):
            self.add_module(f"Level {i}", nets[i])

    def set_frozen(self, level: int, on: bool):
        for net in self.nets:
            net[level].train(not on)

    def forward(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput:
        L = samples.level
        if L == 0:
            return self.nets[L](inputs, *outputs)
        if samples.features is not None:
            return self.nets[L](NetInput(inputs.x, inputs.d, samples.features), *outputs)
        features = self.nets[0](inputs, 'features')['features']
        for i in range(1, L):
            features = self.nets[i](NetInput(inputs.x, inputs.d, features), 'features')['features']
        samples.features = features
        return self(inputs, *outputs, samples=samples)