Skip to content

Commit c0f0b70

Browse files
authored
[core] Support capture custom ops into aclgraph (#2113)
### What this PR does / why we need it? Thanks to the PR #426 make vllm-ascend support the aclgraph inference to reduce the host overhead. However, the capability of aclgraph strongly relies on the functionality provided by `torch.compile`, which is the key feature supported in torch 2.x . Therefore, capture custom op into aclgraph is only possible when it can be recognize and captured by `torch.compile`. In this PR, we register the meta implementation of current custom ops to enable the fx graph capture. And by doing that, insert those custom ops into aclgraph become a natural thing to the ascend runtime. ### Does this PR introduce _any_ user-facing change? No user face change. ### How was this patch tested? Tested in unittest, we will integrate the `rotary_embedding` op into a small custom model and use `torch.compile` and aclgraph to capture and replay it to verify its functionality. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@1b99028 --------- Signed-off-by: ganyi <[email protected]>
1 parent 1ab1541 commit c0f0b70

File tree

6 files changed

+332
-13
lines changed

6 files changed

+332
-13
lines changed

csrc/torch_binding.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@
2727

2828
namespace vllm_ascend {
2929

30+
AscendType get_dtype_from_torch(at::ScalarType scalarType)
31+
{
32+
if (scalarType == at::ScalarType::Float) {
33+
return AscendType::FP32;
34+
} else if (scalarType == at::ScalarType::BFloat16) {
35+
return AscendType::BF16;
36+
} else {
37+
return AscendType::FP16;
38+
}
39+
}
40+
3041
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
3142
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
3243
{

csrc/torch_binding_meta.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include <torch/extension.h>
2+
#include <torch/library.h>
3+
#include <torch/version.h>
4+
#include <torch_npu/csrc/core/npu/NPUStream.h>
5+
#include <torch_npu/csrc/framework/OpCommand.h>
6+
#include <torch_npu/csrc/npu/Module.h>
7+
#include "utils.h"
8+
/*
9+
* How to write a meta implementation for a custom operator (meta kernel):
10+
*
11+
* Meta implementations are used for shape and dtype inference, tracing, and export.
12+
* They do NOT perform any real computation or allocate device memory.
13+
* Instead, they return empty tensors with the correct shapes, dtypes, and device types.
14+
*
15+
* Steps to write a meta implementation:
16+
* 1. The function signature should match the operator's schema, but only use the arguments
17+
* necessary to infer output shapes and dtypes.
18+
* 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes.
19+
* 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype.
20+
* 4. Do NOT perform any real computation or data movement.
21+
* 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar.
22+
*
23+
* Example:
24+
* std::tuple<at::Tensor, at::Tensor> my_op_meta(
25+
* at::Tensor &input, int64_t some_param) {
26+
* // Infer output shape based on input and parameters
27+
* auto out_shape = ...;
28+
* at::Tensor out = at::empty_symint(out_shape, input.options());
29+
* // Return empty tensor(s) with correct shape/dtype
30+
* return {out, ...};
31+
* }
32+
*
33+
* See below for real examples.
34+
*/
35+
36+
namespace vllm_ascend {
37+
namespace meta {
38+
39+
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
40+
at::Tensor &positions,
41+
at::Tensor &query,
42+
at::Tensor &key,
43+
int64_t head_size,
44+
at::Tensor &cos_sin_cache,
45+
bool is_neox) {
46+
auto num_tokens = positions.sym_numel();
47+
auto query_hidden_size = query.sym_numel() / num_tokens;
48+
auto key_hidden_size = key.sym_numel() / num_tokens;
49+
50+
auto num_heads = query_hidden_size / head_size;
51+
auto num_kv_heads = key_hidden_size / head_size;
52+
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
53+
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
54+
55+
return {query_dst, key_dst};
56+
}
57+
58+
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
59+
at::Tensor &input,
60+
const int64_t org_vocab_start_index,
61+
const int64_t org_vocab_end_index,
62+
const int64_t num_org_vocab_padding,
63+
const int64_t added_vocab_start_index,
64+
const int64_t added_vocab_end_index) {
65+
66+
at::Tensor masked_input = at::empty_like(input);
67+
at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool));
68+
69+
return {masked_input, mask};
70+
}
71+
72+
73+
} // namespace meta
74+
} // namespace vllm_ascend
75+
76+
namespace {
77+
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
78+
// the custom kernel been captured into aclgraph
79+
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
80+
// Rotary embedding meta implementation
81+
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
82+
// Masked input and mask meta implementation
83+
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
84+
85+
}
86+
}

csrc/utils.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,3 @@
2929
}
3030

3131

