Skip to content

Commit 51b8db5

Browse files
committed
implements version plugs
1 parent cd01099 commit 51b8db5

File tree

4 files changed

+179
-74
lines changed

4 files changed

+179
-74
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ class TestTryExportHuggingFaceHubModel(ExtTestCase):
1717
@ignore_warnings(UserWarning)
1818
def test_qwen25_vli_visual(self):
1919
"""
20+
unittest::
21+
22+
UNITTEST_GOING=1 NEVERTEST=1 TESTDTYPE=float16 TESTDEVICE=cpu python \\
23+
_unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual
24+
2025
# task: imagetext2text
2126
clear&&NEVERTEST=1 python _unittests/ut_tasks/try_export.py -k qwen_2_5
2227
@@ -180,7 +185,7 @@ def _config_reduction(config, task):
180185
exporter=exporter,
181186
verbose=1,
182187
save_ep=None if self.unit_test_going() else (fileep, 2**35),
183-
target_opset=24 if attention == "LOOPMHA" else 22,
188+
target_opset=24 if attention == "LOOPA24" else 22,
184189
optimize=True,
185190
onnx_plugs=PLUGS,
186191
)

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -667,16 +667,16 @@ def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self):
667667
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
668668
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
669669
self.assertEqual(len(results.eager_outputs), len(results.diffs))
670-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
671-
self.assertLess(results.diffs[0]["abs"], 1e-5)
670+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-2)
671+
self.assertLess(results.diffs[0]["abs"], 1e-2)
672672

673673
results = qwen_sdpa_attention_loopa24_versatile.verify(
674674
*inputs, scaling=0.11180339887498948
675675
)
676676
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
677677
self.assertEqual(len(results.eager_outputs), len(results.diffs))
678-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
679-
self.assertLess(results.diffs[0]["abs"], 1e-5)
678+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005)
679+
self.assertLess(results.diffs[0]["abs"], 0.005)
680680

681681
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
682682
def test_plug_packed_multi_head_attention_qwen25_loopa24_float32(self):

onnx_diagnostic/export/onnx_plug.py

Lines changed: 119 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
44
import onnx
55
import torch
6-
from ..helpers import max_diff
6+
from ..helpers import max_diff, string_type
77
from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
88
from ..reference import OnnxruntimeEvaluator
99

@@ -50,6 +50,7 @@ class EagerDirectReplacementWithOnnx:
5050
only tensors must be counted
5151
:param name: the name of the custom op, the function name if not specified
5252
:param kwargs: constants parameters with their default values
53+
:param version_selector: selects the version based on the arguments
5354
:param verbose: verbose level
5455
5556
Here is an example:
@@ -139,21 +140,28 @@ def __init__(
139140
self,
140141
eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
141142
shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
142-
function_proto: onnx.FunctionProto,
143+
function_proto: Union[onnx.FunctionProto, Dict[Any, onnx.FunctionProto]],
143144
n_inputs: Optional[int] = None,
144145
n_outputs: Optional[int] = None,
145146
name: Optional[str] = None,
146147
kwargs: Optional[Dict[str, Union[int, float]]] = None,
147148
verbose: int = 0,
149+
version_selector: Optional[Callable[[Any], Any]] = None,
148150
):
149-
assert isinstance(
150-
function_proto, onnx.FunctionProto
151+
assert isinstance(function_proto, onnx.FunctionProto) or (
152+
isinstance(function_proto, dict)
153+
or all(isinstance(v, onnx.FunctionProto) for v in function_proto.values())
151154
), f"Unexpected type {type(function_proto)} for function_proto"
152155
assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
153-
assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}"
156+
assert isinstance(n_outputs, int), f"not implemented yet when n_outputs={n_outputs}"
154157
self.eager_fn = eager_fn
155158
self.shape_fn = shape_fn
156-
self.function_proto = function_proto
159+
self._function_proto = (
160+
function_proto if isinstance(function_proto, onnx.FunctionProto) else None
161+
)
162+
self._function_proto_versioned = (
163+
function_proto if isinstance(function_proto, dict) else {}
164+
)
157165
self.n_inputs = n_inputs
158166
self.n_outputs = n_outputs
159167
self.name = name or (
@@ -170,24 +178,72 @@ def __init__(
170178
)
171179
sig = inspect.signature(self.eager_fn)
172180
params = list(sig.parameters)
173-
assert (
174-
len(params) >= n_inputs
175-
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
176-
assert n_inputs == len(function_proto.input), (
177-
f"Input mismatch n_inputs={n_inputs} but "
178-
f"function_proto.input={function_proto.input}"
179-
)
180-
assert n_outputs == len(function_proto.output), (
181-
f"Output mismatch n_outputs={n_outputs} but "
182-
f"function_proto.output={function_proto.output}"
183-
)
184-
assert (
185-
function_proto.domain == self.domain
186-
), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
187181
self.args_name = [p for p in params if p not in self.kwargs]
188182
self.kwargs_name = [p for p in params if p in self.kwargs]
189183
self.verbose = verbose
190184
self.custom_op = self._register()
185+
self.version_selector = version_selector
186+
self._check_protos(params)
187+
188+
def _check_protos(self, params):
189+
assert (
190+
len(params) >= self.n_inputs
191+
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={self.n_inputs}"
192+
193+
# one proto
194+
assert self._function_proto is None or self.n_inputs == len(
195+
self._function_proto.input
196+
), (
197+
f"Input mismatch n_inputs={self.n_inputs} but "
198+
f"function_proto.input={self._function_proto.input}"
199+
)
200+
assert self._function_proto is None or self.n_outputs == len(
201+
self._function_proto.output
202+
), (
203+
f"Output mismatch n_outputs={self.n_outputs} but "
204+
f"function_proto.output={self._function_proto.output}"
205+
)
206+
assert self._function_proto is None or (
207+
self._function_proto.domain == self.domain
208+
), f"Function domain must be {self.domain!r} but it is {self._function_proto.domain!r}"
209+
210+
# multiple protos
211+
assert all(
212+
self.n_inputs == len(v.input) for v in self._function_proto_versioned.values()
213+
), f"Output mismatch n_inputs={self.n_inputs} but one verion is wrong"
214+
assert all(
215+
self.n_outputs == len(v.output) for v in self._function_proto_versioned.values()
216+
), f"Output mismatch n_outputs={self.n_outputs} but one verion is wrong"
217+
assert all(
218+
v.domain == self.domain for v in self._function_proto_versioned.values()
219+
), f"Function domain must be {self.domain!r} but it is different in one version"
220+
assert (
221+
not self._function_proto_versioned or self.version_selector
222+
), "version_selector is needed when multiple protos are given."
223+
224+
def get_function_proto(self, *args) -> onnx.FunctionProto:
225+
"""Returns the correct version based on the inputs."""
226+
if self._function_proto:
227+
return self._function_proto
228+
if (
229+
len(args) == 1
230+
and isinstance(args[0], (int, str))
231+
and args[0] in self._function_proto_versioned
232+
):
233+
return self._function_proto_versioned[args[0]]
234+
try:
235+
key = self.version_selector(*args)
236+
except (ValueError, AttributeError) as e:
237+
raise AssertionError(
238+
f"Unable to select a version, fails to get a key, available="
239+
f"{set(self._function_proto_versioned)}, "
240+
f"args={string_type(args,with_shape=True)}"
241+
) from e
242+
assert key in self._function_proto_versioned, (
243+
f"Unable to select a version, key={key}, available="
244+
f"{set(self._function_proto_versioned)}, args={string_type(args,with_shape=True)}"
245+
)
246+
return self._function_proto_versioned[key]
191247

192248
@property
193249
def domain(self) -> str:
@@ -291,7 +347,7 @@ def verify(
291347
assert engine is None, f"Not implemented yet with engine={engine!r}"
292348
ags, kws = self._make_args_kwargs(*args, **kwargs)
293349
sess = OnnxruntimeEvaluator(
294-
self.function_proto,
350+
self.get_function_proto(*args),
295351
whole=True,
296352
dump_onnx_model=dump_onnx_model,
297353
function_kwargs=kws,
@@ -324,16 +380,15 @@ def converter(
324380
*args,
325381
**kwargs,
326382
) -> Any:
327-
if not g.has_local_function(
328-
self.function_proto.name, domain=self.function_proto.domain
329-
):
330-
g.add_function(self.function_proto)
383+
function_proto = self.get_function_proto(g.get_type(args[0]))
384+
if not g.has_local_function(function_proto.name, domain=function_proto.domain):
385+
g.add_function(function_proto)
331386
ags, kws = self._make_args_kwargs(*args, **kwargs)
332387
res = g.make_node(
333-
self.function_proto.name,
388+
function_proto.name,
334389
ags,
335390
outputs,
336-
domain=self.function_proto.domain,
391+
domain=function_proto.domain,
337392
name=self.target_name,
338393
**kws,
339394
)
@@ -356,41 +411,46 @@ def onnx_dynamo_converter(self) -> Callable:
356411
"""
357412
import onnxscript
358413

