Skip to content

Commit 03d60cb

Browse files
committed
check sdpa is working
1 parent c1d1004 commit 03d60cb

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def get_untrained_model_with_inputs(
6969
subfolder=subfolder,
7070
**(model_kwargs or {}),
7171
)
72+
7273
if hasattr(config, "architecture") and config.architecture:
7374
archs = [config.architecture]
7475
if type(config) is dict:
@@ -116,6 +117,17 @@ def get_untrained_model_with_inputs(
116117
if mkwargs:
117118
update_config(config, mkwargs)
118119

120+
# SDPA
121+
if model_kwargs and "attn_implementation" in model_kwargs:
122+
if hasattr(config, "_attn_implementation_autoset"):
123+
config._attn_implementation_autoset = False
124+
config._attn_implementation = model_kwargs["attn_implementation"]
125+
if verbose:
126+
print(
127+
f"[get_untrained_model_with_inputs] config._attn_implementation="
128+
f"{config._attn_implementation!r}"
129+
)
130+
119131
# input kwargs
120132
kwargs, fct = random_input_kwargs(config, task)
121133
if verbose:

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def validate_model(
482482
verbose=verbose,
483483
optimization=optimization,
484484
do_run=do_run,
485+
dump_folder=dump_folder,
485486
)
486487
else:
487488
data["inputs_export"] = data["inputs"]
@@ -493,6 +494,7 @@ def validate_model(
493494
verbose=verbose,
494495
optimization=optimization,
495496
do_run=do_run,
497+
dump_folder=dump_folder,
496498
)
497499
summary.update(summary_export)
498500

@@ -618,6 +620,7 @@ def call_exporter(
618620
verbose: int = 0,
619621
optimization: Optional[str] = None,
620622
do_run: bool = False,
623+
dump_folder: Optional[None] = None,
621624
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
622625
"""
623626
Calls an exporter on a model;
@@ -629,6 +632,7 @@ def call_exporter(
629632
:param verbose: verbosity
630633
:param optimization: optimization to do
631634
:param do_run: runs and compute discrepancies
635+
:param dump_folder: to dump additional information
632636
:return: two dictionaries, one with some metrics,
633637
another one with whatever the function produces
634638
"""
@@ -661,6 +665,7 @@ def call_exporter(
661665
quiet=quiet,
662666
verbose=verbose,
663667
optimization=optimization,
668+
dump_folder=dump_folder,
664669
)
665670
return summary, data
666671
raise NotImplementedError(
@@ -1045,6 +1050,7 @@ def call_torch_export_custom(
10451050
quiet: bool = False,
10461051
verbose: int = 0,
10471052
optimization: Optional[str] = None,
1053+
dump_folder: Optional[str] = None,
10481054
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
10491055
"""
10501056
Exports a model into onnx.
@@ -1056,6 +1062,7 @@ def call_torch_export_custom(
10561062
:param quiet: catch exception or not
10571063
:param verbose: verbosity
10581064
:param optimization: optimization to do
1065+
:param dump_folder: to store additional information
10591066
:return: two dictionaries, one with some metrics,
10601067
another one with whatever the function produces
10611068
"""
@@ -1113,6 +1120,7 @@ def call_torch_export_custom(
11131120
decomposition_table=(
11141121
"default" if "-default" in exporter else ("all" if "-all" in exporter else None)
11151122
),
1123+
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
11161124
)
11171125
options = OptimizationOptions(patterns=optimization) if optimization else None
11181126
model = data["model"]

0 commit comments

Comments
 (0)