Skip to content

Commit fa18368

Browse files
authored
Supports for kwargs in onnx plugs (#312)
* Fix minor bugs in onnx_plug * fix * fix * fix * fix unittest * fix * fix * update script * fix * fix test * fix * fix documentation
1 parent 99de3c7 commit fa18368

File tree

8 files changed

+420
-102
lines changed

8 files changed

+420
-102
lines changed

_doc/technical/plot_gemm_or_matmul_add.py

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
5353
oh.make_node("Add", ["mm", "B"], ["MatMulAdd"]),
5454
oh.make_node("FusedMatMul", ["A", "X"], ["fmm"], domain="com.microsoft"),
5555
oh.make_node("Add", ["fmm", "B"], ["FusedMatMulAdd"]),
56+
oh.make_node("Cast", ["A"], ["Afloat"], to=onnx.TensorProto.FLOAT),
57+
oh.make_node("Cast", ["B"], ["Bfloat"], to=onnx.TensorProto.FLOAT),
58+
oh.make_node("Cast", ["X"], ["Xfloat"], to=onnx.TensorProto.FLOAT),
59+
oh.make_node("Gemm", ["Afloat", "Xfloat"], ["gmmfloat"]),
60+
oh.make_node("Add", ["gmmfloat", "Bfloat"], ["gemmaddfloat"]),
61+
oh.make_node("Cast", ["gemmaddfloat"], ["CastGemmAddCast"], to=itype),
62+
oh.make_node("Gemm", ["Afloat", "Xfloat", "Bfloat"], ["GemmOnlyfloat"]),
63+
oh.make_node("Cast", ["GemmOnlyfloat"], ["CastGemmOnlyCast"], to=itype),
5664
],
5765
"test",
5866
[
@@ -65,6 +73,8 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
6573
oh.make_tensor_value_info("GemmAdd", itype, ["a", "c"]),
6674
oh.make_tensor_value_info("FusedMatMulAdd", itype, ["a", "c"]),
6775
oh.make_tensor_value_info("MatMulAdd", itype, ["a", "c"]),
76+
oh.make_tensor_value_info("CastGemmAddCast", itype, ["a", "c"]),
77+
oh.make_tensor_value_info("CastGemmOnlyCast", itype, ["a", "c"]),
6878
],
6979
),
7080
opset_imports=[oh.make_opsetid("", 22)],
@@ -85,7 +95,7 @@ def matrix_diff(tensors):
8595
dtype = np.float16
8696
model = make_model_gemm(itype)
8797

88-
A = np.random.randn(512, 256).astype(dtype)
98+
A = np.random.randn(1280, 256).astype(dtype)
8999
X = np.random.randn(256, 256).astype(dtype)
90100
B = np.random.randn(256).astype(dtype)
91101
feeds = dict(A=A, X=X, B=B)
@@ -112,9 +122,9 @@ def matrix_diff(tensors):
112122
# %%
113123
# Let's try with CUDA and float32 if it is available.
114124

115-
A = torch.randn((512, 512), dtype=torch.float32)
116-
X = torch.randn((512, 512), dtype=torch.float32)
117-
B = torch.randn((512), dtype=torch.float32)
125+
A = torch.randn((1280, 1280), dtype=torch.float32)
126+
X = torch.randn((1280, 1280), dtype=torch.float32)
127+
B = torch.randn((1280), dtype=torch.float32)
118128

