Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ jobs:
run: python -m pip freeze

- name: qwen2.5_vl_instruct
run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 QWEN25ATTENTION=BIGMASK TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual
run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual

1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.3
+++++

* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator
* :pr:`323`: drops torch 2.8 on CI
* :pr:`322`: support rerunning onnx kernels with torch intermediate results in side-by-side
Expand Down
69 changes: 64 additions & 5 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,9 @@ def test_qwen2_5_vl_vision_attention_iteration(self):

@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
@requires_cuda()
def test_plug_packed_multi_head_attention_qwen25(self):
def test_plug_packed_multi_head_attention_qwen25_packed(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_versatile,
qwen_sdpa_attention_packed_versatile,
)

inputs = (
Expand Down Expand Up @@ -563,20 +563,79 @@ def test_plug_packed_multi_head_attention_qwen25(self):
).to("cuda"),
)

results = qwen_sdpa_attention_versatile.verify(
results = qwen_sdpa_attention_packed_versatile.verify(
*inputs, scaling=0.5, num_heads=16
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
self.assertLess(results.diffs[0]["abs"], 0.01)

results = qwen_sdpa_attention_packed_versatile.verify(
*inputs, scaling=0.11180339887498948, num_heads=16
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
self.assertLess(results.diffs[0]["abs"], 0.01)

@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_loopmha_versatile,
)

inputs = (
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.tensor(
[
0,
64,
128,
192,
256,
304,
368,
432,
496,
560,
608,
672,
736,
800,
864,
912,
976,
1040,
1104,
1168,
1216,
1232,
1248,
1264,
1280,
1292,
],
dtype=torch.int64,
),
)

results = qwen_sdpa_attention_loopmha_versatile.verify(
*inputs,
scaling=0.5,
num_heads=16,
dump_onnx_model=self.get_dump_file(
"test_plug_packed_multi_head_attention_qwen25.onnx"
"test_plug_packed_multi_head_attention_qwen25_loopmha.onnx"
),
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
self.assertEqual(len(results.eager_outputs), len(results.diffs))
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
self.assertLess(results.diffs[0]["abs"], 0.01)

results = qwen_sdpa_attention_versatile.verify(
results = qwen_sdpa_attention_loopmha_versatile.verify(
*inputs, scaling=0.11180339887498948, num_heads=16
)
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
Expand Down
13 changes: 11 additions & 2 deletions onnx_diagnostic/reference/ort_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
import numpy as np
from onnx import (
AttributeProto,
Expand Down Expand Up @@ -388,7 +388,7 @@ def _make_model_proto(
onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
opsets = {d.domain: d.version for d in onx.opset_import}
add = {}
for node in nodes:
for node in self.enumerate_nodes(onx.graph.node):
if node.domain and node.domain not in opsets and node.domain not in add:
add[node.domain] = 1
onx.opset_import.extend([oh.make_opsetid(k, v) for k, v in add.items()])
Expand All @@ -402,6 +402,15 @@ def _make_model_outputs(
) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]

def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]:
"Enumerates nodes recursively."
for node in nodes:
if node.op_type in {"Scan", "If", "Loop"}:
for att in node.attribute:
if att.type == AttributeProto.GRAPH:
yield from self.enumerate_nodes(att.g.node)
yield node

@classmethod
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ def LoopMHAAttention(
query_3d = op.Reshape(query_transposed, to_3d_shape)
value_3d = op.Reshape(op.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
key_3d = op.Reshape(op.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
cu_seqlens = op.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
num_patches = op.Size(cu_seqlens) - 1
seq_axis = op.Constant(value_ints=[1])
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
attn_output = op.Slice(value_3d, [0], [0], seq_axis)
for i in range(num_patches):
i_1d = op.Reshape(i, [1])
for i_patch in range(num_patches):
i_1d = op.Reshape(i_patch, [1])
i_plus_1_1d = i_1d + 1
start = op.Gather(cu_seqlens, i_1d, axis=0)
end = op.Gather(cu_seqlens, i_plus_1_1d, axis=0)
Expand All @@ -62,6 +63,14 @@ def LoopMHAAttention(
attn_output_4d = op.Reshape(attn_output, output_shape)
return attn_output_4d

def _add_com_microsoft_opset(function_proto):
opsets = {d.domain: d.version for d in function_proto.opset_import}
if "com.microsoft" not in opsets:
d = function_proto.opset_import.add()
d.domain = "com.microsoft"
d.version = 1
return function_proto

@onnxscript.script(opset=onnx_plugs_op)
def PackedAttention(
query,
Expand Down Expand Up @@ -143,20 +152,35 @@ def qwen_sdpa_attention(
return attn_output

# not ideal
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx(
qwen_sdpa_attention,
lambda qs, *args, **kwargs: torch.empty(
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
dtype=qs.dtype,
device=qs.device,
),
_add_com_microsoft_opset(PackedAttention.to_function_proto()),
n_inputs=4,
n_outputs=1,
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
name="qwen_sdpa_attention_packed",
)
PLUGS.append(qwen_sdpa_attention_packed_versatile)

qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx(
qwen_sdpa_attention,
lambda qs, *args, **kwargs: torch.empty(
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
dtype=qs.dtype,
device=qs.device,
),
PackedAttention.to_function_proto(),
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
n_inputs=4,
n_outputs=1,
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
name="qwen_sdpa_attention",
name="qwen_sdpa_attention_loopmha",
)
PLUGS.append(qwen_sdpa_attention_versatile)
PLUGS.append(qwen_sdpa_attention_loopmha_versatile)

class patched_Qwen2_5_VLForConditionalGeneration:
_PATCHES_ = ["prepare_inputs_for_generation"]
Expand Down Expand Up @@ -496,8 +520,8 @@ def forward(
or attention_interface is patched_sdpa_attention_forward
)
attention_strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
if is_sdpa and attention_strategy == "PACKED":
attn_output = qwen_sdpa_attention_versatile(
if is_sdpa and attention_strategy in "PACKED":
attn_output = qwen_sdpa_attention_packed_versatile(
query_states,
key_states,
value_states,
Expand Down Expand Up @@ -530,37 +554,37 @@ def forward(
version=1,
)
elif is_sdpa and attention_strategy == "LOOPMHA":
attn_output = qwen_sdpa_attention_loopmha_versatile(
query_states,
key_states,
value_states,
cu_seqlens,
self.scaling,
self.num_heads,
)

def _iteration(start_end, query_states, key_states, value_states):
return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
self,
start_end,
query_states,
key_states,
value_states,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
)

starts = cu_seqlens[:-1]
ends = cu_seqlens[1:]
# cu_seqlens = [0, 10, 14, 27]
# starts: [0, 10, 14]
# ends: [10, 14, 17]
# starts_ends: [[0, 10], [10, 14], [14, 27]]
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
attn_outputs = [
_iteration(start_end, query_states, key_states, value_states)
for start_end in starts_ends
]
# attn_outputs = torch._higher_order_ops.while_loop(
# attn_outputs = torch.ops.higher_order.while_loop(
# (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
# _iteration,
# (torch.tensor(0),
# starts_ends, query_states, key_states, value_states), tuple(),
# )
attn_output = torch.cat(attn_outputs, dim=1)
# to rewrite later with a for loop
# def _iteration(start_end, query_states, key_states, value_states):
# return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
# self,
# start_end,
# query_states,
# key_states,
# value_states,
# scaling=self.scaling,
# dropout=0.0 if not self.training else self.attention_dropout,
# )

# starts = cu_seqlens[:-1]
# ends = cu_seqlens[1:]
# torch._check(starts.shape[0] > 0)
# torch._check(ends.shape[0] > 0)
# starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
# attn_outputs = [
# _iteration(start_end, query_states, key_states, value_states)
# for start_end in starts_ends
# ]
# attn_output = torch.cat(attn_outputs, dim=1)
elif is_sdpa and attention_strategy == "BIGMASK":
# make square mask
indices = torch.arange(
Expand Down
Loading