utils.h 1.82 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
// Copyright (c) Facebook, Inc. and its affiliates.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
//
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4
5
6
7
8
9
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
11
12
13
#define CHECK_CUDA(x)                                                                              \
    do {                                                                                           \
        TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor");                                     \
    } while (0)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14

Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
16
17
18
#define CHECK_CONTIGUOUS(x)                                                                        \
    do {                                                                                           \
        TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor");                         \
    } while (0)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
19

Nianchen Deng's avatar
sync    
Nianchen Deng committed
20
21
22
23
#define CHECK_IS_INT(x)                                                                            \
    do {                                                                                           \
        TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor");          \
    } while (0)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
24

Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
26
27
28
29
30
31
32
33
#define CHECK_IS_FLOAT(x)                                                                          \
    do {                                                                                           \
        TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, #x " must be a float tensor");       \
    } while (0)

#define CHECK_CUDA_CONT_TENSOR(__TYPE__, __VAR__)                                                  \
    CHECK_CONTIGUOUS(__VAR__);                                                                     \
    CHECK_IS_##__TYPE__(__VAR__);                                                                  \
    CHECK_CUDA(__VAR__);