Skip to content

Commit 604527b

Browse files
committed
Fix minor bugs in onnx_plug
1 parent 60ee6ce commit 604527b

File tree

4 files changed

+201
-54
lines changed

4 files changed

+201
-54
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def test_imagetext2text_qwen_2_5_vl_instruct_visual(self):
4545
exporter = os.environ.get("EXPORTER", "custom")
4646

4747
from transformers import AutoModel, AutoProcessor
48+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
49+
PLUGS,
50+
)
4851

4952
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
5053
# model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
@@ -82,11 +85,17 @@ def _config_reduction(config, task):
8285
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
8386
print(f"-- processor={type(processor)}")
8487

88+
big_inputs = dict(
89+
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
90+
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
91+
)
92+
print("-- save inputs")
8593
inputs = dict(
8694
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
8795
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
8896
)
8997
print("-- save inputs")
98+
torch.save(big_inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.big.pt"))
9099
torch.save(inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.pt"))
91100

92101
print(f"-- inputs: {self.string_type(inputs, with_shape=True)}")
@@ -115,15 +124,6 @@ def _config_reduction(config, task):
115124
verbose=1,
116125
stop_if_static=2,
117126
):
118-
if exporter == "onnx-dynamo":
119-
# The exported program in ONNXProgram cannot be restored.
120-
ep2 = torch.export.export(
121-
model.visual,
122-
(),
123-
kwargs=export_inputs,
124-
dynamic_shapes=self.use_dyn_not_str(dynamic_shapes),
125-
)
126-
torch.export.save(ep2, f"{fileep}.backup.pt2")
127127
to_onnx(
128128
model.visual,
129129
kwargs=export_inputs,
@@ -134,6 +134,7 @@ def _config_reduction(config, task):
134134
save_ep=(fileep, 2**35),
135135
target_opset=22,
136136
optimize=True,
137+
onnx_plugs=PLUGS,
137138
)
138139

139140
pt2_files = [f"{fileep}.backup.pt2", f"{fileep}.ep.pt2", f"{fileep}.pt2"]

onnx_diagnostic/export/control_flow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,11 @@ def make_custom_loop_for(
134134
assert body_outputs is not None, "body_outputs cannot be None"
135135
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136136
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
137-
full_name = body_fn.__qualname__.replace("<locals>", "L").replace(".", "_")
137+
full_name = (
138+
body_fn.__qualname__.replace("<locals>", "L")
139+
.replace("<lambda>", "l")
140+
.replace(".", "_")
141+
)
138142
name = f"loop_for_{full_name}_{srank}_{sred}"
139143
if name in _REGISTERED_SCHEMA:
140144
return name, _REGISTERED_SCHEMA[name][0]

onnx_diagnostic/export/onnx_plug.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
from dataclasses import dataclass
3-
from typing import Any, Callable, Dict, List, Optional, Tuple
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
44
import onnx
55
import torch
66
from ..helpers import max_diff
@@ -49,6 +49,7 @@ class EagerDirectReplacementWithOnnx:
4949
:param n_outputs: same for the number of outputs,
5050
only tensors must be counted
5151
:param name: the name of the custom op, the function name if not specified
52+
:param kwargs: constants
5253
5354
Here is an example:
5455
@@ -141,6 +142,7 @@ def __init__(
141142
n_inputs: Optional[int] = None,
142143
n_outputs: Optional[int] = None,
143144
name: Optional[str] = None,
145+
kwargs: Optional[Dict[str, Union[int, float]]] = None,
144146
):
145147
assert isinstance(
146148
function_proto, onnx.FunctionProto
@@ -152,7 +154,14 @@ def __init__(
152154
self.function_proto = function_proto
153155
self.n_inputs = n_inputs
154156
self.n_outputs = n_outputs
155-
self.name = name or eager_fn.__name__
157+
self.name = name or eager_fn.__qualname__.replace("<local>", "L").replace(
158+
"<lambda>", "l"
159+
).replace(".", "_")
160+
self.kwargs = kwargs
161+
assert kwargs is None or all(isinstance(v, (int, float)) for v in kwargs.values()), (
162+
f"Only int or floats are allowed for kwargs={kwargs}, one of them "
163+
f"does not respect that constraint."
164+
)
156165
sig = inspect.signature(self.eager_fn)
157166
params = list(sig.parameters)
158167
assert (
@@ -190,7 +199,7 @@ def torch_op(self) -> Callable:
190199
def __call__(self, *args):
191200
"""Calls eager_fn or shape_fn if the model is being exported."""
192201
if is_exporting():
193-
return self.shape_fn(*args)
202+
return self.torch_op(*args)
194203
return self.eager_fn(*args)
195204

196205
def _registers(self):
@@ -266,10 +275,16 @@ def converter(
266275
outputs: List[str],
267276
*args,
268277
) -> Any:
269-
if not g.has_local_function(self.name, self.domain):
278+
if not g.has_local_function(
279+
self.function_proto.name, domain=self.function_proto.domain
280+
):
270281
g.add_function(self.function_proto)
271282
res = g.make_node(
272-
self.name, args, outputs, domain=self.domain, name=self.target_name
283+
self.function_proto.name,
284+
args,
285+
outputs,
286+
domain=self.function_proto.domain,
287+
name=self.target_name,
273288
)
274289
if not sts:
275290
new_shapes = self.shape_fn(*args)
@@ -290,8 +305,8 @@ def onnx_dynamo_converter(self) -> Callable:
290305
"""
291306
import onnxscript
292307

293-
onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
294-
schema = onnx_plug_op[self.name]
308+
onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1)
309+
schema = onnx_plug_op[self.function_proto.name]
295310
if schema is None:
296311
all_types = [
297312
"tensor(float)",
@@ -307,8 +322,8 @@ def onnx_dynamo_converter(self) -> Callable:
307322
for i in range(self.n_outputs):
308323
type_constraints.append((f"U{i}", all_types, ""))
309324
schema = onnx.defs.OpSchema(
310-
self.name,
311-
self.domain,
325+
self.function_proto.name,
326+
self.function_proto.domain,
312327
1,
313328
inputs=[
314329
onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
@@ -321,7 +336,7 @@ def onnx_dynamo_converter(self) -> Callable:
321336
type_constraints=type_constraints,
322337
)
323338
onnx.defs.register_schema(schema)
324-
op = onnxscript.values.Op(onnx_plug_op, self.name, schema)
339+
op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
325340

326341
def converter(*cargs):
327342
return op(*cargs, n_outputs=self.n_outputs)

0 commit comments

Comments
 (0)