119129
for itype, dtype, device in [
120130
(onnx.TensorProto.FLOAT16, torch.float16, "cpu"),
@@ -144,8 +154,10 @@ def matrix_diff(tensors):
144154
# are similar to the others coefficients. What if we make them
145155
# a lot higher.
146156

147-
B = (torch.arange(512, dtype=torch.float32) + 1) / 512 * 16384
148-
labels = ["linear", *[o.name for o in model.graph.output], "a @ x + b"]
157+
A = A / A.max()
158+
X = X / X.max()
159+
B = (torch.arange(1280, dtype=torch.float32) + 1) / 1280 * 16
160+
labels = ["F.linear", *[o.name for o in model.graph.output], "a @ x + b"]
149161
all_results = {}
150162

151163
for itype, dtype, device in [
@@ -187,28 +199,58 @@ def matrix_diff(tensors):
187199
# bias value vs discrepancies
188200
# ===========================
189201
#
190-
# Let's compare GemmOnly (so bias is included) and Gemm+Add.
191-
192-
i, j = 1, -1
193-
labs = labels[i], labels[j]
194-
195-
fig, ax = plt.subplots(len(all_results), 2, figsize=(8, 2.5 * len(results)))
196-
for pos, ((device, dtype), results) in enumerate(all_results.items()):
197-
m1, m2 = results[i], results[j]
198-
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
199-
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
200-
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
201-
ax[pos, 0].plot(B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), ".")
202-
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}")
203-
204-
corr = matrix_diff(results)
205-
ax[pos, 1].imshow(corr, cmap="Blues", vmin=0, vmax=corr.max())
206-
# ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
207-
ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45)
208-
ax[pos, 1].set_yticks(range(len(labels)), labels)
209-
ax[pos, 1].set_title(f"max={diff.max()}")
202+
# Let's compare torch linear with GemmOnly.
203+
204+
205+
def make_figure_axis(all_results, i, j):
206+
labs = labels[i], labels[j]
207+
fig, ax = plt.subplots(len(all_results), 2, figsize=(12, 4 * len(all_results)))
208+
for pos, ((device, dtype), results) in enumerate(all_results.items()):
209+
m1, m2 = results[i], results[j]
210+
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
211+
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
212+
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
213+
ax[pos, 0].plot(
214+
B.tolist(), (diff.detach().cpu() + torch.rand(1280) * expand).tolist(), "."
215+
)
216+
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}", fontsize=10)
217+
218+
corr = matrix_diff(results)
219+
ax[pos, 1].imshow(corr, cmap="Wistia", vmin=0, vmax=corr.max())
220+
# ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
221+
ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45, ha="right", fontsize=10)
222+
ax[pos, 1].set_yticks(range(len(labels)), labels, fontsize=10)
223+
ax[pos, 1].set_title(f"max={diff.max():1.2g}", fontsize=10)
224+
for _i in range(corr.shape[0]):
225+
for _j in range(corr.shape[1]):
226+
ax[pos, 1].text(
227+
_j,
228+
_i,
229+
f"{corr[_i, _j]:1.1g}",
230+
ha="center",
231+
va="center",
232+
color="black",
233+
fontsize=8,
234+
)
235+
fig.suptitle(
236+
f"Left column: discrepancies {labs[0]} VS {labs[1]}\n"
237+
f"Right column: max absolute error, across all configuration\n"
238+
f"white is good, orange is not"
239+
)
240+
return fig, ax
241+
242+
243+
fig, ax = make_figure_axis(all_results, 0, 1)
210244
fig.tight_layout()
211-
fig.savefig("plot_gemm_or_matmul_add.png")
245+
fig.savefig("plot_gemm_or_matmul_add1.png")
246+
247+
# %%
248+
# Let's compare with ``A @ X + B``.
249+
250+
fig, ax = make_figure_axis(all_results, -1, 1)
251+
fig.tight_layout()
252+
fig.savefig("plot_gemm_or_matmul_add2.png")
253+
212254

213255
# %%
214256
# Discrepancies do not happen all the time but it is very likely to happen.

_unittests/ut_export/test_onnx_plug.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import unittest
2+
import onnx
23
import onnx.helper as oh
34
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch, hide_stdout, ignore_warnings
56
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
67
from onnx_diagnostic.export.api import to_onnx
78

@@ -36,7 +37,9 @@ def make_function_proto():
3637
self.assertEqual(len(res.diffs), 1)
3738
self.assertEqual(res.diffs[0]["abs"], 0)
3839

39-
def test_onnx_plug_export(self):
40+
@hide_stdout()
41+
@ignore_warnings(FutureWarning)
42+
def test_onnx_plug_export_nokwargs(self):
4043
def _test_customsub(x, y):
4144
return x - y
4245

@@ -61,7 +64,7 @@ def forward(self, x):
6164