359-
onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1)
360-
schema = onnx_plug_op[self.function_proto.name]
361-
if schema is None:
362-
all_types = [
363-
"tensor(float)",
364-
"tensor(float16)",
365-
"tensor(bfloat16)",
366-
"tensor(double)",
367-
"tensor(int64)",
368-
"tensor(int32)",
369-
]
370-
type_constraints = []
371-
for i in range(self.n_inputs):
372-
type_constraints.append((f"T{i}", all_types, ""))
373-
for i in range(self.n_outputs):
374-
type_constraints.append((f"U{i}", all_types, ""))
375-
schema = onnx.defs.OpSchema(
376-
self.function_proto.name,
377-
self.function_proto.domain,
378-
1,
379-
inputs=[
380-
onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
381-
for i in range(self.n_inputs)
382-
],
383-
outputs=[
384-
onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
385-
for i in range(self.n_outputs)
386-
],
387-
type_constraints=type_constraints,
388-
)
389-
onnx.defs.register_schema(schema)
390-
op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
414+
onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
415+
416+
def get_proto(*args):
417+
function_proto = self.get_function_proto()
418+
schema = onnx_plug_op[function_proto.name]
419+
if schema is None:
420+
all_types = [
421+
"tensor(float)",
422+
"tensor(float16)",
423+
"tensor(bfloat16)",
424+
"tensor(double)",
425+
"tensor(int64)",
426+
"tensor(int32)",
427+
]
428+
type_constraints = []
429+
for i in range(self.n_inputs):
430+
type_constraints.append((f"T{i}", all_types, ""))
431+
for i in range(self.n_outputs):
432+
type_constraints.append((f"U{i}", all_types, ""))
433+
schema = onnx.defs.OpSchema(
434+
function_proto.name,
435+
function_proto.domain,
436+
1,
437+
inputs=[
438+
onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
439+
for i in range(self.n_inputs)
440+
],
441+
outputs=[
442+
onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
443+
for i in range(self.n_outputs)
444+
],
445+
type_constraints=type_constraints,
446+
)
447+
onnx.defs.register_schema(schema)
448+
op = onnxscript.values.Op(onnx_plug_op, function_proto.name, schema)
449+
return op
391450

392451
def converter(*cargs, **ckwargs):
393452
ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
453+
op = get_proto(*cargs)
394454
return op(*ags, n_outputs=self.n_outputs, **kws)
395455

396456
return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)

0 commit comments

Comments
 (0)