Skip to content

Commit ed9c7b7

Browse files
committed
fix import issues
1 parent 99d5cd9 commit ed9c7b7

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,9 +1265,10 @@ def _size(name):
12651265
print(f"-- load ep {args.ep!r}")
12661266
begin = time.perf_counter()
12671267
# We need to load the plugs.
1268-
from .torch_export_patches.patches.patch_transformers import PLUGS_Qwen25
1268+
from .torch_export_patches.patches.patch_transformers import get_transformers_plugs
12691269

1270-
assert len(PLUGS_Qwen25) == 1, "Missing PLUGS for Qwen2.5"
1270+
plugs = get_transformers_plugs()
1271+
assert plugs, "Missing PLUGS for Qwen2.5"
12711272
ep = torch.export.load(args.ep)
12721273
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
12731274

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# transformers
2+
from typing import List
23
from .patch_helper import _has_transformers
34

45
from ._patch_transformers_attention import (
@@ -86,3 +87,11 @@
8687

8788

8889
from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
90+
91+
92+
def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821
93+
"""Returns the necessary plugs to rewrite models."""
94+
plugs = []
95+
if patch_qwen2_5:
96+
plugs.extend(PLUGS_Qwen25)
97+
return plugs

0 commit comments

Comments
 (0)