Skip to content

Commit b76753f

Browse files
authored
[Bugfix][Kernel] Support partial rotary embedding for MRoPE triton kernel (#22593)
Signed-off-by: Isotr0py <[email protected]>
1 parent b81fe83 commit b76753f

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

tests/kernels/test_mrope.py renamed to tests/kernels/core/test_mrope.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,13 @@ def unroll_model_tp_dict(model_tp_dict):
4242
model_tp_dict = {
4343
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
4444
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
45-
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2]
45+
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
46+
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
4647
}
4748

4849
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
4950
dtype_atol_rtol_list = [
50-
[torch.bfloat16, 1e-5, 1.6e-2],
51+
[torch.bfloat16, 1e-2, 1.6e-2],
5152
]
5253

5354
num_tokens_list = [11, 8192]
@@ -73,10 +74,12 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
7374

7475
rope_theta = config.rope_theta
7576
max_position = config.max_position_embeddings
77+
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
78+
rotary_dim = int(head_dim * partial_rotary_factor)
7679

7780
mrope_helper_class = get_rope(
7881
head_size=head_dim,
79-
rotary_dim=head_dim,
82+
rotary_dim=rotary_dim,
8083
max_position=max_position,
8184
base=rope_theta,
8285
is_neox_style=is_neox_style,
@@ -110,7 +113,10 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
110113
reason="Skipping CUDA/ROCm only tests.")
111114
@pytest.mark.parametrize(
112115
"model_name, tp_size",
113-
unroll_model_tp_dict({"Qwen/Qwen2-VL-7B-Instruct": [1, 2]}))
116+
unroll_model_tp_dict({
117+
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
118+
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
119+
}))
114120
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
115121
@pytest.mark.parametrize("num_tokens", [4])
116122
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
@@ -126,10 +132,12 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
126132
is_neox_style = True
127133
rope_theta = config.rope_theta
128134
max_position = config.max_position_embeddings
135+
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
136+
rotary_dim = int(head_dim * partial_rotary_factor)
129137

130138
mrope_helper_class = get_rope(
131139
head_size=head_dim,
132-
rotary_dim=head_dim,
140+
rotary_dim=rotary_dim,
133141
max_position=max_position,
134142
base=rope_theta,
135143
is_neox_style=is_neox_style,
@@ -145,7 +153,7 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
145153
# Create a wrapper that makes the in-place function appear functional
146154
def functional_forward_cuda(pos, q, k):
147155
"""Wrapper that converts in-place operation to functional style
148-
156+
149157
CUDA Graph does not support in-place operations.
150158
This wrapper creates working copies of the
151159
input tensors and modifies them.

vllm/model_executor/layers/rotary_embedding/mrope.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def _triton_qwen2vl_mrope_forward(
2525
n_qh: tl.constexpr,
2626
n_kh: tl.constexpr,
2727
hd: tl.constexpr,
28+
rd: tl.constexpr,
2829
pad_n_qh: tl.constexpr,
2930
pad_n_kh: tl.constexpr,
3031
pad_hd: tl.constexpr,
@@ -51,19 +52,19 @@ def _triton_qwen2vl_mrope_forward(
5152
h_end = t_end + mrope_section_h
5253

5354
# Updated stride calculation for half head_dim
54-
half_hd = hd // 2
55-
t_cos = cos + pid * half_hd
56-
h_cos = t_cos + num_tokens * half_hd
57-
w_cos = h_cos + num_tokens * half_hd
58-
t_sin = sin + pid * half_hd
59-
h_sin = t_sin + num_tokens * half_hd
60-
w_sin = h_sin + num_tokens * half_hd
55+
half_rd = rd // 2
56+
t_cos = cos + pid * half_rd
57+
h_cos = t_cos + num_tokens * half_rd
58+
w_cos = h_cos + num_tokens * half_rd
59+
t_sin = sin + pid * half_rd
60+
h_sin = t_sin + num_tokens * half_rd
61+
w_sin = h_sin + num_tokens * half_rd
6162

6263
# Updated offsets for half head_dim
6364
cos_offsets = tl.arange(0, pad_hd // 2)
6465
t_mask = cos_offsets < t_end
6566
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
66-
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_hd)
67+
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
6768

6869
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
6970
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
@@ -85,9 +86,9 @@ def _triton_qwen2vl_mrope_forward(
8586
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
8687
0, pad_hd // 2)[None, :]
8788
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
88-
0, pad_hd // 2)[None, :] < hd // 2)
89+
0, pad_hd // 2)[None, :] < rd // 2)
8990
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
90-
0, pad_hd // 2)[None, :] < hd // 2)
91+
0, pad_hd // 2)[None, :] < rd // 2)
9192

9293
q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
9394
mask=first_q_mask,
@@ -97,8 +98,8 @@ def _triton_qwen2vl_mrope_forward(
9798
other=0).to(sin_row.dtype)
9899

99100
# right half of the head
100-
second_half_q_offsets = first_half_q_offsets + (hd // 2)
101-
second_half_k_offsets = first_half_k_offsets + (hd // 2)
101+
second_half_q_offsets = first_half_q_offsets + (rd // 2)
102+
second_half_k_offsets = first_half_k_offsets + (rd // 2)
102103
second_q_mask = first_q_mask
103104
second_k_mask = first_k_mask
104105

@@ -130,6 +131,7 @@ def triton_mrope(
130131
sin: torch.Tensor,
131132
mrope_section: list[int],
132133
head_size: int,
134+
rotary_dim: int,
133135
) -> tuple[torch.Tensor, torch.Tensor]:
134136
"""Qwen2VL mrope kernel.
135137
@@ -166,6 +168,7 @@ def triton_mrope(
166168
n_q_head,
167169
n_kv_head,
168170
head_size,
171+
rotary_dim,
169172
pad_n_q_head,
170173
pad_n_kv_head,
171174
pad_hd,
@@ -300,6 +303,7 @@ def forward_cuda(
300303
sin,
301304
self.mrope_section,
302305
self.head_size,
306+
self.rotary_dim,
303307
)
304308

305309
return q.reshape(query_shape), k.reshape(key_shape)

0 commit comments

Comments
 (0)