Skip to content

Conversation

@xiaohongchen1991
Copy link
Contributor

@xiaohongchen1991 xiaohongchen1991 commented Feb 12, 2026

Purpose

This PR is to add Helion kernel for rms_norm_dynamic_per_token_quant operation. It follows the implementation from the vllm c version.

This is a subtask for #32962.

Kernel level benchmark

Environment
Python: 3.12.3
PyTorch: 2.9.0+cu129
Cuda: 12.9
Helion: 0.2.10
Triton: 3.5.0
GPU: NVIDIA RTX 5090 (single GPU)

Benchmark Setup
Latency measure: triton.testing.do_bench_cudagraph(rep=1000)
Memory measure: torch.cuda.reset_peak_memory_stats()
Baseline: torch.ops._C.rms_norm_dynamic_per_token_quant

Autotuning Setup
Default "full" autotuning effort. i.e. LFBOPatternSearch with initial_population=FROM_RANDOM, copies=5, max_generations=20.

case                             | baseline_ms | kernel_ms | speedup(x) | baseline_peak(MB) | kernel_peak(MB) | mem_improve(x)
---------------------------------+-------------+-----------+------------+-------------------+-----------------+---------------
num_tokens_1_hidden_size_2048    | 0.005       | 0.006     | 0.830      | 0.02              | 0.02            | 1.000
num_tokens_1_hidden_size_4096    | 0.005       | 0.005     | 1.082      | 0.04              | 0.04            | 1.000
num_tokens_1_hidden_size_8192    | 0.006       | 0.009     | 0.663      | 0.08              | 0.08            | 1.000
num_tokens_2_hidden_size_2048    | 0.005       | 0.006     | 0.779      | 0.04              | 0.04            | 1.000
num_tokens_2_hidden_size_4096    | 0.005       | 0.005     | 1.015      | 0.08              | 0.08            | 1.000
num_tokens_2_hidden_size_8192    | 0.006       | 0.007     | 0.932      | 0.15              | 0.15            | 1.000
num_tokens_4_hidden_size_2048    | 0.007       | 0.007     | 1.004      | 0.07              | 0.07            | 1.000
num_tokens_4_hidden_size_4096    | 0.005       | 0.005     | 1.031      | 0.14              | 0.14            | 1.000
num_tokens_4_hidden_size_8192    | 0.006       | 0.007     | 0.930      | 0.28              | 0.28            | 1.000
num_tokens_8_hidden_size_2048    | 0.005       | 0.007     | 0.768      | 0.14              | 0.14            | 1.000
num_tokens_8_hidden_size_4096    | 0.005       | 0.005     | 1.037      | 0.27              | 0.27            | 1.000
num_tokens_8_hidden_size_8192    | 0.006       | 0.007     | 0.946      | 0.54              | 0.54            | 1.000
num_tokens_16_hidden_size_2048   | 0.005       | 0.007     | 0.772      | 0.27              | 0.27            | 1.000
num_tokens_16_hidden_size_4096   | 0.006       | 0.005     | 1.075      | 0.54              | 0.54            | 1.000
num_tokens_16_hidden_size_8192   | 0.007       | 0.007     | 0.994      | 1.07              | 1.07            | 1.000
num_tokens_32_hidden_size_2048   | 0.005       | 0.007     | 0.772      | 0.53              | 0.53            | 1.000
num_tokens_32_hidden_size_4096   | 0.006       | 0.005     | 1.074      | 1.06              | 1.06            | 1.000
num_tokens_32_hidden_size_8192   | 0.007       | 0.007     | 0.993      | 2.12              | 2.12            | 1.000
num_tokens_64_hidden_size_2048   | 0.005       | 0.007     | 0.774      | 1.06              | 1.06            | 1.000
num_tokens_64_hidden_size_4096   | 0.006       | 0.006     | 1.010      | 2.11              | 2.11            | 1.000
num_tokens_64_hidden_size_8192   | 0.007       | 0.007     | 0.949      | 4.21              | 4.21            | 1.000
num_tokens_128_hidden_size_2048  | 0.006       | 0.007     | 0.798      | 2.10              | 2.10            | 1.000
num_tokens_128_hidden_size_4096  | 0.006       | 0.006     | 0.997      | 4.21              | 4.21            | 1.000
num_tokens_128_hidden_size_8192  | 0.008       | 0.008     | 0.996      | 8.41              | 8.41            | 1.000
num_tokens_256_hidden_size_2048  | 0.007       | 0.008     | 0.956      | 4.20              | 4.20            | 1.000
num_tokens_256_hidden_size_4096  | 0.009       | 0.007     | 1.203      | 8.40              | 8.40            | 1.000
num_tokens_256_hidden_size_8192  | 0.012       | 0.011     | 1.117      | 16.80             | 16.80           | 1.000
num_tokens_512_hidden_size_2048  | 0.010       | 0.008     | 1.263      | 8.40              | 8.40            | 1.000
num_tokens_512_hidden_size_4096  | 0.013       | 0.010     | 1.364      | 16.79             | 16.79           | 1.000
num_tokens_512_hidden_size_8192  | 0.020       | 0.017     | 1.147      | 33.58             | 33.58           | 1.000
num_tokens_1024_hidden_size_2048 | 0.016       | 0.010     | 1.571      | 16.79             | 16.79           | 1.000
num_tokens_1024_hidden_size_4096 | 0.022       | 0.017     | 1.294      | 33.57             | 33.57           | 1.000
num_tokens_1024_hidden_size_8192 | 0.034       | 0.033     | 1.035      | 67.13             | 67.13           | 1.000
num_tokens_2048_hidden_size_2048 | 0.029       | 0.016     | 1.798      | 33.58             | 33.58           | 1.000
num_tokens_2048_hidden_size_4096 | 0.039       | 0.031     | 1.278      | 67.13             | 67.13           | 1.000
num_tokens_2048_hidden_size_8192 | 0.092       | 0.080     | 1.147      | 134.25            | 134.25          | 1.000
num_tokens_4096_hidden_size_2048 | 0.053       | 0.033     | 1.596      | 67.15             | 67.15           | 1.000
num_tokens_4096_hidden_size_4096 | 0.102       | 0.076     | 1.349      | 134.26            | 134.26          | 1.000
num_tokens_4096_hidden_size_8192 | 0.259       | 0.266     | 0.975      | 268.49            | 268.49          | 1.000

