Encoder.cu 1.65 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
#include "Encoder.h"
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
#include "../utils/cuda.h"
Nianchen Deng's avatar
Nianchen Deng committed
3

Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
/// idx3.z = 0: x, y, z, sin(x), sin(y), sin(z), cos(x), cos(y), cos(z)
/// idx3.z = 1: sin(2x), sin(2y), sin(2z), cos(2x), cos(2y), cos(2z)
Nianchen Deng's avatar
Nianchen Deng committed
6
/// ...
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
/// idx3.z = n_freq-1: sin(2^(n_freq-1)x), sin(2^(n_freq-1)y), sin(2^(n_freq-1)z),
Nianchen Deng's avatar
Nianchen Deng committed
8
///                    cos(2^(n_freq-1)x), cos(2^(n_freq-1)y), cos(2^(n_freq-1)z)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
/// Dispatch (n_batch, n_chns, n_freqs)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
__global__ void cu_encode(float *output, float *input, float *freqs, uint n) {
Nianchen Deng's avatar
Nianchen Deng committed
11
    glm::uvec3 idx3 = IDX3;
Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
    if (idx3.x >= n)
Nianchen Deng's avatar
Nianchen Deng committed
13
        return;
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
    uint inChns = blockDim.y, nFreqs = blockDim.z;
Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
16
17
18
    uint i = idx3.x, chn = idx3.y, freq = idx3.z;
    uint elem = i * inChns + chn;
    uint outChns = inChns * (nFreqs * 2 + 1);
    uint base = i * outChns + chn;
Nianchen Deng's avatar
Nianchen Deng committed
19
    if (idx3.x == 0)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
20
21
22
23
24
25
        output[base] = input[elem];
    float x = freqs[freq] * input[elem];
    float s, c;
    __sincosf(x, &s, &c);
    output[base + inChns * (freq * 2 + 1)] = s;
    output[base + inChns * (freq * 2 + 2)] = c;
Nianchen Deng's avatar
Nianchen Deng committed
26
27
}

Nianchen Deng's avatar
sync    
Nianchen Deng committed
28
void Encoder::encode(sptr<CudaArray<float>> output, sptr<CudaArray<float>> input) {
Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
30
    dim3 blkSize(1024 / _chns / _multires, _chns, _multires);
    dim3 grdSize((uint)ceil(input->n() / (float)blkSize.x), 1, 1);
Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
    CU_INVOKE(cu_encode)(output->getBuffer<float>(), *input, *_freqs, input->n());
Nianchen Deng's avatar
Nianchen Deng committed
32
33
34
    CHECK_EX(cudaGetLastError());
}

Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
void Encoder::_genFreqArray() {
Nianchen Deng's avatar
Nianchen Deng committed
36
37
38
39
40
41
42
43
    float *arr = new float[_multires];
    arr[0] = 1.0f;
    for (auto i = 1; i < _multires; ++i)
        arr[i] = arr[i - 1] * 2.0f;
    _freqs = sptr<CudaArray<float>>(new CudaArray<float>(_multires));
    cudaMemcpy(_freqs->getBuffer(), arr, _multires * sizeof(float), cudaMemcpyHostToDevice);
    delete[] arr;
}