Skip to content

Commit 2eab178

Browse files
committed
fix
1 parent 72c6999 commit 2eab178

File tree

5 files changed

+125
-29
lines changed

5 files changed

+125
-29
lines changed

_unittests/ut_export/test_api.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
import os
12
import unittest
23
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
hide_stdout,
7+
has_transformers,
8+
ignore_warnings,
9+
)
410
from onnx_diagnostic.helpers import max_diff
511
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
612
from onnx_diagnostic.helpers.rt_helper import make_feeds
@@ -36,6 +42,7 @@ def forward(self, x, y):
3642
)
3743

3844
@hide_stdout()
45+
@ignore_warnings(FutureWarning)
3946
def test_tiny_llm_to_onnx(self):
4047
import onnxruntime
4148

@@ -68,6 +75,8 @@ def test_tiny_llm_to_onnx(self):
6875
filename=filename,
6976
)
7077
for exporter, filename in filenames.items():
78+
if not os.path.exists(filename):
79+
continue
7180
with self.subTest(exporter=f"validate-{exporter}"):
7281
sess = onnxruntime.InferenceSession(
7382
filename, providers=["CPUExecutionProvider"]
@@ -90,6 +99,8 @@ def test_tiny_llm_to_onnx(self):
9099

91100
expected = model(**torch_deepcopy(problem))
92101
for exporter, filename in filenames.items():
102+
if not os.path.exists(filename):
103+
continue
93104
with self.subTest(exporter=f"full-mask-{exporter}"):
94105
sess = onnxruntime.InferenceSession(
95106
filename, providers=["CPUExecutionProvider"]

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import transformers
44
import transformers.integrations.sdpa_attention as sdpa_attention
5+
import onnx
56
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
67
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, ignore_warnings
78
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting
@@ -387,13 +388,54 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self):
387388
expected = instance.forward(**inputs)
388389
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
389390
self.assertEqualArray(expected, got)
390-
if 1: # with torch_export_patches(patch_transformers=False, patch_torch=True):
391-
with fake_torchdynamo_exporting():
392-
assert (
393-
_is_torchdynamo_exporting()
394-
), f"exporting is not set to true? {torch.compiler.is_exporting_flag}"
395-
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
396-
self.assertEqualArray(expected, got)
391+
with fake_torchdynamo_exporting():
392+
assert (
393+
_is_torchdynamo_exporting()
394+
), f"exporting is not set to true? {torch.compiler.is_exporting_flag}"
395+
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
396+
self.assertEqualArray(expected, got)
397+
398+
@requires_transformers("4.55")
399+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
400+
def test_qwen2_5_vl_vision_attention_iteration(self):
401+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
402+
patched_Qwen2_5_VLVisionAttentionOneIteration,
403+
)
404+
405+
model = patched_Qwen2_5_VLVisionAttentionOneIteration()
406+
inputs = (
407+
torch.tensor([736, 800], dtype=torch.int64),
408+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
409+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
410+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
411+
)
412+
ds = (
413+
{},
414+
{0: "batch", 1: "length", 2: "dim"},
415+
{0: "batch", 1: "length", 2: "dim"},
416+
{0: "batch", 1: "length", 2: "dim"},
417+
)
418+
for exporter in ("custom", "onnx-dynamo"):
419+
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
420+
filename = self.get_dump_file(
421+
f"test_qwen2_5_vl_vision_attention_iteration.{exporter}.onnx"
422+
)
423+
to_onnx(
424+
model,
425+
inputs,
426+
dynamic_shapes=ds,
427+
exporter=exporter,
428+
filename=filename,
429+
exporter_kwargs={"report": True} if exporter == "onnx-dynamo" else {},
430+
)
431+
self.assert_onnx_disc(
432+
f"test_qwen2_5_vl_vision_attention_iteration-{exporter}",
433+
onnx.load(filename),
434+
model,
435+
inputs,
436+
atol=1e-3,
437+
rtol=1,
438+
)
397439

398440

399441
if __name__ == "__main__":

