Skip to content

Commit 2e9b93f

Browse files
committed
Fix precompile
Signed-off-by: Benji Beck <[email protected]>
1 parent 6942388 commit 2e9b93f

File tree

2 files changed

+84
-100
lines changed

2 files changed

+84
-100
lines changed

csrc/layernorm_kernels.cu

Lines changed: 50 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ __device__ __forceinline__ T warp_sum(T v) {
1919
#ifdef __HIP_PLATFORM_AMD__
2020
const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask
2121
#else
22-
const unsigned m = 0xffffffffu; // CUDA 32-bit mask
22+
const unsigned m = 0xffffffffu; // CUDA 32-bit mask
2323
#endif
2424
// Always reduce over 32 lanes to match downstream logic.
2525
constexpr int kWidth = 32;
2626
v += __shfl_down_sync(m, v, 16, kWidth);
27-
v += __shfl_down_sync(m, v, 8, kWidth);
28-
v += __shfl_down_sync(m, v, 4, kWidth);
29-
v += __shfl_down_sync(m, v, 2, kWidth);
30-
v += __shfl_down_sync(m, v, 1, kWidth);
27+
v += __shfl_down_sync(m, v, 8, kWidth);
28+
v += __shfl_down_sync(m, v, 4, kWidth);
29+
v += __shfl_down_sync(m, v, 2, kWidth);
30+
v += __shfl_down_sync(m, v, 1, kWidth);
3131
return v;
3232
}
3333

