Skip to content

Commit b5aa71b

Browse files
committed
fix
1 parent 2d27c1a commit b5aa71b

File tree

3 files changed

+131
-26
lines changed

3 files changed

+131
-26
lines changed

_unittests/ut_export/test_onnx_plug.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import onnx
23
import onnx.helper as oh
34
import torch
45
from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch, hide_stdout, ignore_warnings
@@ -38,7 +39,7 @@ def make_function_proto():
3839

3940
@hide_stdout()
4041
@ignore_warnings(FutureWarning)
41-
def test_onnx_plug_export(self):
42+
def test_onnx_plug_export_nokwargs(self):
4243
def _test_customsub(x, y):
4344
return x - y
4445

@@ -85,7 +86,95 @@ def forward(self, x):
8586
onnx_plugs=replacements,
8687
target_opset=22,
8788
)
88-
self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,))
89+
self.assert_onnx_disc(
90+
"test_onnx_plug_export_nokwargs_custom", onx.model_proto, model, (x,)
91+
)
92+
93+
if not has_torch("2.9"):
94+
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
95+
with self.subTest(exporter="onnx-dynamo"):
96+
onx = to_onnx(
97+
model,
98+
(x,),
99+
dynamic_shapes=ds,
100+
exporter="onnx-dynamo",
101+
onnx_plugs=replacements,
102+
target_opset=22,
103+
)
104+
self.assert_onnx_disc(
105+
"test_onnx_plug_export_nokwargs_onnx_dynamo", onx.model_proto, model, (x,)
106+
)
107+
108+
@unittest.skip("not ready yet")
109+
@hide_stdout()
110+
@ignore_warnings(FutureWarning)
111+
def test_onnx_plug_export_kwargs(self):
112+
def _test_customdiv(x, y, epsilon: float = 1e-5):
113+
return x / (y + epsilon)
114+
115+
def _test_customdiv_shape(x, y, *args, **kwargs):
116+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
117+
118+
def make_function_proto():
119+
f = oh.make_function(
120+
"onnx_plug",
121+
"_test_customdiv",
122+
["x", "y"],
123+
["z"],
124+
[
125+
oh.make_node("Constant", [], ["eps"]),
126+
oh.make_node("Add", ["y", "eps"], ["yeps"]),
127+
oh.make_node("Div", ["x", "yeps"], ["z"]),
128+
],
129+
opset_imports=[oh.make_opsetid("", 22)],
130+
attributes=["epsilon"],
131+
)
132+
att = onnx.AttributeProto()
133+
att.name = "value_float"
134+
att.ref_attr_name = "epsilon"
135+
att.type = onnx.AttributeProto.FLOAT
136+
f.node[0].attribute.append(att)
137+
return f
138+
139+
class Model(torch.nn.Module):
140+
def forward(self, x):
141+
y = x.sum(axis=1, keepdim=True)
142+
d = torch.ops.onnx_plug._test_customdiv(x, y, epsilon=3.5)
143+
return torch.abs(d)
144+
145+
replacements = [
146+
EagerDirectReplacementWithOnnx(
147+
_test_customdiv,
148+
_test_customdiv_shape,
149+
make_function_proto(),
150+
2,
151+
1,
152+
kwargs=dict(epsilon=1e-5),
153+
verbose=1,
154+
)
155+
]
156+
157+
x = torch.randn((3, 4), dtype=torch.float32)
158+
model = Model()
159+
expected = model(x)
160+
ds = ({0: "d1", 1: "d2"},)
161+
ep = torch.export.export(model, (x,), dynamic_shapes=self.use_dyn_not_str(ds))
162+
self.assertIn("torch.ops.onnx_plug._test_customdiv.default", str(ep))
163+
got = ep.module()(x)
164+
self.assertEqualArray(expected, got)
165+
166+
with self.subTest(exporter="custom"):
167+
onx = to_onnx(
168+
model,
169+
(x,),
170+
dynamic_shapes=ds,
171+
exporter="custom",
172+
onnx_plugs=replacements,
173+
target_opset=22,
174+
)
175+
self.assert_onnx_disc(
176+
"test_onnx_plug_export_kwargs_custom", onx.model_proto, model, (x,)
177+
)
89178

90179
if not has_torch("2.9"):
91180
raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8")
@@ -99,7 +188,7 @@ def forward(self, x):
99188
target_opset=22,
100189
)
101190
self.assert_onnx_disc(
102-
"test_onnx_plug_export_onnx_dynamo", onx.model_proto, model, (x,)
191+
"test_onnx_plug_export_kwargs_onnx_dynamo", onx.model_proto, model, (x,)
103192
)
104193

105194

