Skip to content

Commit b3311b4

Browse files
authored
Add rms_norm related kernels. (sgl-project#4)
* Add rms_norm and fused_add_rms_norm kernel. * Add gemma_rms_norm and gemma_fused_add_rms_norm. * Move device check to utils.py. * Add at::SGLXPUNorm for rmsnorm related ops. * Use template parameter pack to replace get_update_vec_size. * Change the max work group size on BMG from 1024 to 512. Replace SIMD with NUM_REDUCE_STAGES. * Replace at::Tensor with torch::Tensor. * Replace at::SGLXPUNorm with at::native::xpu. Remove redundant comments. * Refactor preferred_vector_width. Remove 'default'. * Refactor '_check_layer_norm_inputs'. * Refactor template. * Replace dpcpp with sycl. * Refactor SYCL_DISPATCH_FLOATING_TYPES. * weight only 1 dtype. * Code Style. * Repalce std::exp with sycl::exp and other ops when dtype is bf16.
1 parent f5fe3af commit b3311b4

38 files changed

+2805
-97
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ jobs:
1919
pre-commit install
2020
2121
- name: Linting
22-
run: pre-commit run --all-files --show-diff-on-failure
22+
run: pre-commit run --all-files --show-diff-on-failure

.isort.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[settings]
2+
profile=black
3+
known_first_party=sglang

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ repos:
2626
rev: v0.11.7
2727
hooks:
2828
- id: ruff
29-
args: [--select=F401, --fixable=F401]
29+
args: [--select=F401, --fix] #able=F401]
3030
files: ^(benchmark/|docs/|examples/)
3131
exclude: \.ipynb$
3232
- repo: https://github.com/psf/black
@@ -38,7 +38,7 @@ repos:
3838
hooks:
3939
- id: codespell
4040
additional_dependencies: ['tomli']
41-
args: ['--toml', 'python/pyproject.toml', '-L', 'cann']
41+
args: ['--toml', 'pyproject.toml', '-L', 'cann']
4242
exclude: |
4343
(?x)^(
4444
test/srt/test_reasoning_parser\.py|

benchmark/bench_awq_dequant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import itertools
2-
from typing import List, Tuple
2+
from typing import Tuple
33

44
import torch
55
import triton

benchmark/bench_cutlass_mla.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
import copy
32
import itertools
43

54
import torch

benchmark/bench_fp8_blockwise_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import argparse
2-
import copy
32
import itertools
43

54
import deep_gemm
65
import torch
76
import triton
87
from deep_gemm import get_col_major_tma_aligned_tensor
98
from sgl_kernel import fp8_blockwise_scaled_mm
9+
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
10+
1011
from sglang.srt.layers.quantization.fp8_kernel import (
1112
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
1213
)
13-
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
1414

1515

1616
def get_weight_shapes(args):

benchmark/bench_lightning_attention_decode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
import math
32

43
import torch

benchmark/bench_moe_ep_post_reorder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import triton
33
from sgl_kernel import ep_moe_post_reorder
4+
45
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
56

67
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]

benchmark/bench_moe_ep_pre_reorder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import triton
33
from sgl_kernel import ep_moe_pre_reorder
4+
45
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
56

67
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]

benchmark/bench_moe_fused_gate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import itertools
2-
import math
3-
41
import torch
52
import triton
6-
import triton.language as tl
73
from sgl_kernel import moe_fused_gate
4+
85
from sglang.srt.layers.moe.topk import biased_grouped_topk
96

107

0 commit comments

Comments
 (0)