onnx_diagnostic/ext_test_case.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,13 +1214,14 @@ def assert_onnx_disc(
12141214
from .helpers.ort_session import InferenceSessionForTorch
12151215

12161216
kws = dict(with_shape=True, with_min_max=verbose > 1)
1217-
if verbose:
1218-
vname = test_name or "assert_onnx_disc"
1217+
vname = test_name or "assert_onnx_disc"
12191218
if test_name:
12201219
name = f"{test_name}.onnx"
1221-
print(f"[{vname}] save the onnx model into {name!r}")
1220+
if verbose:
1221+
print(f"[{vname}] save the onnx model into {name!r}")
12221222
name = self.dump_onnx(name, proto)
1223-
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
1223+
if verbose:
1224+
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
12241225
if verbose:
12251226
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
12261227
if use_ort:

onnx_diagnostic/helpers/log_helper.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -901,13 +901,19 @@ def view(
901901
else g.groupby([*key_index, *key_columns], dropna=False).sum()
902902
)
903903
not_unique = r[r["count"] > 1]
904+
if not_unique.shape[0] > 0 and os.environ.get("DUPLICATE", ""):
905+
filename = os.environ.get("DUPLICATE")
906+
subset = data.set_index([*key_index, *key_columns]).merge(
907+
not_unique.head(), left_index=True, right_index=True
908+
)
909+
subset.to_excel(filename)
904910
assert not_unique.shape[0] == 0, (
905911
f"view_def.name={view_def.name!r}, "
906912
f"unable to run the pivot with index={sorted(key_index)}, "
907913
f"key={sorted(key_columns)}, key_agg={key_agg}, values={sorted(values)}, "
908914
f"columns={sorted(data.columns)}, ignored={view_def.ignore_columns}, "
909-
f"not unique={set(data.columns) - unique}"
910-
f"\n--\n{not_unique.head(10)}"
915+
f"not unique={set(data.columns) - unique}, set DUPLICATE=<filename> "
916+
f"to store the duplicates in a excel file\n--\n{not_unique.head(10)}"
911917
)
912918

913919
# pivot
@@ -1000,8 +1006,12 @@ def _fix_aggregation_change(
10001006
keys = set(self.keys_time) - {columns_to_fix}
10011007
select = data[self.keys_time]
10021008
select_agg = select.groupby(list(keys)).count()
1009+
if select_agg.shape[0] == 0:
1010+
# nothing to fix
1011+
return data
10031012
assert select_agg[columns_to_fix].max() <= 1, (
1004-
f"Column {columns_to_fix!r} has two distinct values at least for one date\n"
1013+
f"Column {columns_to_fix!r} has two distinct values at least for one date, "
1014+
f"max={select_agg[columns_to_fix].max()}\n"
10051015
f"{select_agg[select_agg[columns_to_fix] > 1]}"
10061016
)
10071017

@@ -1038,6 +1048,16 @@ def _fix_aggregation_change(
10381048
f"data.columns.equals(res.columns)={data.columns.equals(res.columns)}, "
10391049
f"data.index.equals(res.columns)={data.index.equals(res.columns)}, "
10401050
)
1051+
select = res[self.keys_time]
1052+
select_agg = select.groupby(list(keys)).count()
1053+
if select_agg.shape[0] == 0:
1054+
# nothing to fix
1055+
return data
1056+
assert select_agg[columns_to_fix].max() <= 1, (
1057+
f"Column {columns_to_fix!r} has two distinct values at least for one date, "
1058+
f"max={select_agg[columns_to_fix].max()}\n"
1059+
f"{select_agg[select_agg[columns_to_fix] > 1]}"
1060+
)
10411061
return res
10421062

10431063
def _dropna(
@@ -1977,7 +1997,8 @@ def make_view_def(self, name: str) -> Optional[CubeViewDef]:
19771997
* **cmd:** command lines
19781998
* **raw-short:** raw data without all the unused columns
19791999
"""
1980-
fix_aggregation_change = ["model_speedup_input_set", "model_test_with"]
2000+
# This does not work.
2001+
fix_aggregation_change = [] # "model_speedup_input_set", "model_test_with"]
19812002
fs = ["suite", "model_suite", "task", "model_name", "model_task"]
19822003
index_cols = self._filter_column(fs, self.keys_time)
19832004
assert index_cols, (

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,6 +2265,34 @@ def forward(
22652265
hidden_states = hidden_states[reverse_indices, :]
22662266
return hidden_states
22672267

2268+
class patched_Qwen2_5_VLVisionAttentionOneIteration(torch.nn.Module):
2269+
def forward(
2270+
self,
2271+
start_end,
2272+
query_states,
2273+
key_states,
2274+
value_states,
2275+
scaling: float = 1.0,
2276+
dropout: float = 0.0,
2277+
**kwargs,
2278+
):
2279+
a = start_end[0].item()
2280+
b = start_end[1].item()
2281+
q = query_states[:, :, a:b, :]
2282+
k = key_states[:, :, a:b, :]
2283+
v = value_states[:, :, a:b, :]
2284+
return patched_sdpa_attention_forward(
2285+
self,
2286+
q,
2287+
k,
2288+
v,
2289+
attention_mask=None,
2290+
scaling=scaling,
2291+
dropout=dropout,
2292+
is_causal=False,
2293+
**kwargs,
2294+
)[0]
2295+
22682296
class patched_Qwen2_5_VLVisionAttention:
22692297
_PATCHES_ = ["forward"]
22702298
_PATCHED_CLASS_ = (
@@ -2361,22 +2389,15 @@ def forward(
23612389
attention_interface = patched_sdpa_attention_forward
23622390

23632391
def _iteration(start_end, query_states, key_states, value_states):
2364-
a = start_end[0]
2365-
b = start_end[1]
2366-
q = query_states[:, :, a:b, :]
2367-
k = key_states[:, :, a:b, :]
2368-
v = value_states[:, :, a:b, :]
2369-
return attention_interface(
2392+
return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
23702393
self,
2371-
q,
2372-
k,
2373-
v,
2374-
attention_mask=None,
2394+
start_end,
2395+
query_states,
2396+
key_states,
2397+
value_states,
23752398
scaling=self.scaling,
23762399
dropout=0.0 if not self.training else self.attention_dropout,
2377-
is_causal=False,
2378-
**kwargs,
2379-
)[0]
2400+
)
23802401

23812402
starts = cu_seqlens[:-1]
23822403
ends = cu_seqlens[1:]

0 commit comments

Comments
 (0)