#include "utils.h" void multires_hash_encode_kernel_wrapper_fullp(const uint n, const uint levels, const uint coarse_levels, const uint dims, const uint feature_dims, const float *x, const uint *res_list, const void *hash_table, const uint *hash_table_offsets, void *o_encoded, const bool requires_grad, float *o_weights, uint *o_indices); at::Tensor multires_hash_encode(const int levels, const int coarse_levels, at::Tensor x, at::Tensor res_list, at::Tensor hash_table, at::Tensor hash_table_offsets) { CHECK_CUDA_CONT_TENSOR(FLOAT, x); CHECK_CUDA_CONT_TENSOR(FLOAT, hash_table); CHECK_CUDA_CONT_TENSOR(INT, res_list); CHECK_CUDA_CONT_TENSOR(INT, hash_table_offsets); const uint n = x.size(0); const uint dims = x.size(1); const uint feature_dims = hash_table.size(-1); at::Tensor encoded = torch::empty({levels, n, feature_dims}, at::device(x.device()).dtype(hash_table.dtype())); multires_hash_encode_kernel_wrapper_fullp( n, (uint)levels, (uint)coarse_levels, dims, feature_dims, x.data_ptr(), (uint *)res_list.data_ptr(), hash_table.data_ptr(), (uint *)hash_table_offsets.data_ptr(), encoded.data_ptr(), false, nullptr, nullptr); return encoded; } std::tuple multires_hash_encode_with_grad(const int levels, const int coarse_levels, at::Tensor x, at::Tensor res_list, at::Tensor hash_table, at::Tensor hash_table_offsets) { CHECK_CUDA_CONT_TENSOR(FLOAT, x); CHECK_CUDA_CONT_TENSOR(FLOAT, hash_table); CHECK_CUDA_CONT_TENSOR(INT, res_list); CHECK_CUDA_CONT_TENSOR(INT, hash_table_offsets); const uint n = x.size(0); const uint dims = x.size(1); const uint feature_dims = hash_table.size(-1); at::Tensor encoded = torch::empty({levels, n, feature_dims}, at::device(x.device()).dtype(hash_table.dtype())); at::Tensor weights = torch::empty({levels, n, 1 << dims}, at::device(x.device()).dtype(at::kFloat)); at::Tensor indices = torch::empty({levels, n, 1 << dims}, at::device(x.device()).dtype(at::kInt)); multires_hash_encode_kernel_wrapper_fullp( n, (uint)levels, (uint)coarse_levels, dims, feature_dims, x.data_ptr(), (uint *)res_list.data_ptr(), hash_table.data_ptr(), (uint *)hash_table_offsets.data_ptr(), encoded.data_ptr(), true, weights.data_ptr(), (uint *)indices.data_ptr()); return std::make_tuple(encoded, weights, indices); }