Skip to content

Commit 94d9508

Browse files
authored
Exposes register_flattening_functions (#82)
* Fix for missing intermediate_size * tiny improvment * mypy * better
1 parent 227b022 commit 94d9508

File tree

8 files changed

+83
-14
lines changed

8 files changed

+83
-14
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.4.4
55
+++++
66

7+
* :pr:`82`: exposes ``register_flattening_functions``, add option ``--subfolder``
8+
* :pr:`81`: fixes missing ``intermediate_size`` in configuration
79
* :pr:`79`: implements task ``object-detection``
810
* :pr:`78`: uses *onnx-weekly* instead of *onnx* to avoid conflicts with *onnxscript*
911

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def forward(self, cache, z):
112112
)
113113
print(ep)
114114

115+
# %%
115116
# Do we need to guess?
116117
# ++++++++++++++++++++
117118
#

onnx_diagnostic/_command_lines_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ def get_parser_validate() -> ArgumentParser:
336336
help="drops the following inputs names, it should be a list "
337337
"with comma separated values",
338338
)
339+
parser.add_argument(
340+
"--subfolder",
341+
help="subfolder where to find the model and the configuration",
342+
)
339343
parser.add_argument(
340344
"--ortfusiontype",
341345
required=False,
@@ -413,6 +417,7 @@ def _cmd_validate(argv: List[Any]):
413417
ortfusiontype=args.ortfusiontype,
414418
input_options=args.iop,
415419
model_options=args.mop,
420+
subfolder=args.subfolder,
416421
)
417422
print("")
418423
print("-- summary --")

onnx_diagnostic/torch_export_patches/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,15 @@
66

77
# bypass_export_some_errors is the first name given to the patches.
88
bypass_export_some_errors = torch_export_patches # type: ignore
9+
10+
11+
def register_flattening_functions(verbose: int = 0):
12+
"""
13+
Registers functions to serialize deserialize cache or other classes
14+
implemented in :epkg:`transformers` and used as inputs.
15+
This is needed whenever a model must be exported through
16+
:func:`torch.export.export`.
17+
"""
18+
from .onnx_export_serialization import _register_cache_serialization
19+
20+
return _register_cache_serialization(verbose=verbose)

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import copy
22
import functools
3+
import json
34
import os
45
from typing import Any, Dict, List, Optional, Union
56
import transformers
6-
from huggingface_hub import HfApi, model_info
7+
from huggingface_hub import HfApi, model_info, hf_hub_download
78
from ...helpers.config_helper import update_config
89
from . import hub_data_cached_configs
910
from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
@@ -59,7 +60,11 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr
5960

6061

