Skip to content
Merged
96 changes: 69 additions & 27 deletions _doc/technical/plot_gemm_or_matmul_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
oh.make_node("Add", ["mm", "B"], ["MatMulAdd"]),
oh.make_node("FusedMatMul", ["A", "X"], ["fmm"], domain="com.microsoft"),
oh.make_node("Add", ["fmm", "B"], ["FusedMatMulAdd"]),
oh.make_node("Cast", ["A"], ["Afloat"], to=onnx.TensorProto.FLOAT),
oh.make_node("Cast", ["B"], ["Bfloat"], to=onnx.TensorProto.FLOAT),
oh.make_node("Cast", ["X"], ["Xfloat"], to=onnx.TensorProto.FLOAT),
oh.make_node("Gemm", ["Afloat", "Xfloat"], ["gmmfloat"]),
oh.make_node("Add", ["gmmfloat", "Bfloat"], ["gemmaddfloat"]),
oh.make_node("Cast", ["gemmaddfloat"], ["CastGemmAddCast"], to=itype),
oh.make_node("Gemm", ["Afloat", "Xfloat", "Bfloat"], ["GemmOnlyfloat"]),
oh.make_node("Cast", ["GemmOnlyfloat"], ["CastGemmOnlyCast"], to=itype),
],
"test",
[
Expand All @@ -65,6 +73,8 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
oh.make_tensor_value_info("GemmAdd", itype, ["a", "c"]),
oh.make_tensor_value_info("FusedMatMulAdd", itype, ["a", "c"]),
oh.make_tensor_value_info("MatMulAdd", itype, ["a", "c"]),
oh.make_tensor_value_info("CastGemmAddCast", itype, ["a", "c"]),
oh.make_tensor_value_info("CastGemmOnlyCast", itype, ["a", "c"]),
],
),
opset_imports=[oh.make_opsetid("", 22)],
Expand All @@ -85,7 +95,7 @@ def matrix_diff(tensors):
dtype = np.float16
model = make_model_gemm(itype)

A = np.random.randn(512, 256).astype(dtype)
A = np.random.randn(1280, 256).astype(dtype)
X = np.random.randn(256, 256).astype(dtype)
B = np.random.randn(256).astype(dtype)
feeds = dict(A=A, X=X, B=B)
Expand All @@ -112,9 +122,9 @@ def matrix_diff(tensors):
# %%
# Let's try with CUDA and float32 if it is available.

A = torch.randn((512, 512), dtype=torch.float32)
X = torch.randn((512, 512), dtype=torch.float32)
B = torch.randn((512), dtype=torch.float32)
A = torch.randn((1280, 1280), dtype=torch.float32)
X = torch.randn((1280, 1280), dtype=torch.float32)
B = torch.randn((1280), dtype=torch.float32)

for itype, dtype, device in [
(onnx.TensorProto.FLOAT16, torch.float16, "cpu"),
Expand Down Expand Up @@ -144,8 +154,10 @@ def matrix_diff(tensors):
# are similar to the others coefficients. What if we make them
# a lot higher.

B = (torch.arange(512, dtype=torch.float32) + 1) / 512 * 16384
labels = ["linear", *[o.name for o in model.graph.output], "a @ x + b"]
A = A / A.max()
X = X / X.max()
B = (torch.arange(1280, dtype=torch.float32) + 1) / 1280 * 16
labels = ["F.linear", *[o.name for o in model.graph.output], "a @ x + b"]
all_results = {}

for itype, dtype, device in [
Expand Down Expand Up @@ -187,28 +199,58 @@ def matrix_diff(tensors):
# bias value vs discrepancies
# ===========================
#
# Let's compare GemmOnly (so bias is included) and Gemm+Add.

i, j = 1, -1
labs = labels[i], labels[j]

fig, ax = plt.subplots(len(all_results), 2, figsize=(8, 2.5 * len(results)))
for pos, ((device, dtype), results) in enumerate(all_results.items()):
m1, m2 = results[i], results[j]
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
ax[pos, 0].plot(B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), ".")
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}")

corr = matrix_diff(results)
ax[pos, 1].imshow(corr, cmap="Blues", vmin=0, vmax=corr.max())
# ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45)
ax[pos, 1].set_yticks(range(len(labels)), labels)
ax[pos, 1].set_title(f"max={diff.max()}")
# Let's compare torch linear with GemmOnly.


def make_figure_axis(all_results, i, j):
labs = labels[i], labels[j]
fig, ax = plt.subplots(len(all_results), 2, figsize=(12, 4 * len(all_results)))
for pos, ((device, dtype), results) in enumerate(all_results.items()):
m1, m2 = results[i], results[j]
diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0]
print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}")
expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2
ax[pos, 0].plot(
B.tolist(), (diff.detach().cpu() + torch.rand(1280) * expand).tolist(), "."
)
ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}", fontsize=10)