32-
namespace vllm_ascend {
33-
AscendType get_dtype_from_torch(at::ScalarType scalarType)
34-
{
35-
if (scalarType == at::ScalarType::Float) {
36-
return AscendType::FP32;
37-
} else if (scalarType == at::ScalarType::BFloat16) {
38-
return AscendType::BF16;
39-
} else {
40-
return AscendType::FP16;
41-
}
42-
}
43-
} // namespace vllm_ascend

tests/e2e/singlecard/ops/test_rotary_embedding.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
# Only Neox style true scenario is supported for now
1818
IS_NEOX_STYLE = [True]
1919
DTYPES = [torch.half]
20-
HEAD_SIZES = [64, 96, 128, 256]
20+
HEAD_SIZES = [64, 64, 96, 128, 256]
2121
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
2222
NUM_HEADS = [17] # Arbitrary values for testing
2323
BATCH_SIZES = [5] # Arbitrary values for testing
2424
SEQ_LENS = [11, 4096] # Arbitrary values for testing
25+
NUM_TOKENS = [10, 21]
2526
SEEDS = [0]
2627
DEVICES = [f"npu:{0}"]
2728
# Set tolerance to 1 for quant ops
@@ -198,3 +199,146 @@ def test_rotary_embedding_quant_with_leading_dim(
198199
ref_key,
199200
atol=DEFAULT_ATOL,
200201
rtol=DEFAULT_RTOL)
202+
203+
204+
class ModelwithRotaryEmbedding(nn.Module):
205+
206+
def __init__(
207+
self,
208+
hidden_size: int,
209+
num_heads: int,
210+
head_size: int,
211+
rotary_dim: int,
212+
max_position_embeddings: int,
213+
base: int,
214+
is_neox_style: bool,
215+
dtype: torch.dtype,
216+
) -> None:
217+
super().__init__()
218+
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
219+
self.rope = RotaryEmbedding(
220+
head_size=head_size,
221+
rotary_dim=rotary_dim,
222+
max_position_embeddings=max_position_embeddings,
223+
base=base,
224+
is_neox_style=is_neox_style,
225+
dtype=dtype,
226+
)
227+
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)
228+
229+
def forward(
230+
self,
231+
positions: torch.Tensor,
232+
hidden_states: torch.Tensor,
233+
offsets: Optional[torch.Tensor] = None,
234+
) -> torch.Tensor:
235+
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
236+
qkv = self.qkv_proj(hidden_states)
237+
q, k, v = qkv.chunk(3, dim=-1)
238+
query, key = torch.ops._C.rotary_embedding(
239+
positions,
240+
q,
241+
k,
242+
self.rope.head_size,
243+
self.rope.cos_sin_cache,
244+
self.rope.is_neox_style,
245+
)
246+
query = query.view(q.shape)
247+
key = key.view(k.shape)
248+
o = self.o_proj(query)
249+
return o
250+
251+
252+
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
253+
# add a warmup graph replay for workaround
254+
ACL_GRPAH_FIRST_RUN = True
255+
256+
257+
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
258+
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
259+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
260+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
261+
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
262+
@pytest.mark.parametrize("dtype", DTYPES)
263+
@pytest.mark.parametrize("seed", SEEDS)
264+
@pytest.mark.parametrize("device", DEVICES)
265+
@torch.inference_mode()
266+
def test_capture_rotary_embedding_in_aclgraph(
267+
is_neox_style: bool,
268+
num_tokens: int,
269+
num_heads: int,
270+
head_size: int,
271+
rotary_dim: int,
272+
dtype: torch.dtype,
273+
seed: int,
274+
device: str,
275+
max_position_embeddings: int = 8192,
276+
base: int = 10000,
277+
):
278+
"""Test if the rotary embedding can be captured in aclgraph."""
279+
torch.manual_seed(seed)
280+
torch.set_default_device(device)
281+
if rotary_dim is None:
282+
rotary_dim = head_size
283+
model = ModelwithRotaryEmbedding(
284+
hidden_size=num_heads * head_size,
285+
num_heads=num_heads,
286+
head_size=head_size,
287+
rotary_dim=rotary_dim,
288+
max_position_embeddings=max_position_embeddings,
289+
base=base,
290+
is_neox_style=is_neox_style,
291+
dtype=dtype,
292+
)
293+
294+
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
295+
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
296+
# string match
297+
graph = str(gm.graph)
298+
assert "_C.rotary_embedding" in graph
299+
return gm
300+
301+
static_positions = torch.randint(0, max_position_embeddings,
302+
(num_tokens, ))
303+
static_hidden_states = torch.randn(num_tokens,
304+
num_heads * head_size,
305+
dtype=dtype,
306+
device="npu")
307+
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
308+
stream = torch.npu.Stream()
309+
stream.wait_stream(torch.npu.current_stream())
310+
with torch.npu.stream(stream):
311+
# warmup the fx graph before capture
312+
for i in range(3):
313+
static_output = compiled_model(static_positions,
314+
static_hidden_states,
315+
offsets=None)
316+
stream.wait_stream(torch.npu.current_stream())
317+
318+
aclgraph = torch.npu.NPUGraph()
319+
320+
with torch.npu.graph(aclgraph):
321+
# Capture the model in aclgraph.
322+
static_output = compiled_model(static_positions, static_hidden_states)
323+
# Capture the model in aclgraph.
324+
random_filled_positions = torch.randint(0,
325+
max_position_embeddings,
326+
(num_tokens, ),
327+
device="npu")
328+
random_filled_hidden_states = torch.randn(num_tokens,
329+
num_heads * head_size,
330+
dtype=dtype,
331+
device="npu")
332+
static_positions.copy_(random_filled_positions)
333+
static_hidden_states.copy_(random_filled_hidden_states)
334+
335+
aclgraph.replay()
336+
global ACL_GRPAH_FIRST_RUN
337+
if ACL_GRPAH_FIRST_RUN:
338+
ACL_GRPAH_FIRST_RUN = False
339+
return
340+
output_reference = model(static_positions, static_hidden_states)
341+
torch.testing.assert_close(static_output,
342+
output_reference,
343+
atol=DEFAULT_ATOL,
344+
rtol=DEFAULT_RTOL)

