Skip to content

Commit 8f444df

Browse files
committed
patch
1 parent 43aed2b commit 8f444df

File tree

3 files changed

+43
-25
lines changed

3 files changed

+43
-25
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`323`: drops torch 2.8 on CI
78
* :pr:`322`: support rerunning onnx kernels with torch intermediate results in side-by-side
89
* :pr:`314`: fix modelbuilder download needed after this change https://github.com/microsoft/onnxruntime-genai/pull/1862
910
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime

_unittests/ut_tasks/try_export.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import time
23
import unittest
34
import torch
45
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test, ignore_warnings
@@ -45,6 +46,7 @@ def test_imagetext2text_qwen_2_5_vl_instruct_visual(self):
4546
EXPORTER=custom \\
4647
python _unittests/ut_tasks/try_export.py -k qwen_2_5_vl_instruct_visual
4748
"""
49+
begin = time.perf_counter()
4850
device = os.environ.get("TESTDEVICE", "cpu")
4951
dtype = os.environ.get("TESTDTYPE", "float32")
5052
torch_dtype = {
@@ -87,13 +89,18 @@ def _config_reduction(config, task):
8789
)
8890
model = data["model"]
8991

92+
print(f"-- MODEL LOADED IN {time.perf_counter() - begin}")
93+
begin = time.perf_counter()
9094
model = model.to(device).to(getattr(torch, dtype))
95+
print(f"-- MODEL MOVED IN {time.perf_counter() - begin}")
9196

9297
print(f"-- config._attn_implementation={model.config._attn_implementation}")
9398
print(f"-- model.dtype={model.dtype}")
9499
print(f"-- model.device={model.device}")
100+
begin = time.perf_counter()
95101
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
96102
print(f"-- processor={type(processor)}")
103+
print(f"-- PROCESSOR LOADED IN {time.perf_counter() - begin}")
97104

98105
big_inputs = dict(
99106
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
@@ -104,14 +111,19 @@ def _config_reduction(config, task):
104111
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
105112
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
106113
)
107-
print("-- save inputs")
108-
torch.save(big_inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.big.pt"))
109-
torch.save(inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.pt"))
114+
if not self.unit_test_going():
115+
print("-- save inputs")
116+
torch.save(
117+
big_inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.big.pt")
118+
)
119+
torch.save(inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.pt"))
110120

111121
print(f"-- inputs: {self.string_type(inputs, with_shape=True)}")
112122
# this is too long
113123
model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
124+
begin = time.perf_counter()
114125
expected = model_to_export(**inputs)
126+
print(f"-- MODEL RUN IN {time.perf_counter() - begin}")
115127
print(f"-- expected: {self.string_type(expected, with_shape=True)}")
116128

117129
filename = self.get_dump_file(
@@ -126,6 +138,7 @@ def _config_reduction(config, task):
126138
)
127139

128140
# fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]
141+
begin = time.perf_counter()
129142
export_inputs = inputs
130143
print()
131144
with torch_export_patches(
@@ -148,14 +161,21 @@ def _config_reduction(config, task):
148161
onnx_plugs=PLUGS,
149162
)
150163

164+
print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}")
165+
151166
pt2_files = [f"{fileep}.backup.pt2", f"{fileep}.ep.pt2", f"{fileep}.pt2"]
152-
pt2_file = [f for f in pt2_files if os.path.exists(f)]
153-
assert pt2_file, f"Unable to find an existing file among {pt2_files}"
154-
pt2_file = pt2_file[0]
167+
pt2_files = [f for f in pt2_files if os.path.exists(f)]
168+
assert (
169+
self.unit_test_going() or pt2_files
170+
), f"Unable to find an existing file among {pt2_files!r}"
171+
pt2_file = (
172+
(pt2_files[0] if pt2_files else None) if not self.unit_test_going() else None
173+
)
155174
# self.assertExists(pt2_file)
156175
# ep = torch.export.load(pt2_file)
157176
# diff = self.max_diff(ep.module()(**export_inputs), model.visual(**export_inputs))
158177
# print("----------- diff", diff)
178+
begin = time.perf_counter()
159179
self.assert_onnx_disc(
160180
f"test_imagetext2text_qwen_2_5_vl_instruct_visual.{device}.{dtype}.{exporter}",
161181
filename,
@@ -171,9 +191,10 @@ def _config_reduction(config, task):
171191
atol=0.02,
172192
rtol=10,
173193
ort_optimized_graph=False,
174-
# ep=pt2_file,
194+
ep=pt2_file,
175195
expected=expected,
176196
)
197+
print(f"-- MODEL VERIFIED IN {time.perf_counter() - begin}")
177198
if self.unit_test_going():
178199
self.clean_dump()
179200

onnx_diagnostic/ext_test_case.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -743,15 +743,15 @@ class ExtTestCase(unittest.TestCase):
743743
_warns: List[Tuple[str, int, Warning]] = []
744744
_todos: List[Tuple[Callable, str]] = []
745745

746-
def unit_test_going(self):
746+
def unit_test_going(self) -> bool:
747747
"""
748748
Enables a flag telling the script is running while testing it.
749749
Avois unit tests to be very long.
750750
"""
751751
return unit_test_going()
752752

753753
@property
754-
def verbose(self):
754+
def verbose(self) -> int:
755755
"Returns the the value of environment variable ``VERBOSE``."
756756
return int(os.environ.get("VERBOSE", "0"))
757757

@@ -776,13 +776,13 @@ def todo(cls, f: Callable, msg: str):
776776
cls._todos.append((f, msg))
777777

778778
@classmethod
779-
def ort(cls):
779+
def ort(cls) -> unittest.__class__:
780780
import onnxruntime
781781

782782
return onnxruntime
783783

784784
@classmethod
785-
def to_onnx(self, *args, **kwargs):
785+
def to_onnx(self, *args, **kwargs) -> "ModelProto": # noqa: F821
786786
from experimental_experiment.torch_interpreter import to_onnx
787787

788788
return to_onnx(*args, **kwargs)
@@ -823,12 +823,7 @@ def clean_dump(self, folder: str = "dump_test"):
823823
elif os.path.isdir(item_path):
824824
shutil.rmtree(item_path)
825825

826-
def dump_onnx(
827-
self,
828-
name: str,
829-
proto: Any,
830-
folder: Optional[str] = None,
831-
) -> str:
826+
def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
832827
"""Dumps an onnx file."""
833828
fullname = self.get_dump_file(name, folder=folder)
834829
with open(fullname, "wb") as f:
@@ -1111,7 +1106,9 @@ def assertAlmostEqual(
11111106
value = numpy.array(value).astype(expected.dtype)
11121107
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
11131108

1114-
def check_ort(self, onx: "onnx.ModelProto") -> bool: # noqa: F821
1109+
def check_ort(
1110+
self, onx: "onnx.ModelProto" # noqa: F821
1111+
) -> "onnxruntime.InferenceSession": # noqa: F821
11151112
from onnxruntime import InferenceSession
11161113

11171114
return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
@@ -1154,7 +1151,7 @@ def assertEndsWith(self, suffix: str, full: str):
11541151
if not full.endswith(suffix):
11551152
raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
11561153

1157-
def capture(self, fct: Callable):
1154+
def capture(self, fct: Callable) -> Tuple[Any, str, str]:
11581155
"""
11591156
Runs a function and capture standard output and error.
11601157
@@ -1250,7 +1247,7 @@ def assert_onnx_disc(
12501247
proto, onnx.ModelProto
12511248
), f"Unexpected type {type(proto)} for proto"
12521249
name = self.dump_onnx(name, proto)
1253-
if verbose:
1250+
if verbose and not self.unit_test_going():
12541251
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
12551252
if verbose:
12561253
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
@@ -1262,15 +1259,14 @@ def assert_onnx_disc(
12621259
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
12631260
import onnxruntime
12641261

1265-
if verbose:
1266-
print(f"[{vname}] create onnxruntime.InferenceSession")
12671262
options = onnxruntime.SessionOptions()
12681263
if ort_optimized_graph:
12691264
options.optimized_model_filepath = f"{name}.optort.onnx"
1265+
providers = kwargs.get("providers", ["CPUExecutionProvider"])
1266+
if verbose:
1267+
print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
12701268
sess = onnxruntime.InferenceSession(
1271-
proto.SerializeToString(),
1272-
options,
1273-
providers=kwargs.get("providers", ["CPUExecutionProvider"]),
1269+
proto.SerializeToString(), options, providers=providers
12741270
)
12751271
if verbose:
12761272
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")

0 commit comments

Comments
 (0)