Skip to content

Commit a0ef200

Browse files
authored
Investigate SDPA (#103)
* check sdpa is working * fix mypy * more infos * mypy * fix * mypy
1 parent c1d1004 commit a0ef200

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def get_untrained_model_with_inputs(
6969
subfolder=subfolder,
7070
**(model_kwargs or {}),
7171
)
72+
7273
if hasattr(config, "architecture") and config.architecture:
7374
archs = [config.architecture]
7475
if type(config) is dict:
@@ -116,6 +117,17 @@ def get_untrained_model_with_inputs(
116117
if mkwargs:
117118
update_config(config, mkwargs)
118119

120+
# SDPA
121+
if model_kwargs and "attn_implementation" in model_kwargs:
122+
if hasattr(config, "_attn_implementation_autoset"):
123+
config._attn_implementation_autoset = False
124+
config._attn_implementation = model_kwargs["attn_implementation"] # type: ignore[union-attr]
125+
if verbose:
126+
print(
127+
f"[get_untrained_model_with_inputs] config._attn_implementation="
128+
f"{config._attn_implementation!r}" # type: ignore[union-attr]
129+
)
130+
119131
# input kwargs
120132
kwargs, fct = random_input_kwargs(config, task)
121133
if verbose:

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import inspect
33
import os
4+
import sys
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
56
import time
67
import onnx
@@ -375,8 +376,11 @@ def validate_model(
375376
summary[f"model_{k.replace('_','')}"] = data[k]
376377
summary["model_inputs_opionts"] = str(input_options or "")
377378
summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
378-
summary["model_shapes"] = string_type(str(data["dynamic_shapes"]))
379+
summary["model_shapes"] = string_type(data["dynamic_shapes"])
379380
summary["model_class"] = data["model"].__class__.__name__
381+
summary["model_module"] = str(data["model"].__class__.__module__)
382+
if summary["model_module"] in sys.modules:
383+
summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
380384
summary["model_config_class"] = data["configuration"].__class__.__name__
381385
summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "")
382386
summary["model_id"] = model_id
@@ -482,6 +486,7 @@ def validate_model(
482486
verbose=verbose,
483487
optimization=optimization,
484488
do_run=do_run,
489+
dump_folder=dump_folder,
485490
)
486491
else:
487492
data["inputs_export"] = data["inputs"]
@@ -493,6 +498,7 @@ def validate_model(
493498
verbose=verbose,
494499
optimization=optimization,
495500
do_run=do_run,
501+
dump_folder=dump_folder,
496502
)
497503
summary.update(summary_export)
498504

@@ -618,6 +624,7 @@ def call_exporter(
618624
verbose: int = 0,
619625
optimization: Optional[str] = None,
620626
do_run: bool = False,
627+
dump_folder: Optional[str] = None,
621628
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
622629
"""
623630
Calls an exporter on a model;
@@ -629,6 +636,7 @@ def call_exporter(
629636
:param verbose: verbosity
630637
:param optimization: optimization to do
631638
:param do_run: runs and compute discrepancies
639+
:param dump_folder: to dump additional information
632640
:return: two dictionaries, one with some metrics,
633641
another one with whatever the function produces
634642
"""
@@ -661,6 +669,7 @@ def call_exporter(
661669
quiet=quiet,
662670
verbose=verbose,
663671
optimization=optimization,
672+
dump_folder=dump_folder,
664673
)
665674
return summary, data
666675
raise NotImplementedError(
@@ -1045,6 +1054,7 @@ def call_torch_export_custom(
10451054
quiet: bool = False,
10461055
verbose: int = 0,
10471056
optimization: Optional[str] = None,
1057+
dump_folder: Optional[str] = None,
10481058
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
10491059
"""
10501060
Exports a model into onnx.
@@ -1056,6 +1066,7 @@ def call_torch_export_custom(
10561066
:param quiet: catch exception or not
10571067
:param verbose: verbosity
10581068
:param optimization: optimization to do
1069+
:param dump_folder: to store additional information
10591070
:return: two dictionaries, one with some metrics,
10601071
another one with whatever the function produces
10611072
"""
@@ -1113,6 +1124,7 @@ def call_torch_export_custom(
11131124
decomposition_table=(
11141125
"default" if "-default" in exporter else ("all" if "-all" in exporter else None)
11151126
),
1127+
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
11161128
)
11171129
options = OptimizationOptions(patterns=optimization) if optimization else None
11181130
model = data["model"]

0 commit comments

Comments
 (0)