Test Plan

  • Test correctness
    • Added unit test to cover its correctness
  • Benchmark performance at kernel level with different kernel implementations
  • Benchmark performance at model level

Test Result

Unit test to cover correctness

All unit test cases to cover its correctness passed.

pytest tests/kernels/helion/test_rms_norm_dynamic_per_token_quant.py

Kernel level benchmarking

See the results above. The regression from certain input shapes may due to sub-optimal config from autotuning. Pending further root causing.

Model level benchmarking

TODO

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Helion kernel for rms_norm_dynamic_per_token_quant, along with corresponding unit tests. The implementation correctly follows the logic of the existing CUDA version. However, I've identified a couple of incorrect comments that should be fixed for clarity. More importantly, there is significant redundant computation within the kernel's main loops, which likely impacts performance and could explain the regressions noted in the benchmarks. I've provided a detailed suggestion for refactoring to improve efficiency.

def rms_norm_dynamic_per_token_quant(
output: torch.Tensor, # [num_tokens, hidden_size]
input: torch.Tensor, # [num_tokens, hidden_size]
weight: torch.Tensor, # [num_tokens]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hint comment for weight is incorrect. It should be [hidden_size], not [num_tokens]. The weight tensor is indexed by tile_n, which iterates over hidden_size in the loops below.

Suggested change
weight: torch.Tensor, # [num_tokens]
weight: torch.Tensor, # [hidden_size]

num_tokens, hidden_size = input.shape
hl.specialize(hidden_size)

# only support fp8 quant for now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This comment is outdated. The kernel also supports torch.int8 as checked on line 59.

Suggested change
# only support fp8 quant for now
# supports fp8 and int8 quantization

Comment on lines +83 to +124
for tile_m in hl.tile(num_tokens, block_size=1):
rms = hl.zeros([tile_m], dtype=torch.float32)
for tile_n in hl.tile(hidden_size):
x_blk = input[tile_m, tile_n].to(torch.float32)
if residual is not None:
x_blk = x_blk + residual[tile_m, tile_n]
rms = rms + x_blk.pow(2).sum(dim=-1)

rms = torch.rsqrt(rms * (1.0 / hidden_size) + epsilon)
s_blk = hl.zeros([tile_m], dtype=torch.float32)

for tile_n in hl.tile(hidden_size):
x_blk = input[tile_m, tile_n].to(torch.float32)
if residual is not None:
x_blk = x_blk + residual[tile_m, tile_n]
x_blk = (x_blk * rms[:, None]).to(input.dtype) * weight[None, tile_n]
tmp_blk = torch.amax(torch.abs(x_blk), dim=-1).to(torch.float32)
s_blk = torch.maximum(s_blk, tmp_blk)

if scale_ub is not None:
scale_ub_s = hl.load(scale_ub, [])
s_blk = s_blk.clamp(max=scale_ub_s)
s_blk = s_blk * (1.0 / qtype_max)
s_blk = s_blk.clamp(min=min_scaling_factor)
scale[tile_m, 0] = s_blk

for tile_n in hl.tile(hidden_size):
x_blk = input[tile_m, tile_n].to(torch.float32)
if residual is not None:
x_blk = x_blk + residual[tile_m, tile_n]
residual[tile_m, tile_n] = x_blk.to(residual.dtype)
x_blk = (x_blk * rms[:, None]).to(input.dtype) * weight[None, tile_n]
if quant_dtype == torch.int8:
s_inv_blk = 1.0 / s_blk[:, None]
y_blk = x_blk * s_inv_blk
y_blk = y_blk.round()
else:
y_blk = x_blk / s_blk[:, None]

output[tile_m, tile_n] = y_blk.clamp(qtype_traits_min, qtype_traits_max).to(
output.dtype
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant redundant computation across the three loops over hidden_size.

  1. The value of input + residual is computed in each of the three loops.
  2. The normalized and weighted value (input + residual) * rms * weight is computed in both the second and third loops.

This re-computation can be inefficient and might be the cause of the performance regressions observed in the benchmarks for certain input shapes.

While this three-pass approach (compute RMS, compute scales, quantize) is common in hand-written CUDA kernels to manage register pressure and shared memory, the amount of re-computation here seems excessive.

Consider refactoring to reduce this redundancy. One approach could be:

  1. After computing rms, use a single loop over hidden_size to compute the normalized and weighted values: norm_x = (input + residual) * rms * weight.
  2. Store these norm_x values in a temporary buffer (e.g., of size hidden_size).
  3. Compute the quantization scale s_blk from this temporary buffer (e.g., using torch.amax).
  4. Perform the final quantization and store the output by iterating over the temporary buffer.

This introduces a memory-compute trade-off by using an intermediate buffer, but it could lead to better performance by avoiding expensive re-computations. It's worth exploring if this improves performance, especially for the shapes that currently show regressions.

@mergify
Copy link

mergify bot commented Feb 12, 2026

Hi @xiaohongchen1991, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@xiaohongchen1991 xiaohongchen1991 marked this pull request as draft February 12, 2026 14:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant