Skip to content

Commit 96dfa22

Browse files
committed
add
1 parent 373bcd3 commit 96dfa22

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,13 @@ def forward(self, x):
432432
],
433433
sorted(df.columns),
434434
)
435-
self.assertEqual(len(results), 12)
435+
self.assertEqual(len(results), 8)
436436
self.assertEqual(
437+
[None, None, None, None, None, 0, 0, 0],
437438
[r.err_dev for r in results],
438-
[None, None, None, None, None, None, None, None, None, 0, 0, 0],
439439
)
440440
self.assertEqual(
441-
[-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
441+
[-10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
442442
df["onnx_id_node"].fillna(-10).tolist(),
443443
)
444444
self.clean_dump()
@@ -466,7 +466,7 @@ def forward(self, x):
466466
use_tensor=True,
467467
),
468468
)
469-
self.assertEqual(len(results), 5)
469+
self.assertEqual(len(results), 2)
470470

471471

472472
if __name__ == "__main__":

onnx_diagnostic/_command_lines_parser.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,10 @@ def _size(name):
12501250
import transformers.modeling_outputs # noqa: F401
12511251
print(f"-- load ep {args.ep!r}")
12521252
begin = time.perf_counter()
1253+
# We need to load the plugs.
1254+
from .torch_export_patches.patches.patch_transformers import PLUGS_Qwen25
1255+
1256+
assert len(PLUGS_Qwen25) == 1, "Missing PLUGS for Qwen2.5"
12531257
ep = torch.export.load(args.ep)
12541258
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
12551259

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def _loop_cmp(
498498
torch_results: Dict[str, Any] = {}
499499
last_position = 0
500500
torch_output_names = None
501-
torch_input_names = []
501+
torch_input_names: List[str] = []
502502
name_to_ep_node = {}
503503
torch_names_to_onnx_names = {}
504504
for i, node in enumerate(ep_graph_nodes):

0 commit comments

Comments
 (0)