Skip to content

Commit c371e38

Browse files
authored
update rewriting for qwen (#293)
* update rewriting * fix * changes
1 parent 34e801c commit c371e38

File tree

4 files changed

+26
-59
lines changed

4 files changed

+26
-59
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.2
55
+++++
66

7+
* :pr:`293`: second series of patches
78
* :pr:`292`: new patches for Qwen models
89

910
0.8.1

_unittests/ut_tasks/try_export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def _config_reduction(config, task):
6464
print(f"-- processor={type(processor)}")
6565

6666
inputs = dict(
67-
hidden_states=torch.rand((14308, 1176), dtype=torch.float32),
68-
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64),
67+
hidden_states=torch.rand((1292, 1176), dtype=torch.float32),
68+
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64),
6969
)
7070

7171
print(f"-- inputs: {self.string_type(inputs, with_shape=True)}")
@@ -89,7 +89,7 @@ def _config_reduction(config, task):
8989
export_inputs = inputs
9090
print()
9191
with torch_export_patches(
92-
patch_torch=True,
92+
patch_torch=False,
9393
patch_sympy=False,
9494
patch_transformers=True,
9595
verbose=1,

_unittests/ut_tasks/try_tasks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,8 @@ def config_reduction(config, task):
10441044
"content": [
10451045
{
10461046
"type": "image",
1047-
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
1047+
# "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
1048+
"image": "https://github.com/sdpython/teachpyx/blob/main/_doc/practice/tds-base/int.png?raw=true",
10481049
},
10491050
{"type": "text", "text": "Describe this image."},
10501051
],

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 20 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,52 +2255,6 @@ def forward(
22552255
hidden_states = hidden_states[reverse_indices, :]
22562256
return hidden_states
22572257

2258-
@torch.library.custom_op("custom::qwen25_attention", mutates_args={})
2259-
def qwen25_attention(
2260-
query_states: torch.Tensor,
2261-
key_states: torch.Tensor,
2262-
value_states: torch.Tensor,
2263-
cu_seqlens: torch.Tensor,
2264-
_cu_seqlens: torch.Tensor,
2265-
max_seqlen: torch.Tensor,
2266-
_max_seqlen: torch.Tensor,
2267-
scale: torch.Tensor,
2268-
) -> torch.Tensor:
2269-
return torch.empty(
2270-
key_states.shape[0],
2271-
value_states.shape[1],
2272-
max_seqlen,
2273-
value_states.shape[-1],
2274-
dtype=query_states.dtype,
2275-
device=query_states.device,
2276-
)
2277-
2278-
def make_undefined_dimension(i: int) -> torch.SymInt:
2279-
t = torch.ones((i * 2,))
2280-
t[:i] = 0
2281-
res = torch.nonzero(t).shape[0]
2282-
return res
2283-
2284-
@qwen25_attention.register_fake
2285-
def qwen25_attention_shape(
2286-
query_states,
2287-
key_states,
2288-
value_states,
2289-
cu_seqlens,
2290-
_cu_seqlens,
2291-
max_seqlen,
2292-
_max_seqlen,
2293-
scale,
2294-
):
2295-
return torch.empty(
2296-
key_states.shape[0],
2297-
value_states.shape[1],
2298-
max_seqlen, # make_undefined_dimension(max_seqlen), new dimension does not work
2299-
value_states.shape[-1],
2300-
dtype=query_states.dtype,
2301-
device=query_states.device,
2302-
)
2303-
23042258
class patched_Qwen2_5_VLVisionAttention:
23052259
_PATCHES_ = ["forward"]
23062260
_PATCHED_CLASS_ = (
@@ -2350,15 +2304,26 @@ def forward(
23502304
or torch.compiler.is_exporting()
23512305
):
23522306
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
2353-
attn_output = torch.ops.custom.qwen25_attention(
2354-
query_states,
2355-
key_states,
2356-
value_states,
2357-
cu_seqlens,
2358-
cu_seqlens,
2359-
max_seqlen,
2360-
max_seqlen,
2361-
torch.tensor(self.scaling, dtype=torch.float32),
2307+
attn_output = torch.onnx.ops.symbolic(
2308+
"custom::qwen25_attention",
2309+
(
2310+
query_states,
2311+
key_states,
2312+
value_states,
2313+
cu_seqlens,
2314+
cu_seqlens,
2315+
max_seqlen,
2316+
max_seqlen,
2317+
torch.tensor(self.scaling, dtype=torch.float32),
2318+
),
2319+
dtype=query_states.dtype,
2320+
shape=(
2321+
key_states.shape[0],
2322+
value_states.shape[1],
2323+
max_seqlen,
2324+
value_states.shape[-1],
2325+
),
2326+
version=1,
23622327
)
23632328
elif self.config._attn_implementation == "flash_attention_2":
23642329
# Flash Attention 2: Use cu_seqlens for variable length attention

0 commit comments

Comments
 (0)