15
15
#define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
16
16
#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
17
17
#define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
18
+ #define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
18
19
19
20
// -------------------------------------- FP32 --------------------------------------
20
21
// ElementWise Add
@@ -95,6 +96,23 @@ __global__ void elementwise_add_f16x8_kernel(half* a, half* b, half* c, int N) {
95
96
if ((idx + 6 ) < N) { HALF2 (c[idx + 6 ]) = reg_c_3; }
96
97
}
97
98
99
+ __global__ void elementwise_add_f16x8_pack_kernel (half* a, half* b, half* c, int N) {
100
+ int idx = 8 * (blockIdx .x * blockDim .x + threadIdx .x );
101
+ // temporary register(memory), .local space in ptx, addressable
102
+ half pack_a[8 ], pack_b[8 ], pack_c[8 ]; // 8x16 bits=128 bits.
103
+ // reinterpret as float4 and load 128 bits in 1 memory issue.
104
+ LDST128BITS (pack_a[0 ]) = LDST128BITS (a[idx]); // load 128 bits
105
+ LDST128BITS (pack_b[0 ]) = LDST128BITS (b[idx]); // load 128 bits
106
+
107
+ #pragma unroll
108
+ for (int i = 0 ; i < 8 ; i += 2 ) {
109
+ // __hadd2 for half2 x 4
110
+ HALF2 (pack_c[i]) = __hadd2 (HALF2 (pack_a[i]), HALF2 (pack_b[i]));
111
+ }
112
+ // reinterpret as float4 and store 128 bits in 1 memory issue.
113
+ if ((idx + 7 ) < N) { LDST128BITS (c[idx]) = LDST128BITS (pack_c[0 ]); }
114
+ }
115
+
98
116
99
117
// --------------------- PyTorch bindings for custom kernel -----------------------
100
118
#define STRINGFY (str ) #str
@@ -107,70 +125,59 @@ if(((T).options().dtype() != (th_type))) { \
107
125
throw std::runtime_error (" values must be " #th_type); \
108
126
}
109
127
110
- #define CHECK_TORCH_TENSOR_SHAPE (T, S0 ) \
111
- if (((T).size(0 ) != (S0))) { throw std::runtime_error (" Tensor size mismatch!" ); }
112
-
113
128
#define TORCH_BINDING_ELEM_ADD (packed_type, th_type, element_type, n_elements ) \
114
- torch::Tensor elementwise_add_##packed_type(torch::Tensor a, torch::Tensor b) { \
115
- CHECK_TORCH_TENSOR_DTYPE (a, (th_type)) \
116
- CHECK_TORCH_TENSOR_DTYPE (b, (th_type)) \
117
- auto options = torch::TensorOptions ().dtype ((th_type)).device ( \
118
- torch::kCUDA , 0 ); \
119
- const int N = a.size (0 ); \
120
- CHECK_TORCH_TENSOR_SHAPE (b, N) \
121
- auto c = torch::zeros ({N}, options); \
122
- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
123
- const int NUM_BLOCKS = (N + 256 - 1 ) / 256 ; \
124
- dim3 block (NUM_THREADS_PER_BLOCK); \
125
- dim3 grid (NUM_BLOCKS); \
126
- elementwise_add_##packed_type##_kernel<<<grid, block>>> ( \
127
- reinterpret_cast <element_type*>(a.data_ptr ()), \
128
- reinterpret_cast <element_type*>(b.data_ptr ()), \
129
- reinterpret_cast <element_type*>(c.data_ptr ()), N); \
130
- return c; \
131
- }
132
-
133
- #define TORCH_BINDING_ELEM_ADD_V2 (packed_type, th_type, element_type, n_elements )\
134
- void elementwise_add_##packed_type##_v2( \
129
+ void elementwise_add_##packed_type( \
135
130
torch::Tensor a, torch::Tensor b, torch::Tensor c) { \
136
131
CHECK_TORCH_TENSOR_DTYPE (a, (th_type)) \
137
132
CHECK_TORCH_TENSOR_DTYPE (b, (th_type)) \
138
133
CHECK_TORCH_TENSOR_DTYPE (c, (th_type)) \
139
- const int N = a.size (0 ); \
140
- CHECK_TORCH_TENSOR_SHAPE (b, N) \
141
- CHECK_TORCH_TENSOR_SHAPE (c, N) \
142
- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
143
- const int NUM_BLOCKS = (N + 256 - 1 ) / 256 ; \
144
- dim3 block (NUM_THREADS_PER_BLOCK); \
145
- dim3 grid (NUM_BLOCKS); \
146
- elementwise_add_##packed_type##_kernel<<<grid, block>>> ( \
134
+ const int ndim = a.dim (); \
135
+ if (ndim != 2 ) { \
136
+ int N = 1 ; \
137
+ for (int i = 0 ; i < ndim; ++i) { N *= a.size (i); } \
138
+ dim3 block (256 / (n_elements)); \
139
+ dim3 grid ((N + 256 - 1 ) / 256 ); \
140
+ elementwise_add_##packed_type##_kernel<<<grid, block>>> ( \
147
141
reinterpret_cast <element_type*>(a.data_ptr ()), \
148
142
reinterpret_cast <element_type*>(b.data_ptr ()), \
149
143
reinterpret_cast <element_type*>(c.data_ptr ()), N); \
144
+ } else { \
145
+ const int S = a.size (0 ); \
146
+ const int K = a.size (1 ); \
147
+ const int N = S * K; \
148
+ if ((K/(n_elements)) <= 1024 ) { \
149
+ dim3 block (K/(n_elements)); \
150
+ dim3 grid (S); \
151
+ elementwise_add_##packed_type##_kernel<<<grid, block>>> ( \
152
+ reinterpret_cast <element_type*>(a.data_ptr ()), \
153
+ reinterpret_cast <element_type*>(b.data_ptr ()), \
154
+ reinterpret_cast <element_type*>(c.data_ptr ()), N); \
155
+ } else { \
156
+ int N = 1 ; \
157
+ for (int i = 0 ; i < ndim; ++i) { N *= a.size (i); } \
158
+ dim3 block (256 / (n_elements)); \
159
+ dim3 grid ((N + 256 - 1 ) / 256 ); \
160
+ elementwise_add_##packed_type##_kernel<<<grid, block>>> ( \
161
+ reinterpret_cast <element_type*>(a.data_ptr ()), \
162
+ reinterpret_cast <element_type*>(b.data_ptr ()), \
163
+ reinterpret_cast <element_type*>(c.data_ptr ()), N); \
164
+ } \
165
+ } \
150
166
}
151
167
152
168
153
- TORCH_BINDING_ELEM_ADD (f32 , torch::kFloat32 , float , 1 )
154
- TORCH_BINDING_ELEM_ADD(f32x4, torch::kFloat32 , float , 4 )
155
- TORCH_BINDING_ELEM_ADD(f16 , torch::kHalf , half, 1 )
156
- TORCH_BINDING_ELEM_ADD(f16x2, torch::kHalf , half, 2 )
157
- TORCH_BINDING_ELEM_ADD(f16x8, torch::kHalf , half, 8 )
158
- // v2: no copy of c Tensor
159
- TORCH_BINDING_ELEM_ADD_V2(f32 , torch::kFloat32 , float , 1 )
160
- TORCH_BINDING_ELEM_ADD_V2(f32x4, torch::kFloat32 , float , 4 )
161
- TORCH_BINDING_ELEM_ADD_V2(f16 , torch::kHalf , half, 1 )
162
- TORCH_BINDING_ELEM_ADD_V2(f16x2, torch::kHalf , half, 2 )
163
- TORCH_BINDING_ELEM_ADD_V2(f16x8, torch::kHalf , half, 8 )
169
+ TORCH_BINDING_ELEM_ADD (f32 , torch::kFloat32 , float , 1 )
170
+ TORCH_BINDING_ELEM_ADD(f32x4, torch::kFloat32 , float , 4 )
171
+ TORCH_BINDING_ELEM_ADD(f16 , torch::kHalf , half, 1 )
172
+ TORCH_BINDING_ELEM_ADD(f16x2, torch::kHalf , half, 2 )
173
+ TORCH_BINDING_ELEM_ADD(f16x8, torch::kHalf , half, 8 )
174
+ TORCH_BINDING_ELEM_ADD(f16x8_pack, torch::kHalf , half, 8 )
164
175
165
176
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
166
177
TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f32)
167
178
TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f32x4)
168
179
TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f16)
169
180
TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f16x2)
170
181
TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f16x8)
171
- TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f32_v2)
172
- TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f32x4_v2)
173
- TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f16_v2)
174
- TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f16x2_v2)
175
- TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f16x8_v2)
182
+ TORCH_BINDING_COMMON_EXTENSION (elementwise_add_f16x8_pack)
176
183
}
0 commit comments