6265
replacements = [
6366
EagerDirectReplacementWithOnnx(
64-
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1
67+
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1, verbose=1
6568
)
6669
]
6770

@@ -83,7 +86,95 @@ def forward(self, x):
8386
onnx_plugs=replacements,
8487
target_opset=22,
8588
)
86-
self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,))
89+
self.assert_onnx_disc(
90+
"test_onnx_plug_export_nokwargs_custom", onx.model_proto, model, (x,)
91+
)
92+
93+
if not has_torch("2.9"):
94+
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
95+
with self.subTest(exporter="onnx-dynamo"):
96+
onx = to_onnx(
97+
model,
98+
(x,),
99+
dynamic_shapes=ds,
100+
exporter="onnx-dynamo",
101+
onnx_plugs=replacements,
102+
target_opset=22,
103+
)
104+
self.assert_onnx_disc(
105+
"test_onnx_plug_export_nokwargs_onnx_dynamo", onx.model_proto, model, (x,)
106+
)
107+
108+
@unittest.skip("not ready yet")
109+
@hide_stdout()
110+
@ignore_warnings(FutureWarning)
111+
def test_onnx_plug_export_kwargs(self):
112+
def _test_customdiv(x, y, epsilon: float = 1e-5):
113+
return x / (y + epsilon)
114+
115+
def _test_customdiv_shape(x, y, *args, **kwargs):
116+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
117+
118+
def make_function_proto():
119+
f = oh.make_function(
120+
"onnx_plug",
121+
"_test_customdiv",
122+
["x", "y"],
123+
["z"],
124+
[
125+
oh.make_node("Constant", [], ["eps"]),
126+
oh.make_node("Add", ["y", "eps"], ["yeps"]),
127+
oh.make_node("Div", ["x", "yeps"], ["z"]),
128+
],
129+
opset_imports=[oh.make_opsetid("", 22)],
130+
attributes=["epsilon"],
131+
)
132+
att = onnx.AttributeProto()
133+
att.name = "value_float"
134+
att.ref_attr_name = "epsilon"
135+
att.type = onnx.AttributeProto.FLOAT
136+
f.node[0].attribute.append(att)
137+
return f
138+
139+
class Model(torch.nn.Module):
140+
def forward(self, x):
141+
y = x.sum(axis=1, keepdim=True)
142+
d = torch.ops.onnx_plug._test_customdiv(x, y, epsilon=3.5)
143+
return torch.abs(d)
144+
145+
replacements = [
146+
EagerDirectReplacementWithOnnx(
147+
_test_customdiv,
148+
_test_customdiv_shape,
149+
make_function_proto(),
150+
2,
151+
1,
152+
kwargs=dict(epsilon=1e-5),
153+
verbose=1,
154+
)
155+
]
156+
157+
x = torch.randn((3, 4), dtype=torch.float32)
158+
model = Model()
159+
expected = model(x)
160+
ds = ({0: "d1", 1: "d2"},)
161+
ep = torch.export.export(model, (x,), dynamic_shapes=self.use_dyn_not_str(ds))
162+
self.assertIn("torch.ops.onnx_plug._test_customdiv.default", str(ep))
163+
got = ep.module()(x)
164+
self.assertEqualArray(expected, got)
165+
166+
with self.subTest(exporter="custom"):
167+
onx = to_onnx(
168+
model,
169+
(x,),
170+
dynamic_shapes=ds,
171+
exporter="custom",
172+
onnx_plugs=replacements,
173+
target_opset=22,
174+
)
175+
self.assert_onnx_disc(
176+
"test_onnx_plug_export_kwargs_custom", onx.model_proto, model, (x,)
177+
)
87178

88179
if not has_torch("2.9"):
89180
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
@@ -97,7 +188,7 @@ def forward(self, x):
97188
target_opset=22,
98189
)
99190
self.assert_onnx_disc(
100-
"test_onnx_plug_export_onnx_dynamo", onx.model_proto, model, (x,)
191+
"test_onnx_plug_export_kwargs_onnx_dynamo", onx.model_proto, model, (x,)
101192
)
102193

103194