corr = matrix_diff(results)
ax[pos, 1].imshow(corr, cmap="Wistia", vmin=0, vmax=corr.max())
# ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45, ha="right", fontsize=10)
ax[pos, 1].set_yticks(range(len(labels)), labels, fontsize=10)
ax[pos, 1].set_title(f"max={diff.max():1.2g}", fontsize=10)
for _i in range(corr.shape[0]):
for _j in range(corr.shape[1]):
ax[pos, 1].text(
_j,
_i,
f"{corr[_i, _j]:1.1g}",
ha="center",
va="center",
color="black",
fontsize=8,
)
fig.suptitle(
f"Left column: discrepancies {labs[0]} VS {labs[1]}\n"
f"Right column: max absolute error, across all configuration\n"
f"white is good, orange is not"
)
return fig, ax


fig, ax = make_figure_axis(all_results, 0, 1)
fig.tight_layout()
fig.savefig("plot_gemm_or_matmul_add.png")
fig.savefig("plot_gemm_or_matmul_add1.png")

# %%
# Let's compare with ``A @ X + B``.

fig, ax = make_figure_axis(all_results, -1, 1)
fig.tight_layout()
fig.savefig("plot_gemm_or_matmul_add2.png")


# %%
# Discrepancies do not happen all the time but it is very likely to happen.
Expand Down
101 changes: 96 additions & 5 deletions _unittests/ut_export/test_onnx_plug.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
import onnx
import onnx.helper as oh
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch, hide_stdout, ignore_warnings
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
from onnx_diagnostic.export.api import to_onnx

Expand Down Expand Up @@ -36,7 +37,9 @@ def make_function_proto():
self.assertEqual(len(res.diffs), 1)
self.assertEqual(res.diffs[0]["abs"], 0)

def test_onnx_plug_export(self):
@hide_stdout()
@ignore_warnings(FutureWarning)
def test_onnx_plug_export_nokwargs(self):
def _test_customsub(x, y):
return x - y

Expand All @@ -61,7 +64,7 @@ def forward(self, x):

replacements = [
EagerDirectReplacementWithOnnx(
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1
_test_customsub, _test_customsub_shape, make_function_proto(), 2, 1, verbose=1
)
]

Expand All @@ -83,7 +86,95 @@ def forward(self, x):
onnx_plugs=replacements,
target_opset=22,
)
self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,))
self.assert_onnx_disc(
"test_onnx_plug_export_nokwargs_custom", onx.model_proto, model, (x,)
)

if not has_torch("2.9"):
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
with self.subTest(exporter="onnx-dynamo"):
onx = to_onnx(
model,
(x,),
dynamic_shapes=ds,
exporter="onnx-dynamo",
onnx_plugs=replacements,
target_opset=22,
)
self.assert_onnx_disc(
"test_onnx_plug_export_nokwargs_onnx_dynamo", onx.model_proto, model, (x,)
)

@unittest.skip("not ready yet")
@hide_stdout()
@ignore_warnings(FutureWarning)
def test_onnx_plug_export_kwargs(self):
def _test_customdiv(x, y, epsilon: float = 1e-5):
return x / (y + epsilon)

def _test_customdiv_shape(x, y, *args, **kwargs):
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)

def make_function_proto():
f = oh.make_function(
"onnx_plug",
"_test_customdiv",
["x", "y"],
["z"],
[
oh.make_node("Constant", [], ["eps"]),
oh.make_node("Add", ["y", "eps"], ["yeps"]),
oh.make_node("Div", ["x", "yeps"], ["z"]),
],
opset_imports=[oh.make_opsetid("", 22)],
attributes=["epsilon"],
)
att = onnx.AttributeProto()
att.name = "value_float"
att.ref_attr_name = "epsilon"
att.type = onnx.AttributeProto.FLOAT
f.node[0].attribute.append(att)
return f

class Model(torch.nn.Module):
def forward(self, x):
y = x.sum(axis=1, keepdim=True)
d = torch.ops.onnx_plug._test_customdiv(x, y, epsilon=3.5)
return torch.abs(d)

replacements = [
EagerDirectReplacementWithOnnx(
_test_customdiv,
_test_customdiv_shape,
make_function_proto(),
2,
1,
kwargs=dict(epsilon=1e-5),
verbose=1,
)
]