onnx_diagnostic/export/onnx_plug.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class EagerDirectReplacementWithOnnx:
4949
:param n_outputs: same for the number of outputs,
5050
only tensors must be counted
5151
:param name: the name of the custom op, the function name if not specified
52-
:param kwargs: constants
52+
:param kwargs: constants parameters with their default values
5353
:param verbose: verbose level
5454
5555
Here is an example:
@@ -163,8 +163,8 @@ def __init__(
163163
.replace("<lambda>", "l")
164164
.replace(".", "_")
165165
)
166-
self.kwargs = kwargs
167-
assert kwargs is None or all(isinstance(v, (int, float)) for v in kwargs.values()), (
166+
self.kwargs = kwargs or {}
167+
assert all(isinstance(v, (int, float)) for v in self.kwargs.values()), (
168168
f"Only int or floats are allowed for kwargs={kwargs}, one of them "
169169
f"does not respect that constraint."
170170
)
@@ -184,7 +184,8 @@ def __init__(
184184
assert (
185185
function_proto.domain == self.domain
186186
), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
187-
self.arg_names = params
187+
self.args_name = [p for p in params if p not in self.kwargs]
188+
self.kwargs_name = [p for p in params if p in self.kwargs]
188189
self.verbose = verbose
189190
self.custom_op = self._register()
190191

@@ -211,7 +212,19 @@ def __call__(self, *args):
211212

212213
def _register(self):
213214
"""Registers the custom op."""
214-
inputs = ", ".join([f"Tensor {p}" for p in self.arg_names])
215+
input_args = [f"Tensor {p}" for p in self.args_name]
216+
for p in self.kwargs_name:
217+
val = self.kwargs[p]
218+
if isinstance(val, int):
219+
input_args.append(f"int {p}={val}")
220+
elif isinstance(val, float):
221+
input_args.append(f"float {p}={val}")
222+
else:
223+
raise NotImplementedError(
224+
f"kwargs {p!r} has a default value of unsupported type {type(val)}"
225+
)
226+
227+
inputs = ", ".join(input_args)
215228
schema = f"({inputs}) -> Tensor"
216229
if self.n_outputs > 1:
217230
schema += "[]"
@@ -292,12 +305,15 @@ def converter(
292305
self.function_proto.name, domain=self.function_proto.domain
293306
):
294307
g.add_function(self.function_proto)
308+
ags = args[: len(self.args_name)]
309+
kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
295310
res = g.make_node(
296311
self.function_proto.name,
297-
args,
312+
ags,
298313
outputs,
299314
domain=self.function_proto.domain,
300315
name=self.target_name,
316+
**kws,
301317
)
302318
if not sts:
303319
new_shapes = self.shape_fn(*args)

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727

2828
@onnxscript.script(opset=onnx_plugs_op)
2929
def LoopMHAAttention(
30-
query_states, key_states, value_states, cu_seqlens, scale: float, num_heads: int
30+
query_states,
31+
key_states,
32+
value_states,
33+
cu_seqlens,
34+
scaling: float = 0.11180339887498948,
35+
num_heads: int = 16,
3136
):
3237
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
3338
query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
@@ -52,7 +57,7 @@ def LoopMHAAttention(
5257
key_i,
5358
value_i,
5459
num_heads=num_heads,
55-
scale=scale,
60+
scale=scaling,
5661
)
5762
attn_output = op.Concat(attn_output, mha_output, axis=1)
5863
attn_output_4d = op.Reshape(attn_output, output_shape)
@@ -64,7 +69,7 @@ def PackedAttention(
6469
key,
6570
value,
6671
cu_seqlens,
67-
scale: float = 0.11180339887498948,
72+
scaling: float = 0.11180339887498948,
6873
num_heads: int = 16,
6974
):
7075
num_patches = op.Cast(op.Size(cu_seqlens), to=onnx.TensorProto.INT32) - 1
@@ -102,7 +107,7 @@ def PackedAttention(
102107
None,
103108
op.Cast(token_offset, to=onnx.TensorProto.INT32),
104109
op.Cast(cu_seqlens, to=onnx.TensorProto.INT32),
105-
scale=scale,
110+
scale=scaling,
106111
num_heads=num_heads,
107112
)
108113
packed_attn_output_3d = op.Reshape(packed_attn_output_2d, shape_3d)
@@ -139,10 +144,8 @@ def qwen_sdpa_attention(
139144

140145
# not ideal
141146
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
142-
lambda qs, ks, vs, cuseq: qwen_sdpa_attention(
143-
qs, ks, vs, cuseq, scaling=0.11180339887498948
144-
),
145-
lambda qs, *args: torch.empty(
147+
qwen_sdpa_attention,
148+
lambda qs, *args, **kwargs: torch.empty(
146149
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
147150
dtype=qs.dtype,
148151
device=qs.device,
@@ -489,16 +492,13 @@ def forward(
489492
is transformers.integrations.sdpa_attention.sdpa_attention_forward
490493
or attention_interface is patched_sdpa_attention_forward
491494
) and strategy_for_attention_in_qwen_2_5 == "PACKED":
492-
torch._check(
493-
qwen_sdpa_attention_versatile.kwargs["scaling"] == self.scaling,
494-
lambda: f"Not implemented for scaling={self.scaling}",
495-
)
496-
torch._check(
497-
qwen_sdpa_attention_versatile.kwargs["num_heads"] == self.num_heads,
498-
lambda: f"Not implemented for num_heads={self.num_heads}",
499-
)
500495
attn_output = qwen_sdpa_attention_versatile(
501-
query_states, key_states, value_states, cu_seqlens
496+
query_states,
497+
key_states,
498+
value_states,
499+
cu_seqlens,
500+
scaling=self.scaling,
501+
num_heads=self.num_heads,
502502
)
503503
elif _is_torchdynamo_exporting():
504504
if self.config._attn_implementation == "flash_attention_2":

0 commit comments

Comments
 (0)