|
2 | 2 |
|
3 | 3 | #include <torch/extension.h> |
4 | 4 |
|
| 5 | +#include "compat.h" |
| 6 | + |
5 | 7 | #define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \ |
6 | 8 | [&] { \ |
7 | | - TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ |
| 9 | + TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \ |
8 | 10 | auto TENSOR1##_size = TENSOR1.size(DIM); \ |
9 | 11 | auto TENSOR1##_stride = TENSOR1.stride(DIM); \ |
10 | 12 | \ |
11 | | - TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \ |
| 13 | + TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \ |
12 | 14 | auto TENSOR2##_size = TENSOR2.size(DIM); \ |
13 | 15 | auto TENSOR2##_stride = TENSOR2.stride(DIM); \ |
14 | 16 | \ |
15 | | - TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \ |
| 17 | + TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \ |
16 | 18 | auto TENSOR3##_size = TENSOR3.size(DIM); \ |
17 | 19 | auto TENSOR3##_stride = TENSOR3.stride(DIM); \ |
18 | 20 | \ |
19 | 21 | auto dims = TENSOR1.dim(); \ |
20 | 22 | auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \ |
21 | | - auto counter = zeros.data<int64_t>(); \ |
| 23 | + auto counter = zeros.DATA_PTR<int64_t>(); \ |
22 | 24 | bool has_finished = false; \ |
23 | 25 | \ |
24 | 26 | while (!has_finished) { \ |
|
59 | 61 | #define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \ |
60 | 62 | TENSOR4, DIM, CODE) \ |
61 | 63 | [&] { \ |
62 | | - TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ |
| 64 | + TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \ |
63 | 65 | auto TENSOR1##_size = TENSOR1.size(DIM); \ |
64 | 66 | auto TENSOR1##_stride = TENSOR1.stride(DIM); \ |
65 | 67 | \ |
66 | | - TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \ |
| 68 | + TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \ |
67 | 69 | auto TENSOR2##_size = TENSOR2.size(DIM); \ |
68 | 70 | auto TENSOR2##_stride = TENSOR2.stride(DIM); \ |
69 | 71 | \ |
70 | | - TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \ |
| 72 | + TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \ |
71 | 73 | auto TENSOR3##_size = TENSOR3.size(DIM); \ |
72 | 74 | auto TENSOR3##_stride = TENSOR3.stride(DIM); \ |
73 | 75 | \ |
74 | | - TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \ |
| 76 | + TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR<TYPE4>(); \ |
75 | 77 | auto TENSOR4##_size = TENSOR4.size(DIM); \ |
76 | 78 | auto TENSOR4##_stride = TENSOR4.stride(DIM); \ |
77 | 79 | \ |
78 | 80 | auto dims = TENSOR1.dim(); \ |
79 | 81 | auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \ |
80 | | - auto counter = zeros.data<int64_t>(); \ |
| 82 | + auto counter = zeros.DATA_PTR<int64_t>(); \ |
81 | 83 | bool has_finished = false; \ |
82 | 84 | \ |
83 | 85 | while (!has_finished) { \ |
|
0 commit comments