Skip to content

Commit 38fe3d6

Browse files
committed
fix modelbuilder
1 parent e52ec0b commit 38fe3d6

File tree

7 files changed

+70
-62
lines changed

7 files changed

+70
-62
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,60 @@ def test_unbatch_inputs(self):
862862
s,
863863
)
864864

865+
def test_guess_dynamic_cache_without_patches(self):
866+
n_layers = 2
867+
bsize, nheads, slen, dim = 2, 4, 3, 7
868+
cache = make_dynamic_cache(
869+
[
870+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
871+
for i in range(n_layers)
872+
]
873+
)
874+
z = torch.randn((1, 1, 1, 7))
875+
cache2 = make_dynamic_cache(
876+
[
877+
(
878+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
879+
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
880+
)
881+
for i in range(n_layers)
882+
]
883+
)
884+
inputs = [
885+
(cache, z),
886+
(cache2, torch.randn((1, 1, 1, 8))),
887+
]
888+
889+
class Model(torch.nn.Module):
890+
def forward(self, cache, z):
891+
cache = CacheKeyValue(cache)
892+
return (
893+
z
894+
+ cache.key_cache[0]
895+
+ cache.key_cache[1]
896+
+ cache.value_cache[0]
897+
+ cache.value_cache[1]
898+
)
899+
900+
mi = ModelInputs(Model(), inputs)
901+
ds = mi.guess_dynamic_shapes()
902+
DYN = torch.export.Dim.DYNAMIC
903+
self.assertEqual(
904+
(
905+
(
906+
[
907+
{0: DYN, 2: DYN, 3: DYN},
908+
{0: DYN, 2: DYN, 3: DYN},
909+
{0: DYN, 2: DYN, 3: DYN},
910+
{0: DYN, 2: DYN, 3: DYN},
911+
],
912+
{3: DYN},
913+
),
914+
{},
915+
),
916+
ds,
917+
)
918+
865919

866920
if __name__ == "__main__":
867921
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def test_convert_dynamic_axes_into_dynamic_shapes_2(self):
8888
)
8989
self.assertEqual(
9090
[
91-
[{0: "batch_size", 2: "past_sequence_length"}],
92-
[{0: "batch_size", 2: "past_sequence_length"}],
91+
{0: "batch_size", 2: "past_sequence_length"},
92+
{0: "batch_size", 2: "past_sequence_length"},
9393
],
9494
res[2]["past_key_values"],
9595
)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import itertools
23
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
34
import numpy as np
45
import torch
@@ -934,7 +935,7 @@ def guess_dynamic_shape_object(
934935
auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
935936
)
936937
)
937-
return [key_cache, value_cache]
938+
return list(itertools.chain.from_iterable(zip(key_cache, value_cache)))
938939

939940
raise NotImplementedError(
940941
f"Unable to build dynamic shapes for type {set_types.pop()}: "

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -96,54 +96,9 @@ def make_feeds(
9696
elif isinstance(i, float):
9797
i = np.array(i, dtype=np.float32)
9898
new_flat.append(i)
99-
100-
# NOTE: model builder has a different order for past_key_values
101-
# we need to reorder them to match the expected order
102-
if is_modelbuilder:
103-
# We assume that if "past_key_values" is in the names when it's
104-
# modelbuilder
105-
non_past_kv_input_names = [n for n in names if "past_key_values" not in n]
106-
past_kv_names = [n for n in names if "past_key_values" in n]
107-
reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names)
108-
names = non_past_kv_input_names + reorder_past_kv_names
10999
return dict(zip(names, new_flat))
110100

111101

112-
def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
113-
"""
114-
Reorders the past_kvs for ModelBuilder to match the expected order
115-
by PyTorch exported models.
116-
117-
.. note::
118-
This function can take either the names or the actual tensors
119-
as long as they are in a list.
120-
121-
Conceptually,
122-
123-
From::
124-
125-
[past_key_values.0.key, past_key_values.0.value,
126-
past_key_values.1.key, past_key_values.1.value, ...]
127-
128-
To::
129-
130-
[past_key_values.0.key, past_key_values.1.key,
131-
..., past_key_values.0.value, past_key_values.1.value, ...]
132-
133-
:param past_kv: list of flattened inputs
134-
:return: reordered list of flattened inputs
135-
"""
136-
total_len = len(past_kv)
137-
if total_len % 2 != 0:
138-
raise ValueError("The length of past_key_values should be even.")
139-
keys = []
140-
values = []
141-
for i in range(0, total_len, 2):
142-
keys.append(past_kv[i])
143-
values.append(past_kv[i + 1])
144-
return keys + values
145-
146-
147102
def _get_dim(i: int, s: Union[str, int], batch: int = 1) -> int:
148103
if isinstance(s, int):
149104
return s

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from typing import Any, Callable, Dict, Optional, Tuple
23
import torch
34
from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
@@ -269,10 +270,14 @@ def get_inputs_default(
269270
"token_type_ids": {0: batch, 1: seq_length},
270271
"attention_mask": {0: batch, 1: "cache+seq"},
271272
"position_ids": {0: batch, 1: seq_length},
272-
"past_key_values": [
273-
[{0: batch} for _ in range(num_hidden_layers)],
274-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
275-
],
273+
"past_key_values": list(
274+
itertools.chain.from_iterable(
275+
zip(
276+
[{0: batch} for _ in range(num_hidden_layers)],
277+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
278+
)
279+
)
280+
),
276281
"pixel_values": (
277282
{0: batch, 1: images}
278283
if model.__class__.__name__ == "IdeficsForVisionText2Text"

onnx_diagnostic/torch_export_patches/patch_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
3838
for v in subset.values():
3939
axes = v
4040
break
41-
new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]]
41+
new_shape = [axes for i in range(cache_length * 2)]
4242
return new_shape
4343
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
4444
raise NotImplementedError(

onnx_diagnostic/torch_models/validate.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..export import CoupleInputsDynamicShapes
1313
from ..helpers import max_diff, string_type, string_diff
1414
from ..helpers.helper import flatten_object
15-
from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch
15+
from ..helpers.rt_helper import make_feeds
1616
from ..helpers.torch_helper import to_any, torch_deepcopy
1717
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
1818
from ..tasks import random_input_kwargs
@@ -1478,7 +1478,7 @@ def _mk(key, flavour=flavour):
14781478
data[k_input],
14791479
use_numpy=True,
14801480
check_flatten=False,
1481-
is_modelbuilder=data["exporter"] == "modelbuilder",
1481+
is_modelbuilder=data["exporter"] == "modelbuilder", # to remove position_ids
14821482
)
14831483
if verbose:
14841484
print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
@@ -1501,13 +1501,6 @@ def _mk(key, flavour=flavour):
15011501
repeat=repeat,
15021502
warmup=warmup,
15031503
)
1504-
# NOTE: modelbuilder has different order on past_kv outputs
1505-
if data["exporter"] == "modelbuilder":
1506-
logits = got[:1]
1507-
past_key_values = got[1:]
1508-
reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values)
1509-
got = logits + reorder_past_key_values
1510-
15111504
if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
15121505
return summary, data
15131506

0 commit comments

Comments
 (0)