Skip to content

Commit d167935

Browse files
committed
fix patch
1 parent aea7b1a commit d167935

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Dict, List, Optional, Tuple
55
import torch
66
import transformers
7-
import transformers.modeling_attn_mask_utils
7+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
88
from transformers.cache_utils import StaticCache, Cache, DynamicCache
99
from ...helpers.torch_test_helper import is_torchdynamo_exporting
1010

@@ -54,7 +54,7 @@ class kkpatched_AttentionMaskConverter:
5454
"""
5555

5656
_PATCHES_ = ["_make_causal_mask"]
57-
_PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter
57+
_PATCHED_CLASS_ = AttentionMaskConverter
5858

5959
@staticmethod
6060
def _make_causal_mask(
@@ -79,7 +79,7 @@ class kkpatched_AttentionMaskConverter:
7979
"""
8080

8181
_PATCHES_ = ["_make_causal_mask"]
82-
_PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter
82+
_PATCHED_CLASS_ = AttentionMaskConverter
8383

8484
@staticmethod
8585
def _make_causal_mask(

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,34 @@ def filter_inputs(
109109
return new_inputs, dyn
110110

111111

112+
def _make_folder_name(
113+
model_id: str,
114+
exporter: str,
115+
optimization: Optional[str] = None,
116+
dtype: Optional[Union[str, torch.dtype]] = None,
117+
device: Optional[Union[str, torch.device]] = None,
118+
) -> str:
119+
"Creates a filename unique based on the given options."
120+
els = [model_id.replace("/", "_"), exporter]
121+
if optimization:
122+
els.append(optimization)
123+
if dtype is not None and dtype:
124+
stype = dtype if isinstance(dtype, str) else str(dtype)
125+
stype = stype.replace("float", "f").replace("uint", "u").replace("int", "i")
126+
els.append(stype)
127+
if device is not None and device:
128+
sdev = device if isinstance(device, str) else str(device)
129+
sdev = sdev.lower()
130+
if "cpu" in sdev:
131+
sdev = "cpu"
132+
elif "cuda" in sdev:
133+
sdev = "cuda"
134+
else:
135+
raise AssertionError(f"unexpected value for device={device}, sdev={sdev!r}")
136+
els.append(sdev)
137+
return "-".join(els)
138+
139+
112140
def validate_model(
113141
model_id: str,
114142
task: Optional[str] = None,
@@ -152,7 +180,9 @@ def validate_model(
152180
assert not trained, f"trained={trained} not supported yet"
153181
summary: Dict[str, Union[int, float, str]] = {}
154182
if dump_folder:
155-
folder_name = f"{model_id.replace('/','-')}-{exporter}-{optimization or ''}"
183+
folder_name = _make_folder_name(
184+
model_id, exporter, optimization, dtype=dtype, device=device
185+
)
156186
dump_folder = os.path.join(dump_folder, folder_name)
157187
if not os.path.exists(dump_folder):
158188
os.makedirs(dump_folder)
@@ -353,7 +383,7 @@ def validate_model(
353383
if verbose:
354384
print(f"[validate_model] dumps onnx program in {dump_folder!r}...")
355385
onnx_file_name = os.path.join(dump_folder, f"{folder_name}.onnx")
356-
epo.save(onnx_file_name)
386+
epo.save(onnx_file_name, external_data=True)
357387
if verbose:
358388
print("[validate_model] done (dump onnx)")
359389
if verbose:

0 commit comments

Comments
 (0)