x = torch.randn((3, 4), dtype=torch.float32)
model = Model()
expected = model(x)
ds = ({0: "d1", 1: "d2"},)
ep = torch.export.export(model, (x,), dynamic_shapes=self.use_dyn_not_str(ds))
self.assertIn("torch.ops.onnx_plug._test_customdiv.default", str(ep))
got = ep.module()(x)
self.assertEqualArray(expected, got)

with self.subTest(exporter="custom"):
onx = to_onnx(
model,
(x,),
dynamic_shapes=ds,
exporter="custom",
onnx_plugs=replacements,
target_opset=22,
)
self.assert_onnx_disc(
"test_onnx_plug_export_kwargs_custom", onx.model_proto, model, (x,)
)

if not has_torch("2.9"):
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
Expand All @@ -97,7 +188,7 @@ def forward(self, x):
target_opset=22,
)
self.assert_onnx_disc(
"test_onnx_plug_export_onnx_dynamo", onx.model_proto, model, (x,)
"test_onnx_plug_export_kwargs_onnx_dynamo", onx.model_proto, model, (x,)
)


Expand Down
19 changes: 10 additions & 9 deletions _unittests/ut_tasks/try_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def test_imagetext2text_qwen_2_5_vl_instruct_visual(self):
exporter = os.environ.get("EXPORTER", "custom")

from transformers import AutoModel, AutoProcessor
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
PLUGS,
)

model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
# model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
Expand Down Expand Up @@ -82,11 +85,17 @@ def _config_reduction(config, task):
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
print(f"-- processor={type(processor)}")

big_inputs = dict(
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
)
print("-- save inputs")
inputs = dict(
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
)
print("-- save inputs")
torch.save(big_inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.big.pt"))
torch.save(inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.pt"))

print(f"-- inputs: {self.string_type(inputs, with_shape=True)}")
Expand Down Expand Up @@ -115,15 +124,6 @@ def _config_reduction(config, task):
verbose=1,
stop_if_static=2,
):
if exporter == "onnx-dynamo":
# The exported program in ONNXProgram cannot be restored.
ep2 = torch.export.export(
model.visual,
(),
kwargs=export_inputs,
dynamic_shapes=self.use_dyn_not_str(dynamic_shapes),
)
torch.export.save(ep2, f"{fileep}.backup.pt2")
to_onnx(
model.visual,
kwargs=export_inputs,
Expand All @@ -134,6 +134,7 @@ def _config_reduction(config, task):
save_ep=(fileep, 2**35),
target_opset=22,
optimize=True,
onnx_plugs=PLUGS,
)

pt2_files = [f"{fileep}.backup.pt2", f"{fileep}.ep.pt2", f"{fileep}.pt2"]
Expand Down
23 changes: 14 additions & 9 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self):
)
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched_Qwen2_5_VLVisionAttention,
PLUGS_Qwen25,
)

config = get_cached_configuration("Qwen/Qwen2.5-VL-7B-Instruct")
Expand All @@ -406,7 +407,7 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self):
_is_torchdynamo_exporting()
), f"exporting is not set to true? {torch.compiler.is_exporting_flag}"
got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs)
self.assertEqualArray(expected, got, atol=1e-5)
self.assertEqualArray(expected, got, atol=1e-2)

class Model(patched_class):
def forward(
Expand Down Expand Up @@ -456,16 +457,20 @@ def forward(
dynamic_shapes=ds,
exporter=exporter,
filename=filename,
onnx_plugs=PLUGS_Qwen25,
target_opset=22,
)
# exporter_kwargs={"report":True} if exporter != "custom" else {}
self.assert_onnx_disc(
f"test_patched_qwen2_5_vl_vision_attention_forward-{exporter}",
onnx.load(filename),
instance,
inputs,
atol=1e-3,
rtol=1,
)
if torch.cuda.is_available():
self.assert_onnx_disc(
f"test_patched_qwen2_5_vl_vision_attention_forward-{exporter}",
onnx.load(filename),
instance,
inputs,
atol=1e-3,
rtol=1,
providers=["CUDAExecutionProvider"],
)
self.clean_dump()

@requires_transformers("4.99")
Expand Down
6 changes: 5 additions & 1 deletion onnx_diagnostic/export/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def make_custom_loop_for(
assert body_outputs is not None, "body_outputs cannot be None"
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
full_name = body_fn.__qualname__.replace("<locals>", "L").replace(".", "_")
full_name = (
body_fn.__qualname__.replace("<locals>", "L")
.replace("<lambda>", "l")
.replace(".", "_")
)
name = f"loop_for_{full_name}_{srank}_{sred}"
if name in _REGISTERED_SCHEMA:
return name, _REGISTERED_SCHEMA[name][0]
Expand Down
Loading
Loading