-
-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[WIP][Kernel] Add Helion kernel for rms_norm_dynamic_per_token_quant #34432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP][Kernel] Add Helion kernel for rms_norm_dynamic_per_token_quant #34432
Conversation
Signed-off-by: Sean Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new Helion kernel for rms_norm_dynamic_per_token_quant, along with corresponding unit tests. The implementation correctly follows the logic of the existing CUDA version. However, I've identified a couple of incorrect comments that should be fixed for clarity. More importantly, there is significant redundant computation within the kernel's main loops, which likely impacts performance and could explain the regressions noted in the benchmarks. I've provided a detailed suggestion for refactoring to improve efficiency.
| def rms_norm_dynamic_per_token_quant( | ||
| output: torch.Tensor, # [num_tokens, hidden_size] | ||
| input: torch.Tensor, # [num_tokens, hidden_size] | ||
| weight: torch.Tensor, # [num_tokens] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| num_tokens, hidden_size = input.shape | ||
| hl.specialize(hidden_size) | ||
|
|
||
| # only support fp8 quant for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for tile_m in hl.tile(num_tokens, block_size=1): | ||
| rms = hl.zeros([tile_m], dtype=torch.float32) | ||
| for tile_n in hl.tile(hidden_size): | ||
| x_blk = input[tile_m, tile_n].to(torch.float32) | ||
| if residual is not None: | ||
| x_blk = x_blk + residual[tile_m, tile_n] | ||
| rms = rms + x_blk.pow(2).sum(dim=-1) | ||
|
|
||
| rms = torch.rsqrt(rms * (1.0 / hidden_size) + epsilon) | ||
| s_blk = hl.zeros([tile_m], dtype=torch.float32) | ||
|
|
||
| for tile_n in hl.tile(hidden_size): | ||
| x_blk = input[tile_m, tile_n].to(torch.float32) | ||
| if residual is not None: | ||
| x_blk = x_blk + residual[tile_m, tile_n] | ||
| x_blk = (x_blk * rms[:, None]).to(input.dtype) * weight[None, tile_n] | ||
| tmp_blk = torch.amax(torch.abs(x_blk), dim=-1).to(torch.float32) | ||
| s_blk = torch.maximum(s_blk, tmp_blk) | ||
|
|
||
| if scale_ub is not None: | ||
| scale_ub_s = hl.load(scale_ub, []) | ||
| s_blk = s_blk.clamp(max=scale_ub_s) | ||
| s_blk = s_blk * (1.0 / qtype_max) | ||
| s_blk = s_blk.clamp(min=min_scaling_factor) | ||
| scale[tile_m, 0] = s_blk | ||
|
|
||
| for tile_n in hl.tile(hidden_size): | ||
| x_blk = input[tile_m, tile_n].to(torch.float32) | ||
| if residual is not None: | ||
| x_blk = x_blk + residual[tile_m, tile_n] | ||
| residual[tile_m, tile_n] = x_blk.to(residual.dtype) | ||
| x_blk = (x_blk * rms[:, None]).to(input.dtype) * weight[None, tile_n] | ||
| if quant_dtype == torch.int8: | ||
| s_inv_blk = 1.0 / s_blk[:, None] | ||
| y_blk = x_blk * s_inv_blk | ||
| y_blk = y_blk.round() | ||
| else: | ||
| y_blk = x_blk / s_blk[:, None] | ||
|
|
||
| output[tile_m, tile_n] = y_blk.clamp(qtype_traits_min, qtype_traits_max).to( | ||
| output.dtype | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant redundant computation across the three loops over hidden_size.
- The value of
input + residualis computed in each of the three loops. - The normalized and weighted value
(input + residual) * rms * weightis computed in both the second and third loops.
This re-computation can be inefficient and might be the cause of the performance regressions observed in the benchmarks for certain input shapes.
While this three-pass approach (compute RMS, compute scales, quantize) is common in hand-written CUDA kernels to manage register pressure and shared memory, the amount of re-computation here seems excessive.
Consider refactoring to reduce this redundancy. One approach could be:
- After computing
rms, use a single loop overhidden_sizeto compute the normalized and weighted values:norm_x = (input + residual) * rms * weight. - Store these
norm_xvalues in a temporary buffer (e.g., of sizehidden_size). - Compute the quantization scale
s_blkfrom this temporary buffer (e.g., usingtorch.amax). - Perform the final quantization and store the output by iterating over the temporary buffer.
This introduces a memory-compute trade-off by using an intermediate buffer, but it could lead to better performance by avoiding expensive re-computations. It's worth exploring if this improves performance, especially for the shapes that currently show regressions.
|
Hi @xiaohongchen1991, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Purpose
This PR is to add Helion kernel for
rms_norm_dynamic_per_token_quantoperation. It follows the implementation from the vllm c version.This is a subtask for #32962.
Kernel level benchmark
Environment
Python: 3.12.3
PyTorch: 2.9.0+cu129
Cuda: 12.9
Helion: 0.2.10
Triton: 3.5.0
GPU: NVIDIA RTX 5090 (single GPU)
Benchmark Setup
Latency measure:
triton.testing.do_bench_cudagraph(rep=1000)Memory measure:
torch.cuda.reset_peak_memory_stats()Baseline:
torch.ops._C.rms_norm_dynamic_per_token_quantAutotuning Setup
Default "full" autotuning effort. i.e. LFBOPatternSearch with initial_population=FROM_RANDOM, copies=5, max_generations=20.
Test Plan
Test Result
Unit test to cover correctness
All unit test cases to cover its correctness passed.
Kernel level benchmarking
See the results above. The regression from certain input shapes may due to sub-optimal config from autotuning. Pending further root causing.
Model level benchmarking
TODO