Skip to content

Commit 0d3bd28

Browse files
authored
Diagonal Mask (#295)
* diagonal mask for attention * doc * fix * fix
1 parent 8215a72 commit 0d3bd28

File tree

4 files changed

+125
-28
lines changed

4 files changed

+125
-28
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.8.2
55
+++++
66

7-
* :pr:`292`, :pr:`293`, :pr:`294`: new patches for Qwen models
7+
* :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models
88

99
0.8.1
1010
+++++

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import transformers.integrations.sdpa_attention as sdpa_attention
55
import onnx
66
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
7-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, ignore_warnings
7+
from onnx_diagnostic.ext_test_case import (
8+
ExtTestCase,
9+
requires_transformers,
10+
requires_torch,
11+
ignore_warnings,
12+
has_onnxscript,
13+
)
814
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting
915
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1016
from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration
@@ -398,9 +404,69 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self):
398404
_is_torchdynamo_exporting()
399405
), f"exporting is not set to true? {torch.compiler.is_exporting_flag}"
400406
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
401-
self.assertEqualArray(expected, got)
407+
self.assertEqualArray(expected, got, atol=1e-5)
408+
409+
class Model(patched_class):
410+
def forward(
411+
self,
412+
hidden_states: torch.Tensor,
413+
cu_seqlens: torch.Tensor,
414+
rotary_pos_emb: torch.Tensor | None = None,
415+
position_embeddings1: torch.Tensor | None = None,
416+
position_embeddings2: torch.Tensor | None = None,
417+
**kwargs,
418+
) -> torch.Tensor:
419+
return patched_Qwen2_5_VLVisionAttention.forward(
420+
self,
421+
hidden_states,
422+
cu_seqlens,
423+
rotary_pos_emb=rotary_pos_emb,
424+
position_embeddings=(position_embeddings1, position_embeddings2),
425+
**kwargs,
426+
)
427+
428+
instance = Model(config.vision_config)
429+
instance.eval()
430+
431+
ds = dict(
432+
hidden_states={0: "d1"},
433+
cu_seqlens={0: "d3"},
434+
position_embeddings1={0: "d1"},
435+
position_embeddings2={0: "d1"},
436+
)
437+
inputs.update(
438+
dict(
439+
position_embeddings1=inputs["position_embeddings"][0],
440+
position_embeddings2=inputs["position_embeddings"][1],
441+
)
442+
)
443+
del inputs["position_embeddings"]
444+
for exporter in ("custom", "onnx-dynamo"):
445+
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
446+
if exporter == "onnx-dynamo" and not has_onnxscript("0.5.7"):
447+
raise unittest.SkipTest("needs onnxscript>=0.5.7")
448+
filename = self.get_dump_file(
449+
f"test_patched_qwen2_5_vl_vision_attention_forward.{exporter}.onnx"
450+
)
451+
to_onnx(
452+
instance,
453+
kwargs=inputs,
454+
dynamic_shapes=ds,
455+
exporter=exporter,
456+
filename=filename,
457+
)
458+
# exporter_kwargs={"report":True} if exporter != "custom" else {}
459+
self.assert_onnx_disc(
460+
f"test_patched_qwen2_5_vl_vision_attention_forward-{exporter}",
461+
onnx.load(filename),
462+
instance,
463+
inputs,
464+
atol=1e-3,
465+
rtol=1,
466+
)
402467