vllm_ascend/meta_registration.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
from torch.library import Library
3+
4+
# This file provides a template and registration utilities for writing "meta" implementations
5+
# of custom operators in Python for the vllm_ascend project.
6+
#
7+
# We offer two ways to implement meta implementations for custom ops:
8+
# 1. Python meta implementation (as shown in this file): Write a Python function that
9+
# takes the same arguments as your operator and returns empty tensors with the correct
10+
# shapes and dtypes. This is useful for rapid prototyping and for ops that are only
11+
# used in Python.
12+
# 2. C++ meta implementation: You can also implement the meta function in C++ for better
13+
# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp`
14+
# for examples of C++ meta implementations and how to register them.
15+
#
16+
# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which
17+
# is essential for supporting `torch.compile` and aclgraph.
18+
19+
# How to add a new meta implementation in Python:
20+
# -------------------------------------
21+
# 1. Write a Python function that takes the same arguments as your operator, and returns
22+
# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes.
23+
# Do NOT perform any real computation or allocate device memory.
24+
#
25+
# 2. Register your meta function using `register_meta_if_necessary`, providing:
26+
# - The namespace (usually "_C" for custom ops)
27+
# - The operator name (as registered in C++)
28+
# - The Python meta function
29+
# - (Optional) The overload name, if your op has overloads
30+
#
31+
# 3. The registration utility will check if a meta implementation already exists for your op,
32+
# and only register if necessary. This avoids duplicate registrations.
33+
#
34+
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
35+
#
36+
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
37+
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
38+
# and aclgraph.
39+
#
40+
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
41+
42+
lib = Library("_C", "IMPL")
43+
44+
45+
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
46+
if overload != "":
47+
op_name = op_name + "." + overload
48+
schema_to_find = ns + "::" + op_name
49+
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key(
50+
"Meta")
51+
if schema_to_find in meta_impl_list:
52+
return
53+
lib.impl(op_name, fn, "Meta")
54+
55+
56+
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
57+
key: torch.Tensor, head_size: int,
58+
cos_sin_cache: torch.Tensor, is_neox: bool):
59+
60+
num_tokens = positions.numel()
61+
query_hidden_size = query.numel() // num_tokens
62+
key_hidden_size = key.numel() // num_tokens
63+
num_heads = query_hidden_size // head_size
64+
num_kv_heads = key_hidden_size // head_size
65+
66+
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
67+
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
68+
return query_dst, key_dst
69+
70+
71+
def get_masked_input_and_mask_meta(input: torch.Tensor,
72+
org_vocab_start_index: int,
73+
org_vocab_end_index: int,
74+
num_org_vocab_padding: int,
75+
added_vocab_start_index: int,
76+
added_vocab_end_index: int):
77+
78+
masked_input = torch.empty_like(input)
79+
mask = torch.empty_like(input).to(torch.bool)
80+
81+
return masked_input, mask
82+
83+
84+
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
85+
register_meta_if_necessary("_C", "get_masked_input_and_mask",
86+
get_masked_input_and_mask_meta)

vllm_ascend/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ def enable_custom_op():
214214
if _CUSTOM_OP_ENABLED is not None:
215215
return _CUSTOM_OP_ENABLED
216216
try:
217+
# isort: off
217218
# register custom ops into torch_library here
218219
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
220+
# register the meta implementation for custom kernel if necessary
221+
import vllm_ascend.meta_registration # type: ignore # noqa: F401
222+
# isort: on
219223
_CUSTOM_OP_ENABLED = True
220224
except ImportError:
221225
_CUSTOM_OP_ENABLED = False

0 commit comments

Comments
 (0)