cuda_utils.h 2.11 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
// Copyright (c) Facebook, Inc. and its affiliates.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
//
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4
5
6
7
8
9
10
11
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cmath>
Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
#include <vector>
Nianchen Deng's avatar
sync    
Nianchen Deng committed
13
14
15

#include <cuda.h>
#include <cuda_runtime.h>
Nianchen Deng's avatar
sync    
Nianchen Deng committed
16
#include <device_launch_parameters.h>
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17

Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
19
#define LOG2_TOTAL_THREADS 10
#define TOTAL_THREADS (2 << LOG2_TOTAL_THREADS)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
20

Nianchen Deng's avatar
sync    
Nianchen Deng committed
21
22
23
inline uint opt_n_threads(uint work_size) {
    const uint pow_2 = std::log(work_size) / std::log(2.0);
    return 1 << min(pow_2, LOG2_TOTAL_THREADS);
Nianchen Deng's avatar
sync    
Nianchen Deng committed
24
25
26
}

inline dim3 opt_block_config(int x, int y) {
Nianchen Deng's avatar
sync    
Nianchen Deng committed
27
28
29
    const int x_threads = opt_n_threads(x);
    const int y_threads = max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
    dim3 block_config(x_threads, y_threads, 1);
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30

Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
    return block_config;
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
33
}

Nianchen Deng's avatar
sync    
Nianchen Deng committed
34
35
36
37
38
39
40
41
42
#define CUDA_CHECK_ERRORS()                                                                        \
    do {                                                                                           \
        cudaError_t err = cudaGetLastError();                                                      \
        if (cudaSuccess != err) {                                                                  \
            fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n",                         \
                    cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, __FILE__);             \
            exit(-1);                                                                              \
        }                                                                                          \
    } while (0)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
44

#endif
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

template <typename T, uint N_ELEMS> struct vec {
    __host__ __device__ T &operator[](uint idx) { return data[idx]; }

    __host__ __device__ T operator[](uint idx) const { return data[idx]; }

    T data[N_ELEMS];
    static constexpr uint N = N_ELEMS;
};

template <uint N_FLOATS> using fvec = vec<float, N_FLOATS>;

template <uint N_HALFS> using hvec = vec<__half, N_HALFS>;

template <uint N_UINTS> using uvec = vec<uint, N_UINTS>;