@@ -63,18 +63,17 @@ __device__ __forceinline__ void copy_row_to_shared_aligned(
6363
for (int i = tid; i < prefix; i += blockDim.x) dst[i] = src[i];
6464

6565
// vector main
66-
const int remain = n_elems - prefix;
66+
const int remain = n_elems - prefix;
6767
const int main_elems = (remain / perVec) * perVec;
6868
if (main_elems > 0) {
6969
const uint4* __restrict__ vsrc =
7070
reinterpret_cast<const uint4*>(src + prefix);
7171

7272
#if defined(__HIP_PLATFORM_AMD__)
7373
// ROCm: vector load from global, scalar 32-bit stores to shared
74-
uint32_t* __restrict__ s32 =
75-
reinterpret_cast<uint32_t*>(dst + prefix);
76-
const int nvec = main_elems / perVec; // 16B packets
77-
constexpr int WORDS_PER_PKT = 16 / sizeof(uint32_t); // = 4
74+
uint32_t* __restrict__ s32 = reinterpret_cast<uint32_t*>(dst + prefix);
75+
const int nvec = main_elems / perVec; // 16B packets
76+
constexpr int WORDS_PER_PKT = 16 / sizeof(uint32_t); // = 4
7877
for (int v = tid; v < nvec; v += blockDim.x) {
7978
uint4 p = vsrc[v];
8079
const int base = v * WORDS_PER_PKT;
@@ -85,8 +84,7 @@ __device__ __forceinline__ void copy_row_to_shared_aligned(
8584
}
8685
#else
8786
// CUDA: vector load + vector store (fastest)
88-
uint4* __restrict__ vdst =
89-
reinterpret_cast<uint4*>(dst + prefix);
87+
uint4* __restrict__ vdst = reinterpret_cast<uint4*>(dst + prefix);
9088
const int nvec = main_elems / perVec;
9189
for (int v = tid; v < nvec; v += blockDim.x) {
9290
uint4 p = vsrc[v];
@@ -101,12 +99,14 @@ __device__ __forceinline__ void copy_row_to_shared_aligned(
10199
__syncthreads();
102100
}
103101

104-
// ---------------- vec/scalar ops (generic, used for all dtypes) ----------------
102+
// ---------------- vec/scalar ops (generic, used for all dtypes)
103+
// ----------------
105104
template <int V, typename T>
106105
struct VecMulNormWeight {
107-
const vec_n_t<T, V>* __restrict__ wv; // vector view of weight (aligned with in/out)
106+
const vec_n_t<T, V>* __restrict__ wv; // vector view of weight (aligned with
107+
// in/out)
108108
float inv_rms;
109-
int stride_vec;
109+
int stride_vec;
110110
mutable int64_t vec_idx;
111111

112112
__device__ __forceinline__ void operator()(vec_n_t<T, V>& dst,
@@ -123,8 +123,8 @@ struct VecMulNormWeight {
123123

124124
template <typename T>
125125
struct ScalarMulNormWeight {
126-
const T* __restrict__ w_base; // already offset by +prefix
127-
T* __restrict__ out_base; // out_row + prefix
126+
const T* __restrict__ w_base; // already offset by +prefix
127+
T* __restrict__ out_base; // out_row + prefix
128128
float inv_rms;
129129
__device__ __forceinline__ void operator()(T& dst, const T src) const {
130130
const int i = static_cast<int>(&dst - out_base);
@@ -139,12 +139,11 @@ __global__ void rms_norm_kernel(
139139
scalar_t* __restrict__ out, // [..., hidden_size]
140140
const scalar_t* __restrict__ input, // [..., hidden_size]
141141
const int64_t input_stride,
142-
const scalar_t* __restrict__ weight, // [hidden_size]
142+
const scalar_t* __restrict__ weight, // [hidden_size]
143143
const float epsilon, const int /*num_tokens*/, const int hidden_size,
144144
int smem_elems) {
145-
146-
const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride;
147-
scalar_t* __restrict__ out_row = out + blockIdx.x * hidden_size;
145+
const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride;
146+
scalar_t* __restrict__ out_row = out + blockIdx.x * hidden_size;
148147

149148
// Optional cached-row (half) when host provisioned shmem
150149
extern __shared__ unsigned char smem_raw[];
@@ -187,15 +186,17 @@ __global__ void rms_norm_kernel(
187186

188187
acc_t total = acc_t(0);
189188
if (threadIdx.x < 32) {
190-
acc_t v = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_sums_sh[threadIdx.x] : acc_t(0);
189+
acc_t v = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_sums_sh[threadIdx.x]
190+
: acc_t(0);
191191
total = warp_sum<acc_t>(v);
192192
if (threadIdx.x == 0) warp_sums_sh[0] = total;
193193
}
194194
__syncthreads();
195195

196196
// compute inv_rms in float to match baseline epsilon semantics
197-
const float inv_rms =
198-
rsqrtf(static_cast<float>(warp_sums_sh[0] / static_cast<acc_t>(hidden_size)) + epsilon);
197+
const float inv_rms = rsqrtf(
198+
static_cast<float>(warp_sums_sh[0] / static_cast<acc_t>(hidden_size)) +
199+
epsilon);
199200

200201
// -------- Fast path: HS == blockDim.x (e.g., 1024) --------
201202
if (hidden_size == blockDim.x) {
@@ -210,10 +211,9 @@ __global__ void rms_norm_kernel(
210211
constexpr int V = (sizeof(scalar_t) == 2) ? 8 : 4; // 16B packets
211212
constexpr int WIDTH = V * sizeof(scalar_t);
212213

213-
const bool can_vec =
214-
(hidden_size % V == 0) &&
215-
same_phase(in_row, out_row, WIDTH) &&
216-
same_phase(in_row, weight, WIDTH);
214+
const bool can_vec = (hidden_size % V == 0) &&
215+
same_phase(in_row, out_row, WIDTH) &&
216+
same_phase(in_row, weight, WIDTH);
217217

218218
if (can_vec) {
219219
const uintptr_t addr = reinterpret_cast<uintptr_t>(in_row);
@@ -227,29 +227,24 @@ __global__ void rms_norm_kernel(
227227
}
228228

229229
// vector main
230-
const int remain = hidden_size - prefix;
230+
const int remain = hidden_size - prefix;
231231
const int main_len = (remain / V) * V;
232232
if (main_len > 0) {
233233
using VecT = vec_n_t<scalar_t, V>;
234234
const VecT* __restrict__ wv =
235235
reinterpret_cast<const VecT*>(weight + prefix);
236236

237-
VecMulNormWeight<V, scalar_t> vec_op{
238-
/*wv=*/ wv,
239-
/*inv_rms=*/ inv_rms,
240-
/*stride_vec=*/ (int)blockDim.x,
241-
/*vec_idx=*/ (int64_t)threadIdx.x
242-
};
243-
ScalarMulNormWeight<scalar_t> sca_op{
244-
/*w_base=*/ weight + prefix,
245-
/*out_base=*/ out_row + prefix,
246-
/*inv_rms=*/ inv_rms
247-
};
237+
VecMulNormWeight<V, scalar_t> vec_op{/*wv=*/wv,
238+
/*inv_rms=*/inv_rms,
239+
/*stride_vec=*/(int)blockDim.x,
240+
/*vec_idx=*/(int64_t)threadIdx.x};
241+
ScalarMulNormWeight<scalar_t> sca_op{/*w_base=*/weight + prefix,
242+
/*out_base=*/out_row + prefix,
243+
/*inv_rms=*/inv_rms};
248244

249245
const scalar_t* vin = use_cached ? (s_in + prefix) : (in_row + prefix);
250-
vectorize_with_alignment<V>(
251-
vin, out_row + prefix, main_len,
252-
threadIdx.x, blockDim.x, vec_op, sca_op);
246+
vectorize_with_alignment<V>(vin, out_row + prefix, main_len, threadIdx.x,
247+
blockDim.x, vec_op, sca_op);
253248
}
254249

255250
// scalar tail
@@ -269,7 +264,6 @@ __global__ void rms_norm_kernel(
269264
}
270265
}
271266

272-
273267
/* Function specialization in the case of FP16/BF16 tensors.
274268
Additional optimizations we can make in this case are
275269
packed and vectorized operations, which help with the
@@ -369,9 +363,9 @@ fused_add_rms_norm_kernel(
369363
} // namespace vllm
370364

371365
static inline int ln_block_threads_unified(int H) {
372-
int threads = (H >= 1024) ? 256
373-
: (H >= 512) ? 512
374-
: std::min(1024, ((H + 31) / 32) * 32);
366+
int threads = (H >= 1024) ? 256
367+
: (H >= 512) ? 512
368+
: std::min(1024, ((H + 31) / 32) * 32);
375369
return std::min(1024, std::max(128, ((threads + 31) / 32) * 32));
376370
}
377371

@@ -383,8 +377,8 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
383377
TORCH_CHECK(input.stride(-1) == 1);
384378
TORCH_CHECK(weight.is_contiguous());
385379

386-
const int hidden_size = input.size(-1);
387-
const int num_tokens = input.numel() / hidden_size;
380+
const int hidden_size = input.size(-1);
381+
const int num_tokens = input.numel() / hidden_size;
388382
const int64_t in_stride = input.stride(-2);
389383

390384
dim3 grid(num_tokens);
@@ -393,29 +387,23 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
393387
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
394388
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
395389

396-
// Optional cached-row for FP16 (recommended). Kernel still works if this is 0.
390+
// Optional cached-row for FP16 (recommended). Kernel still works if this is
391+
// 0.
397392
size_t shmem_bytes = 0;
398393
int smem_elems = 0;
399394
if (input.scalar_type() == at::kHalf && hidden_size <= 4096) {
400395
shmem_bytes = static_cast<size_t>(hidden_size) * sizeof(at::Half);
401-
smem_elems = hidden_size; // flag to kernel that shmem was provisioned
396+
smem_elems = hidden_size; // flag to kernel that shmem was provisioned
402397
}
403398

404399
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
405-
vllm::rms_norm_kernel<scalar_t>
406-
<<<grid, block, shmem_bytes, stream>>>(
407-
out.data_ptr<scalar_t>(),
408-
input.data_ptr<scalar_t>(),
409-
in_stride,
410-
weight.data_ptr<scalar_t>(),
411-
static_cast<float>(epsilon),
412-
num_tokens,
413-
hidden_size,
414-
smem_elems);
400+
vllm::rms_norm_kernel<scalar_t><<<grid, block, shmem_bytes, stream>>>(
401+
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), in_stride,
402+
weight.data_ptr<scalar_t>(), static_cast<float>(epsilon), num_tokens,
403+
hidden_size, smem_elems);
415404
});
416405
}
417406

418-
419407
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
420408
VLLM_DISPATCH_FLOATING_TYPES( \
421409
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \

csrc/layernorm_quant_kernels.cu

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,19 @@
2020

2121
namespace vllm {
2222

23-
2423
template <typename T>
2524
__device__ __forceinline__ T warp_sum(T v) {
2625
#ifdef __HIP_PLATFORM_AMD__
2726
const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask
2827
#else
29-
const unsigned m = 0xffffffffu;
28+
const unsigned m = 0xffffffffu;
3029
#endif
3130
constexpr int kWidth = 32; // keep reduction over 32 lanes everywhere
3231
v += __shfl_down_sync(m, v, 16, kWidth);
33-
v += __shfl_down_sync(m, v, 8, kWidth);
34-
v += __shfl_down_sync(m, v, 4, kWidth);
35-
v += __shfl_down_sync(m, v, 2, kWidth);
36-
v += __shfl_down_sync(m, v, 1, kWidth);
32+
v += __shfl_down_sync(m, v, 8, kWidth);
33+
v += __shfl_down_sync(m, v, 4, kWidth);
34+
v += __shfl_down_sync(m, v, 2, kWidth);
35+
v += __shfl_down_sync(m, v, 1, kWidth);
3736
return v;
3837
}
3938

@@ -45,10 +44,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
4544
const int64_t input_stride, // <-- int64_t
4645
const scalar_t* __restrict__ weight, // [H]
4746
const float* __restrict__ scale, // [1]
48-
const float epsilon,
49-
const int /*num_tokens*/,
50-
const int hidden_size) {
51-
47+
const float epsilon, const int /*num_tokens*/, const int hidden_size) {
5248
const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride;
5349

5450
using acc_t = float;
@@ -65,27 +61,29 @@ __global__ void rms_norm_static_fp8_quant_kernel(
6561
__syncthreads();
6662

6763
if (threadIdx.x < 32) {
68-
acc_t v = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_sums_sh[threadIdx.x] : acc_t(0);
64+
acc_t v = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_sums_sh[threadIdx.x]
65+
: acc_t(0);
6966
acc_t total = warp_sum<acc_t>(v);
7067
if (threadIdx.x == 0) warp_sums_sh[0] = total;
7168
}
7269
__syncthreads();
7370

74-
const float inv_rms =
75-
rsqrtf(static_cast<float>(warp_sums_sh[0] / static_cast<acc_t>(hidden_size)) + epsilon);
71+
const float inv_rms = rsqrtf(
72+
static_cast<float>(warp_sums_sh[0] / static_cast<acc_t>(hidden_size)) +
73+
epsilon);
7674

7775
const float scale_inv = 1.0f / (*scale);
7876

7977
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
80-
const float x_f = static_cast<float>(in_row[i]);
81-
const scalar_t xn = static_cast<scalar_t>(x_f * inv_rms); // fp32 normalize → cast to T
82-
const scalar_t z = xn * weight[i]; // multiply in T
78+
const float x_f = static_cast<float>(in_row[i]);
79+
const scalar_t xn =
80+
static_cast<scalar_t>(x_f * inv_rms); // fp32 normalize → cast to T
81+
const scalar_t z = xn * weight[i]; // multiply in T
8382
out[blockIdx.x * hidden_size + i] =
8483
scaled_fp8_conversion<true, fp8_type>(static_cast<float>(z), scale_inv);
8584
}
8685
}
8786

88-
8987
/* Function specialization in the case of FP16/BF16 tensors.
9088
Additional optimizations we can make in this case are
9189
packed and vectorized operations, which help with the
@@ -198,13 +196,12 @@ fused_add_rms_norm_static_fp8_quant_kernel(
198196

199197
} // namespace vllm
200198

201-
202-
203199
// --- shared: match unfused launch exactly ---
204200
static inline int ln_block_threads_unified(int hidden_size) {
205201
int threads = (hidden_size >= 1024) ? 256
206-
: (hidden_size >= 512) ? 512
207-
: std::min(1024, ((hidden_size + 31) / 32) * 32);
202+
: (hidden_size >= 512)
203+
? 512
204+
: std::min(1024, ((hidden_size + 31) / 32) * 32);
208205
// warp-align and clamp to [128, 1024]
209206
threads = std::min(1024, std::max(128, ((threads + 31) / 32) * 32));
210207
return threads;
@@ -219,30 +216,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [T, H]
219216
TORCH_CHECK(weight.is_contiguous());
220217
TORCH_CHECK(input.stride(-1) == 1, "last dim must be contiguous");
221218

222-
const int hidden_size = input.size(-1);
223-
const int64_t input_stride = input.stride(-2); // row stride (== last_dim when 2D)
224-
const int num_tokens = input.numel() / hidden_size;
219+
const int hidden_size = input.size(-1);
220+
const int64_t input_stride =
221+
input.stride(-2); // row stride (== last_dim when 2D)
222+
const int num_tokens = input.numel() / hidden_size;
225223

226224
dim3 grid(num_tokens);
227225
dim3 block(ln_block_threads_unified(hidden_size)); // <-- match unfused
228226

229227
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
230228
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
231229

232-
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
233-
VLLM_DISPATCH_FP8_TYPES(out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
234-
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
235-
<<<grid, block, 0, stream>>>(
236-
out.data_ptr<fp8_t>(),
237-
input.data_ptr<scalar_t>(),
238-
input_stride,
239-
weight.data_ptr<scalar_t>(),
240-
scale.data_ptr<float>(),
241-
static_cast<float>(epsilon),
242-
num_tokens,
243-
hidden_size);
244-
});
245-
});
230+
VLLM_DISPATCH_FLOATING_TYPES(
231+
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
232+
VLLM_DISPATCH_FP8_TYPES(
233+
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
234+
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
235+
<<<grid, block, 0, stream>>>(
236+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
237+
input_stride, weight.data_ptr<scalar_t>(),
238+
scale.data_ptr<float>(), static_cast<float>(epsilon),
239+
num_tokens, hidden_size);
240+
});
241+
});
246242
// TORCH_CUDA_KERNEL_LAUNCH_CHECK();
247243
}
248244

0 commit comments

Comments
 (0)