_unittests/ut_tasks/try_export.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def test_imagetext2text_qwen_2_5_vl_instruct_visual(self):
4545
exporter = os.environ.get("EXPORTER", "custom")
4646

4747
from transformers import AutoModel, AutoProcessor
48+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
49+
PLUGS,
50+
)
4851

4952
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
5053
# model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
@@ -82,11 +85,17 @@ def _config_reduction(config, task):
8285
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
8386
print(f"-- processor={type(processor)}")
8487

88+
big_inputs = dict(
89+
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
90+
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
91+
)
92+
print("-- save inputs")
8593
inputs = dict(
8694
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
8795
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
8896
)
8997
print("-- save inputs")
98+
torch.save(big_inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.big.pt"))
9099
torch.save(inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.pt"))
91100

92101
print(f"-- inputs: {self.string_type(inputs, with_shape=True)}")
@@ -115,15 +124,6 @@ def _config_reduction(config, task):
115124
verbose=1,
116125
stop_if_static=2,
117126
):
118-
if exporter == "onnx-dynamo":
119-
# The exported program in ONNXProgram cannot be restored.
120-
ep2 = torch.export.export(
121-
model.visual,
122-
(),
123-
kwargs=export_inputs,
124-
dynamic_shapes=self.use_dyn_not_str(dynamic_shapes),
125-
)
126-
torch.export.save(ep2, f"{fileep}.backup.pt2")
127127
to_onnx(
128128
model.visual,
129129
kwargs=export_inputs,
@@ -134,6 +134,7 @@ def _config_reduction(config, task):
134134
save_ep=(fileep, 2**35),
135135
target_opset=22,
136136
optimize=True,
137+
onnx_plugs=PLUGS,
137138
)
138139

139140
pt2_files = [f"{fileep}.backup.pt2", f"{fileep}.ep.pt2", f"{fileep}.pt2"]

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self):
384384
)
385385
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
386386
patched_Qwen2_5_VLVisionAttention,
387+
PLUGS_Qwen25,
387388
)
388389

389390
config = get_cached_configuration("Qwen/Qwen2.5-VL-7B-Instruct")
@@ -406,7 +407,7 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self):
406407
_is_torchdynamo_exporting()
407408
), f"exporting is not set to true? {torch.compiler.is_exporting_flag}"
408409
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
409-
self.assertEqualArray(expected, got, atol=1e-5)
410+
self.assertEqualArray(expected, got, atol=1e-2)
410411

411412
class Model(patched_class):
412413
def forward(
@@ -456,16 +457,20 @@ def forward(
456457
dynamic_shapes=ds,
457458
exporter=exporter,
458459
filename=filename,
460+
onnx_plugs=PLUGS_Qwen25,
461+
target_opset=22,
459462
)
460463
# exporter_kwargs={"report":True} if exporter != "custom" else {}
461-
self.assert_onnx_disc(
462-
f"test_patched_qwen2_5_vl_vision_attention_forward-{exporter}",
463-
onnx.load(filename),
464-
instance,
465-
inputs,
466-
atol=1e-3,
467-
rtol=1,
468-
)
464+
if torch.cuda.is_available():
465+
self.assert_onnx_disc(
466+
f"test_patched_qwen2_5_vl_vision_attention_forward-{exporter}",
467+
onnx.load(filename),
468+
instance,
469+
inputs,
470+
atol=1e-3,
471+
rtol=1,
472+
providers=["CUDAExecutionProvider"],
473+
)
469474
self.clean_dump()
470475

471476
@requires_transformers("4.99")

onnx_diagnostic/export/control_flow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,11 @@ def make_custom_loop_for(
134134
assert body_outputs is not None, "body_outputs cannot be None"
135135
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136136
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
137-
full_name = body_fn.__qualname__.replace("<locals>", "L").replace(".", "_")
137+
full_name = (
138+
body_fn.__qualname__.replace("<locals>", "L")
139+
.replace("<lambda>", "l")
140+
.replace(".", "_")
141+
)
138142
name = f"loop_for_{full_name}_{srank}_{sred}"
139143
if name in _REGISTERED_SCHEMA:
140144
return name, _REGISTERED_SCHEMA[name][0]

0 commit comments

Comments
 (0)