Skip to content

Commit 9ce79a1

Browse files
committed
add function to convert dynamic_axes
1 parent 5eb5217 commit 9ce79a1

File tree

3 files changed

+221
-1
lines changed

3 files changed

+221
-1
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import unittest
2+
import torch
3+
import transformers
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
5+
from onnx_diagnostic.helpers import string_type
6+
from onnx_diagnostic.torch_export_patches.patch_inputs import (
7+
convert_dynamic_axes_into_dynamic_shapes,
8+
)
9+
10+
11+
class TestPatchInputs(ExtTestCase):
12+
@hide_stdout()
13+
def test_convert_dynamic_axes_into_dynamic_shapes(self):
14+
args = (
15+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
16+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
17+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
18+
[(torch.rand((2, 1, 3, 96)), torch.rand((2, 1, 3, 96)))],
19+
)
20+
dynamic_axes = {
21+
"attention_mask": {0: "batch_size", 1: "total_sequence_length"},
22+
"input_ids": {0: "batch_size", 1: "sequence_length"},
23+
"logits": {0: "batch_size", 1: "sequence_length"},
24+
"past_key_values.0.key": {0: "batch_size", 2: "past_sequence_length"},
25+
"past_key_values.0.value": {0: "batch_size", 2: "past_sequence_length"},
26+
"position_ids": {0: "batch_size", 1: "sequence_length"},
27+
"present.0.key": {0: "batch_size", 2: "total_sequence_length"},
28+
"present.0.value": {0: "batch_size", 2: "total_sequence_length"},
29+
}
30+
31+
model_cls = transformers.LlamaModel
32+
res = convert_dynamic_axes_into_dynamic_shapes(
33+
model_cls, args=args, dynamic_axes=dynamic_axes, verbose=1
34+
)
35+
self.assertEqual((), res[0])
36+
self.assertEqual(
37+
(
38+
"dict(input_ids:T7s2x8,attention_mask:T7s2x8,position_ids:T7s2x8,"
39+
"past_key_values:DynamicCache(key_cache=#1[T1s2x1x3x96], "
40+
"value_cache=#1[T1s2x1x3x96]))"
41+
),
42+
string_type(res[1], with_shape=True),
43+
)
44+
self.assertEqual(
45+
{
46+
"attention_mask": {0: "batch_size", 1: "total_sequence_length"},
47+
"input_ids": {0: "batch_size", 1: "sequence_length"},
48+
"past_key_values": [
49+
[{0: "batch_size", 2: "past_sequence_length"}],
50+
[{0: "batch_size", 2: "past_sequence_length"}],
51+
],
52+
"position_ids": {0: "batch_size", 1: "sequence_length"},
53+
},
54+
res[2],
55+
)
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main(verbosity=2)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import inspect
2+
from typing import Any, Dict, Optional, Tuple
3+
import torch
4+
import transformers
5+
from ..helpers import string_type
6+
from ..cache_helpers import make_dynamic_cache
7+
8+
9+
def _process_cache(k: str, v):
10+
assert k != "position_ids" or isinstance(
11+
k, torch.Tensor
12+
), f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}"
13+
if (
14+
isinstance(v, list)
15+
and all(isinstance(i, tuple) for i in v)
16+
and set(len(t) for t in v) == {2}
17+
):
18+
# A dynamicCache
19+
cache = make_dynamic_cache(v)
20+
return cache
21+
if isinstance(v, torch.Tensor):
22+
return v
23+
raise NotImplementedError(
24+
f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}"
25+
)
26+
27+
28+
def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
29+
if cls is transformers.cache_utils.DynamicCache:
30+
assert subset, "DynamicCache cannot be empty"
31+
values = set(map(str, subset.values()))
32+
assert len(values) == 1, (
33+
f"Inconsistencies in subset={subset}, found={values}, "
34+
f"it cannot be a {cls}, value={string_type(value)}"
35+
)
36+
cache_length = len(value.key_cache)
37+
for v in subset.values():
38+
axes = v
39+
break
40+
new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]]
41+
return new_shape
42+
raise NotImplementedError(
43+
f"_make_shape not implemented for cls={cls}, "
44+
f"subset={subset}, value={string_type(value)}"
45+
)
46+
47+
48+
def convert_dynamic_axes_into_dynamic_shapes(
49+
model: torch.nn.Module,
50+
args: Optional[Tuple[Any, ...]] = None,
51+
kwargs: Optional[Dict[str, Any]] = None,
52+
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
53+
verbose: int = 0,
54+
) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]:
55+
"""
56+
Converts the input from an export to something :func:`torch.export.export` can handle.
57+
58+
:param model: model to convert (used to extract the signature)
59+
:param args: positional arguments
60+
:param kwargs: named arguments
61+
:param dynamic_axes: dynamic axes
62+
:param verbose: verbosity
63+
:return: (args, kwargs, dynamic shapes)
64+
"""
65+
new_kwargs = {}
66+
if args:
67+
assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
68+
print(
69+
f"[convert_dynamic_axes_into_dynamic_shapes] "
70+
f"mapping args to kwargs for model={model}"
71+
)
72+
plus = 0 if isinstance(model, torch.nn.Module) else 1
73+
pars = inspect.signature(model.forward).parameters
74+
assert len(pars) >= len(
75+
args
76+
), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}"
77+
78+
for i, p in enumerate(pars):
79+
if i < plus:
80+
continue
81+
if i - plus >= len(args):
82+
break
83+
if verbose:
84+
print(
85+
f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i-plus}] "
86+
f"to {p!r} ({string_type(args[i-plus])})"
87+
)
88+
new_kwargs[p] = args[i - plus]
89+
90+
if kwargs:
91+
for k, v in kwargs.items():
92+
assert k not in new_kwargs, f"Argument {k!r} from kwargs already present in args."
93+
new_kwargs[k] = v
94+
95+
# process
96+
updated_kwargs = {}
97+
changes = {}
98+
for k, v in new_kwargs.items():
99+
if isinstance(v, torch.Tensor):
100+
updated_kwargs[k] = v
101+
continue
102+
if isinstance(v, list):
103+
# cache?
104+
updated_kwargs[k] = _process_cache(k, v)
105+
if type(updated_kwargs[k]) is not type(v):
106+
# A cache was introduced.
107+
if verbose:
108+
print(
109+
f"[convert_dynamic_axes_into_dynamic_shapes] parameter "
110+
f"{k!r} was changed into {type(updated_kwargs[k])}"
111+
)
112+
changes[k] = type(updated_kwargs[k])
113+
continue
114+
raise NotImplementedError(
115+
f"Unexpected type {type(v)} for parameter {k!r} "
116+
f"({string_type(v, with_shape=True)})"
117+
)
118+
119+
# process dynamic axes
120+
if changes:
121+
dynamic_shapes = {}
122+
done = set()
123+
for k, v in dynamic_axes.items():
124+
if k not in changes and k in updated_kwargs and isinstance(v, dict):
125+
dynamic_shapes[k] = v
126+
continue
127+
if "." in k:
128+
# something like present.0.key
129+
prefix = k.split(".")[0]
130+
if prefix in done:
131+
continue
132+
if prefix in updated_kwargs and prefix in changes:
133+
# A cache.
134+
cls = changes[prefix]
135+
dynamic_shapes[prefix] = _make_shape(
136+
{
137+
_: __
138+
for _, __ in dynamic_axes.items()
139+
if _.startswith(f"{prefix}.")
140+
},
141+
cls,
142+
updated_kwargs[prefix],
143+
)
144+
done.add(prefix)
145+
continue
146+
if k not in updated_kwargs:
147+
# dynamic axes not in the given inputs, should be raise an exception?
148+
if verbose:
149+
print(
150+
f"[convert_dynamic_axes_into_dynamic_shapes] droping axes "
151+
f"{k!r}-{v!r}, not found in {set(updated_kwargs)}"
152+
)
153+
continue
154+
raise NotImplementedError(
155+
f"Unable to process dynamic axes {k!r}, axes={v}, "
156+
f"value={string_type(updated_kwargs[k], with_shape=True)}, "
157+
f"dynamic axes={dynamic_axes}, "
158+
f"updated_kwargs={string_type(updated_kwargs, with_shape=True)}"
159+
)
160+
161+
return (), updated_kwargs, dynamic_shapes

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def get_untrained_model_with_inputs(
299299

300300
sizes = compute_model_size(model)
301301
res["model"] = model
302-
res["config"] = config
302+
res["configuration"] = config
303303
res["size"] = sizes[0]
304304
res["n_weights"] = sizes[1]
305305

0 commit comments

Comments
 (0)