6162
def get_pretrained_config(
62-
model_id: str, trust_remote_code: bool = True, use_preinstalled: bool = True, **kwargs
63+
model_id: str,
64+
trust_remote_code: bool = True,
65+
use_preinstalled: bool = True,
66+
subfolder: Optional[str] = None,
67+
**kwargs,
6368
) -> Any:
6469
"""
6570
Returns the config for a model_id.
@@ -71,13 +76,32 @@ def get_pretrained_config(
7176
accessing the network, if available, it is returned by
7277
:func:`get_cached_configuration`, the cached list is mostly for
7378
unit tests
79+
:param subfolder: subfolder for the given model id
7480
:param kwargs: additional kwargs
7581
:return: a configuration
7682
"""
7783
if use_preinstalled:
78-
conf = get_cached_configuration(model_id, **kwargs)
84+
conf = get_cached_configuration(model_id, subfolder=subfolder, **kwargs)
7985
if conf is not None:
8086
return conf
87+
if subfolder:
88+
try:
89+
return transformers.AutoConfig.from_pretrained(
90+
model_id, trust_remote_code=trust_remote_code, subfolder=subfolder, **kwargs
91+
)
92+
except ValueError:
93+
# Then we try to download it.
94+
config = hf_hub_download(
95+
model_id, filename="config.json", subfolder=subfolder, **kwargs
96+
)
97+
try:
98+
return transformers.AutoConfig.from_pretrained(
99+
config, trust_remote_code=trust_remote_code, **kwargs
100+
)
101+
except ValueError:
102+
# Diffusers uses a dictionayr.
103+
with open(config, "r") as f:
104+
return json.load(f)
81105
return transformers.AutoConfig.from_pretrained(
82106
model_id, trust_remote_code=trust_remote_code, **kwargs
83107
)

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
T5ForConditionalGeneration,text2text-generation
126126
TableTransformerModel,image-feature-extraction
127127
TableTransformerForObjectDetection,object-detection
128+
UNet2DConditionModel,text-to-image
128129
UniSpeechForSequenceClassification,audio-classification
129130
ViTForImageClassification,image-classification
130131
ViTMAEModel,image-feature-extraction
@@ -163,6 +164,7 @@
163164
"sentence-similarity",
164165
"text-classification",
165166
"text-generation",
167+
"text-to-image",
166168
"text-to-audio",
167169
"text2text-generation",
168170
"zero-shot-image-classification",

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def get_untrained_model_with_inputs(
1818
same_as_pretrained: bool = False,
1919
use_preinstalled: bool = True,
2020
add_second_input: bool = False,
21+
subfolder: Optional[str] = None,
2122
) -> Dict[str, Any]:
2223
"""
2324
Gets a non initialized model similar to the original model
@@ -37,6 +38,7 @@ def get_untrained_model_with_inputs(
3738
:param use_preinstalled: use preinstalled configurations
3839
:param add_second_input: provides a second inputs to check a model
3940
supports different shapes
41+
:param subfolder: subfolder to use for this model id
4042
:return: dictionary with a model, inputs, dynamic shapes, and the configuration
4143
4244
Example:
@@ -62,11 +64,18 @@ def get_untrained_model_with_inputs(
6264
print(f"[get_untrained_model_with_inputs] use preinstalled {model_id!r}")
6365
if config is None:
6466
config = get_pretrained_config(
65-
model_id, use_preinstalled=use_preinstalled, **(model_kwargs or {})
67+
model_id,
68+
use_preinstalled=use_preinstalled,
69+
subfolder=subfolder,
70+
**(model_kwargs or {}),
6671
)
6772
if hasattr(config, "architecture") and config.architecture:
6873
archs = [config.architecture]
69-
archs = config.architectures # type: ignore
74+
if type(config) is dict:
75+
assert "_class_name" in config, f"Unable to get the architecture from config={config}"
76+
archs = [config["_class_name"]]
77+
else:
78+
archs = config.architectures # type: ignore
7079
task = None
7180
if archs is None:
7281
task = task_from_id(model_id)
@@ -84,6 +93,10 @@ def get_untrained_model_with_inputs(
8493

8594
# model kwagrs
8695
if dynamic_rope is not None:
96+
assert (
97+
type(config) is not dict
98+
), f"Unable to set dynamic_rope if the configuration is a dictionary\n{config}"
99+
assert hasattr(config, "rope_scaling"), f"Missing 'rope_scaling' in\n{config}"
87100
config.rope_scaling = (
88101
{"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None
89102
)

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,12 @@ def _make_folder_name(
109109
optimization: Optional[str] = None,
110110
dtype: Optional[Union[str, torch.dtype]] = None,
111111
device: Optional[Union[str, torch.device]] = None,
112+
subfolder: Optional[str] = None,
112113
) -> str:
113114
"Creates a filename unique based on the given options."
114115
els = [model_id.replace("/", "_")]
116+
if subfolder:
117+
els.append(subfolder.replace("/", "_"))
115118
if exporter:
116119
els.append(exporter)
117120
if optimization:
@@ -224,6 +227,7 @@ def validate_model(
224227
ortfusiontype: Optional[str] = None,
225228
input_options: Optional[Dict[str, Any]] = None,
226229
model_options: Optional[Dict[str, Any]] = None,
230+
subfolder: Optional[str] = None,
227231
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
228232
"""
229233
Validates a model.
@@ -256,11 +260,11 @@ def validate_model(
256260
used to export
257261
:param model_options: additional options when creating the model such as
258262
``num_hidden_layers`` or ``attn_implementation``
263+
:param subfolder: version or subfolders to uses when retrieving a model id
259264
:return: two dictionaries, one with some metrics,
260265
another one with whatever the function produces
261266
"""
262267
summary = version_summary()
263-
264268
summary.update(
265269
dict(
266270
version_model_id=model_id,
@@ -282,7 +286,7 @@ def validate_model(
282286
folder_name = None
283287
if dump_folder:
284288
folder_name = _make_folder_name(
285-
model_id, exporter, optimization, dtype=dtype, device=device
289+
model_id, exporter, optimization, dtype=dtype, device=device, subfolder=subfolder
286290
)
287291
dump_folder = os.path.join(dump_folder, folder_name)
288292
if not os.path.exists(dump_folder):
@@ -293,11 +297,15 @@ def validate_model(
293297
print(f"[validate_model] dump into {folder_name!r}")
294298

295299
if verbose:
296-
print(f"[validate_model] validate model id {model_id!r}")
300+
if subfolder:
301+
print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
302+
else:
303+
print(f"[validate_model] validate model id {model_id!r}")
297304
if model_options:
298305
print(f"[validate_model] model_options={model_options!r}")
299306
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
300307
summary["model_id"] = model_id
308+
summary["model_subfolder"] = subfolder or ""
301309

302310
iop = input_options or {}
303311
mop = model_options or {}
@@ -307,14 +315,15 @@ def validate_model(
307315
summary,
308316
None,
309317
(
310-
lambda mid=model_id, v=verbose, task=task, tr=trained, iop=iop: (
318+
lambda mid=model_id, v=verbose, task=task, tr=trained, iop=iop, sub=subfolder: (
311319
get_untrained_model_with_inputs(
312320
mid,
313321
verbose=v,
314322
task=task,
315323
same_as_pretrained=tr,
316324
inputs_kwargs=iop,
317325
model_kwargs=mop,
326+
subfolder=sub,
318327
)
319328
)
320329
),
@@ -1060,15 +1069,16 @@ def call_torch_export_custom(
10601069
assert (
10611070
optimization in available
10621071
), f"unexpected value for optimization={optimization}, available={available}"
1063-
assert exporter in {
1072+
available = {
10641073
"custom",
10651074
"custom-strict",
1066-
"custom-strict-dec",
1075+
"custom-strict-default",
10671076
"custom-strict-all",
10681077
"custom-nostrict",
1069-
"custom-nostrict-dec",
1078+
"custom-nostrict-default",
10701079
"custom-nostrict-all",
1071-
}, f"Unexpected value for exporter={exporter!r}"
1080+
}
1081+
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
10721082
assert "model" in data, f"model is missing from data: {sorted(data)}"
10731083
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
10741084
summary: Dict[str, Union[str, int, float]] = {}
@@ -1100,7 +1110,7 @@ def call_torch_export_custom(
11001110
export_options = ExportOptions(
11011111
strict=strict,
11021112
decomposition_table=(
1103-
"dec" if "-dec" in exporter else ("all" if "-all" in exporter else None)
1113+
"default" if "-default" in exporter else ("all" if "-all" in exporter else None)
11041114
),
11051115
)
11061116
options = OptimizationOptions(patterns=optimization) if optimization else None

0 commit comments

Comments
 (0)