13
13
#define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
14
14
#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
15
15
#define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
16
+ #define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
16
17
17
18
// -------------------------------------- FP32 --------------------------------------
18
19
// Relu x: N, y: N y=max(0,x)
@@ -81,6 +82,24 @@ __global__ void relu_f16x8_kernel(half* x, half* y, int N) {
81
82
if ((idx + 6 ) < N) { HALF2 (y[idx + 6 ]) = reg_y_3; }
82
83
}
83
84
85
+ __global__ void relu_f16x8_pack_kernel (half* x, half* y, int N) {
86
+ int idx = 8 * (blockIdx .x * blockDim .x + threadIdx .x );
87
+ const half2 z2 = {__float2half (0 .0f ), __float2half (0 .0f )};
88
+ // temporary register(memory), .local space in ptx, addressable
89
+ half pack_x[8 ], pack_y[8 ]; // 8x16 bits=128 bits.
90
+ // reinterpret as float4 and load 128 bits in 1 memory issue.
91
+ LDST128BITS (pack_x[0 ]) = LDST128BITS (x[idx]); // load 128 bits
92
+
93
+ #pragma unroll
94
+ for (int i = 0 ; i < 8 ; i += 2 ) {
95
+ // __hmax2 for half2 x 4
96
+ HALF2 (pack_y[i]) = __hmax2 (HALF2 (pack_x[i]), z2);
97
+ }
98
+ // reinterpret as float4 and store 128 bits in 1 memory issue.
99
+ if ((idx + 7 ) < N) { LDST128BITS (y[idx]) = LDST128BITS (pack_y[0 ]); }
100
+ }
101
+
102
+
84
103
// --------------------- PyTorch bindings for custom kernel -----------------------
85
104
#define STRINGFY (str ) #str
86
105
#define TORCH_BINDING_COMMON_EXTENSION (func ) \
@@ -92,61 +111,54 @@ if(((T).options().dtype() != (th_type))) { \
92
111
throw std::runtime_error (" values must be " #th_type); \
93
112
}
94
113
95
- #define CHECK_TORCH_TENSOR_SHAPE (T, S0 ) \
96
- if (((T).size(0 ) != (S0))) { throw std::runtime_error (" Tensor size mismatch!" ); }
97
-
98
114
#define TORCH_BINDING_RELU (packed_type, th_type, element_type, n_elements ) \
99
- torch::Tensor relu_##packed_type(torch::Tensor x) { \
100
- CHECK_TORCH_TENSOR_DTYPE (x, (th_type)) \
101
- auto options = torch::TensorOptions ().dtype ((th_type)).device ( \
102
- torch::kCUDA , 0 ); \
103
- const int N = x.size (0 ); \
104
- auto y = torch::zeros ({N}, options); \
105
- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
106
- const int NUM_BLOCKS = (N + 256 - 1 ) / 256 ; \
107
- dim3 block (NUM_THREADS_PER_BLOCK); \
108
- dim3 grid (NUM_BLOCKS); \
109
- relu_##packed_type##_kernel<<<grid, block>>> ( \
110
- reinterpret_cast <element_type*>(x.data_ptr ()), \
111
- reinterpret_cast <element_type*>(y.data_ptr ()), N); \
112
- return y; \
113
- }
114
-
115
- #define TORCH_BINDING_RELU_V2 (packed_type, th_type, element_type, n_elements ) \
116
- void relu_##packed_type##_v2(torch::Tensor x, torch::Tensor y) { \
115
+ void relu_##packed_type(torch::Tensor x, torch::Tensor y) { \
117
116
CHECK_TORCH_TENSOR_DTYPE (x, (th_type)) \
118
117
CHECK_TORCH_TENSOR_DTYPE (y, (th_type)) \
119
- const int N = x.size ( 0 ); \
120
- CHECK_TORCH_TENSOR_SHAPE (y, N) \
121
- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
122
- const int NUM_BLOCKS = (N + 256 - 1 ) / 256 ; \
123
- dim3 block (NUM_THREADS_PER_BLOCK); \
124
- dim3 grid (NUM_BLOCKS); \
125
- relu_##packed_type##_kernel<<<grid, block>>> ( \
118
+ const int ndim = x.dim (); \
119
+ if (ndim != 2 ) { \
120
+ int N = 1 ; \
121
+ for ( int i = 0 ; i < ndim; ++i) { N *= x. size (i); } \
122
+ dim3 block (256 / (n_elements)); \
123
+ dim3 grid ((N + 256 - 1 ) / 256 ); \
124
+ relu_##packed_type##_kernel<<<grid, block>>> ( \
126
125
reinterpret_cast <element_type*>(x.data_ptr ()), \
127
126
reinterpret_cast <element_type*>(y.data_ptr ()), N); \
127
+ } else { \
128
+ const int S = x.size (0 ); \
129
+ const int K = x.size (1 ); \
130
+ const int N = S * K; \
131
+ if ((K/(n_elements)) <= 1024 ) { \
132
+ dim3 block (K/(n_elements)); \
133
+ dim3 grid (S); \
134
+ relu_##packed_type##_kernel<<<grid, block>>> ( \
135
+ reinterpret_cast <element_type*>(x.data_ptr ()), \
136
+ reinterpret_cast <element_type*>(y.data_ptr ()), N); \
137
+ } else { \
138
+ int N = 1 ; \
139
+ for (int i = 0 ; i < ndim; ++i) { N *= x.size (i); } \
140
+ dim3 block (256 / (n_elements)); \
141
+ dim3 grid ((N + 256 - 1 ) / 256 ); \
142
+ relu_##packed_type##_kernel<<<grid, block>>> ( \
143
+ reinterpret_cast <element_type*>(x.data_ptr ()), \
144
+ reinterpret_cast <element_type*>(y.data_ptr ()), N); \
145
+ } \
146
+ } \
128
147
}
129
148
130
- TORCH_BINDING_RELU (f32 , torch::kFloat32 , float , 1 )
131
- TORCH_BINDING_RELU(f32x4, torch::kFloat32 , float , 4 )
132
- TORCH_BINDING_RELU(f16 , torch::kHalf , half, 1 )
133
- TORCH_BINDING_RELU(f16x2, torch::kHalf , half, 2 )
134
- TORCH_BINDING_RELU(f16x8, torch::kHalf , half, 8 )
135
- TORCH_BINDING_RELU_V2(f32 , torch::kFloat32 , float , 1 )
136
- TORCH_BINDING_RELU_V2(f32x4, torch::kFloat32 , float , 4 )
137
- TORCH_BINDING_RELU_V2(f16 , torch::kHalf , half, 1 )
138
- TORCH_BINDING_RELU_V2(f16x2, torch::kHalf , half, 2 )
139
- TORCH_BINDING_RELU_V2(f16x8, torch::kHalf , half, 8 )
149
+
150
+ TORCH_BINDING_RELU (f32 , torch::kFloat32 , float , 1 )
151
+ TORCH_BINDING_RELU(f32x4, torch::kFloat32 , float , 4 )
152
+ TORCH_BINDING_RELU(f16 , torch::kHalf , half, 1 )
153
+ TORCH_BINDING_RELU(f16x2, torch::kHalf , half, 2 )
154
+ TORCH_BINDING_RELU(f16x8, torch::kHalf , half, 8 )
155
+ TORCH_BINDING_RELU(f16x8_pack, torch::kHalf , half, 8 )
140
156
141
157
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
142
158
TORCH_BINDING_COMMON_EXTENSION (relu_f32)
143
159
TORCH_BINDING_COMMON_EXTENSION (relu_f32x4)
144
160
TORCH_BINDING_COMMON_EXTENSION (relu_f16)
145
161
TORCH_BINDING_COMMON_EXTENSION (relu_f16x2)
146
162
TORCH_BINDING_COMMON_EXTENSION (relu_f16x8)
147
- TORCH_BINDING_COMMON_EXTENSION (relu_f32_v2)
148
- TORCH_BINDING_COMMON_EXTENSION (relu_f32x4_v2)
149
- TORCH_BINDING_COMMON_EXTENSION (relu_f16_v2)
150
- TORCH_BINDING_COMMON_EXTENSION (relu_f16x2_v2)
151
- TORCH_BINDING_COMMON_EXTENSION (relu_f16x8_v2)
163
+ TORCH_BINDING_COMMON_EXTENSION (relu_f16x8_pack)
152
164
}
0 commit comments