403-
@requires_transformers("5.0")
468+
@requires_transformers("4.99")
469+
@requires_torch("2.9.99")
404470
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
405471
def test_qwen2_5_vl_vision_attention_iteration(self):
406472
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (

onnx_diagnostic/export/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def to_onnx(
9393
)
9494
ort_fusions.optimize_for_ort(epo.model)
9595
if filename:
96-
epo.save(filename)
96+
epo.save(filename, external_data=True)
9797
return epo
9898

9999
if exporter == "modelbuilder":

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,8 @@ def forward(
20122012
if patch_qwen2_5:
20132013
import torch.nn.functional as F
20142014

2015+
use_loop_for_attention_in_qwen_2_5 = False
2016+
20152017
class patched_Qwen2_5_VLForConditionalGeneration:
20162018
_PATCHES_ = ["prepare_inputs_for_generation"]
20172019
_PATCHED_CLASS_ = (
@@ -2392,36 +2394,65 @@ def forward(
23922394
):
23932395
attention_interface = patched_sdpa_attention_forward
23942396

2395-
def _iteration(start_end, query_states, key_states, value_states):
2396-
return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
2397+
if use_loop_for_attention_in_qwen_2_5:
2398+
2399+
def _iteration(start_end, query_states, key_states, value_states):
2400+
return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
2401+
self,
2402+
start_end,
2403+
query_states,
2404+
key_states,
2405+
value_states,
2406+
scaling=self.scaling,
2407+
dropout=0.0 if not self.training else self.attention_dropout,
2408+
)
2409+
2410+
starts = cu_seqlens[:-1]
2411+
ends = cu_seqlens[1:]
2412+
# cu_seqlens = [0, 10, 14, 27]
2413+
# starts: [0, 10, 14]
2414+
# ends: [10, 14, 17]
2415+
# starts_ends: [[0, 10], [10, 14], [14, 27]]
2416+
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
2417+
attn_outputs = [
2418+
_iteration(start_end, query_states, key_states, value_states)
2419+
for start_end in starts_ends
2420+
]
2421+
# attn_outputs = torch._higher_order_ops.while_loop(
2422+
# attn_outputs = torch.ops.higher_order.while_loop(
2423+
# (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
2424+
# _iteration,
2425+
# (torch.tensor(0),
2426+
# starts_ends, query_states, key_states, value_states), tuple(),
2427+
# )
2428+
attn_output = torch.cat(attn_outputs, dim=1)
2429+
else:
2430+
# make square mask
2431+
indices = torch.arange(
2432+
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
2433+
)
2434+
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
2435+
cu_seqlens.dtype
2436+
)
2437+
dot = dot.sum(dim=0)
2438+
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
2439+
bool_mask = mask == 0
2440+
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
2441+
2442+
torch._check(bool_mask.shape[2] == key_states.shape[2])
2443+
torch._check(bool_mask.shape[3] == key_states.shape[2])
2444+
2445+
attn_output, _ = attention_interface(
23972446
self,
2398-
start_end,
23992447
query_states,
24002448
key_states,
24012449
value_states,
2450+
attention_mask=bool_mask,
24022451
scaling=self.scaling,
24032452
dropout=0.0 if not self.training else self.attention_dropout,
2453+
is_causal=False,
2454+
**kwargs,
24042455
)
2405-
2406-
starts = cu_seqlens[:-1]
2407-
ends = cu_seqlens[1:]
2408-
# cu_seqlens = [0, 10, 14, 27]
2409-
# starts: [0, 10, 14]
2410-
# ends: [10, 14, 17]
2411-
# starts_ends: [[0, 10], [10, 14], [14, 27]]
2412-
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
2413-
attn_outputs = [
2414-
_iteration(start_end, query_states, key_states, value_states)
2415-
for start_end in starts_ends
2416-
]
2417-
# attn_outputs = torch._higher_order_ops.while_loop(
2418-
# attn_outputs = torch.ops.higher_order.while_loop(
2419-
# (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
2420-
# _iteration,
2421-
# (torch.tensor(0),
2422-
# starts_ends, query_states, key_states, value_states), tuple(),
2423-
# )
2424-
attn_output = torch.cat(attn_outputs, dim=1)
24252456
else:
24262457
# Other implementations: Process each chunk separately
24272458
lengths = cu_seqlens[1:] - cu_seqlens[:-1]

0 commit comments

Comments
 (0)