#include #include #include #include "cuda_utils.h" #include "cutil_math.h" // required for float3 vector math template __device__ uint fast_hash(const uvec gpos, const uint hashmap_size) { static_assert(DIMS <= 7, "fast_hash can only hash up to 7 dimensions."); // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory // coherence and is sufficient for our use case of obtaining a uniformly colliding index from // high-dimensional coordinates. constexpr uint primes[7] = {1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737}; uint result = gpos[0]; #pragma unroll for (uint dim = 1; dim < DIMS; ++dim) result ^= gpos[dim] * primes[dim]; return result % hashmap_size; } template __device__ uint gidx(const uvec gpos, const uvec res) { uint index = gpos[0] * res[1] + gpos[1]; #pragma unroll for (uint dim = 2; dim < DIMS; ++dim) index = index * res[dim] + gpos[dim]; return index; } template __global__ void multires_hash_encode_kernel(const uint n, const uint coarse_levels, const uvec *__restrict__ res_list, const vec *__restrict__ hash_table, const uint *__restrict__ hash_table_offsets, const fvec *__restrict__ x, vec *__restrict__ o_encoded, const bool requires_grad, float *__restrict__ o_weights, uint *__restrict__ o_indices) { const uint i = blockDim.x * blockIdx.x + threadIdx.x; if (i >= n) return; const uint level = blockIdx.y; const uint hash_table_offset = hash_table_offsets[level]; const uint hash_table_size = hash_table_offsets[level + 1] - hash_table_offset; const uvec res = res_list[level]; hash_table += hash_table_offset; fvec pos = x[i]; uvec gpos; #pragma unroll for (uint dim = 0; dim < DIMS; ++dim) { pos[dim] *= res[dim] - 1; gpos[dim] = (uint)floor(pos[dim]); pos[dim] -= gpos[dim]; } auto hash_idx = [&](const uvec gpos) { uint idx; if (level >= coarse_levels) idx = fast_hash(gpos, hash_table_size); else idx = gidx(gpos, res); return idx; }; // N-linear interpolation vec result = {}; auto n_corners = (1 << DIMS); #pragma unroll for (uint corner_idx = 0; corner_idx < n_corners; ++corner_idx) { float weight = 1; uvec corner_gpos; #pragma unroll for (uint dim = 0; dim < DIMS; ++dim) { if ((corner_idx & (1 << dim)) == 0) { weight *= 1 - pos[dim]; corner_gpos[dim] = gpos[dim]; } else { weight *= pos[dim]; corner_gpos[dim] = gpos[dim] + 1; } } auto idx = hash_idx(corner_gpos); auto val = hash_table[idx]; #pragma unroll for (uint feature = 0; feature < FEATURE_DIMS; ++feature) { result[feature] += (T)(weight * (float)val[feature]); } // For backward if (requires_grad) { auto j = (level * n + i) * n_corners + corner_idx; o_indices[j] = idx + hash_table_offset; o_weights[j] = weight; } } o_encoded[level * n + i] = result; } template void multires_hash_encode_kernel_wrapper(const uint n, const uint levels, const uint coarse_levels, const uint dims, const float *x, const uint *res_list, const vec *hash_table, const uint *hash_table_offsets, vec *o_encoded, const bool requires_grad, float *o_weights, uint *o_indices) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const uint threads = opt_n_threads(n); const dim3 blocks = {(uint)ceil((float)n / threads), levels, 1}; #define DISPATCH_KERNEL_CASE(__DIMS__) \ case __DIMS__: \ multires_hash_encode_kernel<<>>( \ n, coarse_levels, (const uvec<__DIMS__> *)res_list, hash_table, hash_table_offsets, \ (const fvec<__DIMS__> *)x, o_encoded, requires_grad, o_weights, o_indices); \ break; switch (dims) { DISPATCH_KERNEL_CASE(2) DISPATCH_KERNEL_CASE(3) default: throw std::invalid_argument("'dims' should be 2 or 3"); } CUDA_CHECK_ERRORS(); #undef DISPATCH_KERNEL_CASE } template void multires_hash_encode_kernel_wrapper(const uint n, const uint levels, const uint coarse_levels, const uint dims, const uint feature_dims, const float *x, const uint *res_list, const void *hash_table, const uint *hash_table_offsets, void *o_encoded, const bool requires_grad, float *o_weights, uint *o_indices) { #define KERNEL_WRAPPER_CASE(__FEATURE_DIMS__) \ case __FEATURE_DIMS__: \ multires_hash_encode_kernel_wrapper( \ n, levels, coarse_levels, dims, x, res_list, \ (const vec *)hash_table, hash_table_offsets, \ (vec *)o_encoded, requires_grad, o_weights, o_indices); \ break; switch (feature_dims) { KERNEL_WRAPPER_CASE(1) KERNEL_WRAPPER_CASE(2) KERNEL_WRAPPER_CASE(4) KERNEL_WRAPPER_CASE(8) KERNEL_WRAPPER_CASE(16) default: throw std::invalid_argument("'feature_dims' should be 1, 2, 4, 8, 16"); } #undef KERNEL_WRAPPER_CASE } #if !defined(__CUDA_NO_HALF_CONVERSIONS__) void multires_hash_encode_kernel_wrapper_halfp(const uint n, const uint levels, const uint coarse_levels, const uint dims, const uint feature_dims, const float *x, const uint *res_list, const void *hash_table, const uint *hash_table_offsets, void *o_encoded, const bool requires_grad, float *o_weights, uint *o_indices) { multires_hash_encode_kernel_wrapper<__half>(n, levels, coarse_levels, dims, feature_dims, x, res_list, hash_table, hash_table_offsets, o_encoded, requires_grad, o_weights, o_indices); } #endif void multires_hash_encode_kernel_wrapper_fullp(const uint n, const uint levels, const uint coarse_levels, const uint dims, const uint feature_dims, const float *x, const uint *res_list, const void *hash_table, const uint *hash_table_offsets, void *o_encoded, const bool requires_grad, float *o_weights, uint *o_indices) { multires_hash_encode_kernel_wrapper(n, levels, coarse_levels, dims, feature_dims, x, res_list, hash_table, hash_table_offsets, o_encoded, requires_grad, o_weights, o_indices); }