@@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
128
128
}
129
129
}
130
130
131
+ template <typename T>
132
+ __device__ __forceinline__ T swigluoai_and_mul (const T& gate, const T& up,
133
+ float alpha, float limit) {
134
+ // clamp gate: min=None, max=limit
135
+ const float gate_f = (float )gate;
136
+ const float clamped_gate = gate_f > limit ? limit : gate_f;
137
+
138
+ // clamp up: min=-limit, max=limit
139
+ const float up_f = (float )up;
140
+ const float clamped_up =
141
+ up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
142
+
143
+ // glu = gate * sigmoid(gate * alpha)
144
+ const float sigmoid_val = 1 .0f / (1 .0f + expf (-clamped_gate * alpha));
145
+ const float glu = clamped_gate * sigmoid_val;
146
+
147
+ // (up + 1) * glu
148
+ return (T)((clamped_up + 1 .0f ) * glu);
149
+ }
150
+
151
+ template <typename scalar_t ,
152
+ scalar_t (*ACT_FN)(const scalar_t &, const scalar_t &, const float ,
153
+ const float )>
154
+ __global__ void swigluoai_and_mul_kernel (
155
+ scalar_t * __restrict__ out, // [..., d]
156
+ const scalar_t * __restrict__ input, // [..., 2, d]
157
+ const int d, const float alpha, const float limit) {
158
+ const int64_t token_idx = blockIdx .x ;
159
+ // TODO: Vectorize loads and stores.
160
+ for (int64_t idx = threadIdx .x ; idx < d; idx += blockDim .x ) {
161
+ // gate = x[..., ::2] (even indices)
162
+ const scalar_t gate = VLLM_LDG (&input[token_idx * 2 * d + 2 * idx]);
163
+ // up = x[..., 1::2] (odd indices)
164
+ const scalar_t up = VLLM_LDG (&input[token_idx * 2 * d + 2 * idx + 1 ]);
165
+
166
+ out[token_idx * d + idx] = ACT_FN (gate, up, alpha, limit);
167
+ }
168
+ }
169
+
131
170
} // namespace vllm
132
171
133
172
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM (KERNEL, PARAM ) \
@@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
145
184
PARAM); \
146
185
});
147
186
187
+ #define LAUNCH_SIGLUOAI_AND_MUL (KERNEL, ALPHA, LIMIT ) \
188
+ int d = input.size(-1 ) / 2 ; \
189
+ int64_t num_tokens = input.numel() / input.size(-1 ); \
190
+ dim3 grid (num_tokens); \
191
+ dim3 block (std::min(d, 1024 )); \
192
+ const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
193
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
194
+ VLLM_DISPATCH_FLOATING_TYPES ( \
195
+ input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
196
+ vllm::swigluoai_and_mul_kernel<scalar_t , KERNEL<scalar_t >> \
197
+ <<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
198
+ input.data_ptr <scalar_t >(), d, ALPHA, \
199
+ LIMIT); \
200
+ });
201
+
148
202
void fatrelu_and_mul (torch::Tensor& out, // [..., d],
149
203
torch::Tensor& input, // [..., 2 * d]
150
204
double threshold) {
151
205
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM (vllm::fatrelu_kernel, threshold);
152
206
}
207
+ void swigluoai_and_mul (torch::Tensor& out, // [..., d]
208
+ torch::Tensor& input, // [..., 2 * d]
209
+ double alpha, double limit) {
210
+ LAUNCH_SIGLUOAI_AND_MUL (vllm::swigluoai_and_mul, alpha, limit);
211
+ }
153
212
namespace vllm {
154
213
155
214
// Element-wise activation kernel template.
0 commit comments