Skip to content

Commit f91cc32

Browse files
committed
fix script
1 parent 859338f commit f91cc32

File tree

4 files changed

+195
-19
lines changed

4 files changed

+195
-19
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,42 @@ def test_couple_input_ds_replace_string(self):
679679
).replace_string_by(value="DYN"),
680680
)
681681

682+
def test_couple_input_ds_replace_by_string(self):
683+
T3x1 = torch.rand((3, 1))
684+
T3x4 = torch.rand((3, 4))
685+
T5x1 = torch.rand((5, 1))
686+
args = (T5x1,)
687+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
688+
ds_batch = {0: "batch"}
689+
ds_batch_seq = {0: "batch", 1: "seq"}
690+
ds = {"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)}
691+
Cls = CoupleInputsDynamicShapes
692+
res = Cls(
693+
args,
694+
kwargs,
695+
ds,
696+
args_names=["X"],
697+
).replace_by_string()
698+
self.assertEqual(ds, res)
699+
700+
ds_batch = {0: torch.export.Dim("batch")}
701+
ds_batch_seq = {0: torch.export.Dim("batch"), 1: torch.export.Dim.DYNAMIC}
702+
ds = {"X": ds_batch, "A": ds_batch_seq, "B": (ds_batch_seq, ds_batch_seq)}
703+
res = Cls(
704+
args,
705+
kwargs,
706+
ds,
707+
args_names=["X"],
708+
).replace_by_string()
709+
self.assertEqual(
710+
{
711+
"X": {0: "batch"},
712+
"A": {0: "batch", 1: "Dim1"},
713+
"B": ({0: "batch", 1: "Dim1"}, {0: "batch", 1: "Dim1"}),
714+
},
715+
res,
716+
)
717+
682718
def test_couple_input_ds_change_dynamic_dimensions(self):
683719
T257 = torch.arange(2 * 5 * 7).reshape((2, 5, 7))
684720
T29 = torch.arange(2 * 9).reshape((2, 9))
@@ -703,6 +739,37 @@ def test_couple_input_ds_change_dynamic_dimensions_fixed(self):
703739
self.assertEqual((1, 5, 8), new_input["A"].shape)
704740
self.assertEqual((1, 50), new_input["B"].shape)
705741

742+
def test_dynamic_cache_replace_by_string(self):
743+
n_layers = 2
744+
bsize, nheads, slen, dim = 2, 4, 3, 7
745+
cache = make_dynamic_cache(
746+
[
747+
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
748+
for i in range(n_layers)
749+
]
750+
)
751+
752+
DYN = torch.export.Dim.DYNAMIC
753+
ds = {
754+
"cache": [
755+
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
756+
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
757+
]
758+
}
759+
inst = CoupleInputsDynamicShapes((), dict(cache=cache), ds)
760+
as_string = inst.replace_by_string()
761+
self.assertEqual(
762+
{
763+
"cache": [
764+
{0: "Dim0", 1: "Dim1"},
765+
{0: "Dim2", 1: "Dim3"},
766+
{0: "Dim4", 1: "Dim5"},
767+
{0: "Dim6", 1: "Dim7"},
768+
]
769+
},
770+
as_string,
771+
)
772+
706773

707774
if __name__ == "__main__":
708775
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def get_parser_validate() -> ArgumentParser:
291291
"--ortfusiontype",
292292
required=False,
293293
help="applies onnxruntime fusion, this parameter should contain the "
294-
"model type or multiple values separated by |",
294+
"model type or multiple values separated by `|`. `ALL` can be used "
295+
"to run them all",
295296
)
296297
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
297298
parser.add_argument("--dtype", help="changes dtype if necessary")

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
33
import numpy as np
44
import torch
55
from ..helpers import string_type
@@ -8,6 +8,30 @@
88
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
99

1010

11+
def flatten_dynamic_shapes(ds: Any) -> Any:
12+
"""Flattens the dynamic shapes."""
13+
if isinstance(ds, list):
14+
return _flat_list([flatten_dynamic_shapes(t) for t in ds])
15+
if isinstance(ds, tuple):
16+
return tuple(_flat_list([flatten_dynamic_shapes(t) for t in ds]))
17+
if isinstance(ds, dict):
18+
if all(isinstance(i, int) for i in ds):
19+
# That's a dynamic shape
20+
return ds
21+
return _flat_list([flatten_dynamic_shapes(t) for t in ds.values()])
22+
raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
23+
24+
25+
def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
26+
res = []
27+
for t in li:
28+
if isinstance(t, dict):
29+
res.append(t)
30+
else:
31+
res.extend(t)
32+
return res
33+
34+
1135
class CoupleInputsDynamicShapes:
1236
"""
1337
Pair inputs / dynamic shapes.
@@ -76,7 +100,7 @@ def _replace_string_dim_tensor(cls, inputs, ds, value=None):
76100
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
77101
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
78102
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
79-
f"a dictionary is expected to specify a dimension dimension"
103+
f"a dictionary is expected to specify a dimension"
80104
)
81105
if value is None:
82106
value = torch.export.Dim.DYNAMIC
@@ -86,6 +110,57 @@ def _replace_string_dim_tensor(cls, inputs, ds, value=None):
86110
new_ds[i] = value
87111
return new_ds
88112

113+
def replace_by_string(self):
114+
"""
115+
Replaces dimensions by strings.
116+
117+
Example:
118+
119+
.. runpython::
120+
:showcode:
121+
122+
import torch
123+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
124+
125+
Dim = torch.export.Dim
126+
T3x1 = torch.rand((3, 1))
127+
T3x4 = torch.rand((3, 4))
128+
ds_batch = {0: Dim("batch")}
129+
ds_batch_seq = {0: Dim("batch"), 1: Dim("seq")}
130+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
131+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
132+
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_by_string())
133+
"""
134+
unique = set()
135+
return self._generic_walker(
136+
lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
137+
inputs, ds, unique=unique
138+
)
139+
)
140+
141+
@classmethod
142+
def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]):
143+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
144+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
145+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
146+
f"a dictionary is expected to specify a dimension"
147+
)
148+
new_ds = ds.copy()
149+
for i, v in ds.items():
150+
if isinstance(v, str):
151+
assert v not in unique, f"Dimension {v!r} is already defined in {unique}"
152+
unique.add(v)
153+
new_ds[i] = v
154+
elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO):
155+
name = f"Dim{len(unique)}"
156+
new_ds[i] = name
157+
unique.add(name)
158+
else:
159+
name = v.__name__
160+
unique.add(name)
161+
new_ds[i] = name
162+
return new_ds
163+
89164
def invalid_dimensions_for_export(self):
90165
"""
91166
Tells if the inputs are valid based on the dynamic shapes definition.
@@ -252,6 +327,9 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
252327
f"map this class with the given dynamic shapes."
253328
)
254329
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
330+
if all(isinstance(t, torch.Tensor) for t in flat):
331+
# We need to flatten dynamic shapes as well
332+
ds = flatten_dynamic_shapes(ds)
255333
return cls._generic_walker_step(processor, flat, ds)
256334

257335
class ChangeDimensionProcessor:

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
import onnx
77
import torch
8+
from ..export import CoupleInputsDynamicShapes
89
from ..helpers import max_diff, string_type, string_diff
910
from ..helpers.helper import flatten_object
1011
from ..helpers.rt_helper import make_feeds
@@ -506,7 +507,12 @@ def validate_model(
506507
), f"Missing attribute num_attention_heads in configuration {config}"
507508
num_attention_heads = config.num_attention_heads
508509

509-
model_types = ortfusiontype.split("|")
510+
if ortfusiontype == "ALL":
511+
from onnxruntime.transformers.optimizer import MODEL_TYPES
512+
513+
model_types = sorted(MODEL_TYPES)
514+
else:
515+
model_types = ortfusiontype.split("|")
510516
for model_type in model_types:
511517
flavour = f"ort{model_type}"
512518
summary[f"version_{flavour}_hidden_size"] = hidden_size
@@ -517,13 +523,15 @@ def validate_model(
517523
print(f"[validate_model] run onnxruntime fusion for {model_type!r}")
518524
input_filename = data["onnx_filename"]
519525
output_path = f"{os.path.splitext(input_filename)[0]}.ort.{model_type}.onnx"
520-
run_ort_fusion(
526+
ort_sum, ort_data = run_ort_fusion(
521527
input_filename,
522528
output_path,
523529
model_type=model_type,
524530
num_attention_heads=num_attention_heads,
525531
hidden_size=hidden_size,
526532
)
533+
summary.update(ort_sum)
534+
data.update(ort_data)
527535
data[f"onnx_filename_{flavour}"] = output_path
528536
duration = time.perf_counter() - begin
529537
summary[f"time_ortfusion_{flavour}"] = duration
@@ -590,7 +598,7 @@ def call_exporter(
590598
optimization=optimization,
591599
)
592600
return summary, data
593-
if exporter.startswith("custom-"):
601+
if exporter == "custom" or exporter.startswith("custom"):
594602
# torch export
595603
summary, data = call_torch_export_custom(
596604
exporter=exporter,
@@ -758,6 +766,9 @@ def _mk(key):
758766

759767
if input_data_key in data:
760768
source = data[input_data_key]
769+
if not os.path.exists(source):
770+
summary[_mk("ERR_onnx_missing")] = f"FileNotFoundError({source!r})"
771+
return summary, data
761772
summary[input_data_key] = source
762773
summary[_mk("onnx_size")] = os.stat(source).st_size
763774
else:
@@ -866,7 +877,7 @@ def call_torch_export_onnx(
866877
assert "model" in data, f"model is missing from data: {sorted(data)}"
867878
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
868879
summary: Dict[str, Union[str, int, float]] = {}
869-
dynamo = "nostrict" not in exporter
880+
dynamo = "dynamo" in exporter
870881
args, kwargs = split_args_kwargs(data["inputs_export"])
871882
ds = data.get("dynamic_shapes", None)
872883
if verbose:
@@ -884,15 +895,23 @@ def call_torch_export_onnx(
884895
summary["export_args"] = string_type(args, with_shape=True)
885896
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
886897

898+
export_export_kwargs = (
899+
dict(dynamo=True, dynamic_shapes=ds)
900+
if dynamo
901+
else dict(
902+
dynamo=False,
903+
dynamic_axes=CoupleInputsDynamicShapes(args, kwargs, ds).replace_by_string(),
904+
)
905+
)
906+
887907
begin = time.perf_counter()
888908
if quiet:
889909
try:
890910
epo = torch.onnx.export(
891911
data["model"],
892912
args,
893913
kwargs=kwargs,
894-
dynamic_shapes=ds,
895-
dynamo=dynamo,
914+
**export_export_kwargs,
896915
)
897916
except Exception as e:
898917
summary["ERR_export_export"] = str(e)
@@ -904,8 +923,7 @@ def call_torch_export_onnx(
904923
data["model"],
905924
args,
906925
kwargs=kwargs,
907-
dynamic_shapes=ds,
908-
dynamo=dynamo,
926+
**export_export_kwargs,
909927
)
910928

911929
summary["time_export_export"] = time.perf_counter() - begin
@@ -966,6 +984,7 @@ def call_torch_export_custom(
966984
None,
967985
}, f"unexpected value for optimization={optimization}"
968986
assert exporter in {
987+
"custom",
969988
"custom-strict",
970989
"custom-strict-dec",
971990
"custom-strict-all",
@@ -1155,14 +1174,24 @@ def run_ort_fusion(
11551174
f"[run_ort_fusion] starts optimization for "
11561175
f"model_type={model_type!r} with {n_nodes} nodes"
11571176
)
1158-
new_onx = optimize_by_fusion(
1159-
onx,
1160-
model_type=model_type,
1161-
num_heads=num_attention_heads,
1162-
hidden_size=hidden_size,
1163-
optimization_options=opts,
1164-
)
1165-
duration = {time.perf_counter() - begin}
1177+
try:
1178+
new_onx = optimize_by_fusion(
1179+
onx,
1180+
model_type=model_type,
1181+
num_heads=num_attention_heads,
1182+
hidden_size=hidden_size,
1183+
optimization_options=opts,
1184+
)
1185+
except Exception as e:
1186+
duration = {time.perf_counter() - begin}
1187+
if verbose:
1188+
print(f"[run_ort_fusion] failed in {duration} for model_type={model_type!r}")
1189+
return {
1190+
f"ERR_opt_ort_{model_type}": str(e),
1191+
f"opt_ort_{model_type}_duration": duration,
1192+
}, {}
1193+
1194+
duration = time.perf_counter() - begin
11661195
delta = len(new_onx.model.graph.node)
11671196
if verbose:
11681197
print(f"[run_ort_fusion] done in {duration} with {delta} nodes")
@@ -1175,6 +1204,7 @@ def run_ort_fusion(
11751204
return {
11761205
f"opt_ort_{model_type}_n_nodes1": n_nodes,
11771206
f"opt_ort_{model_type}_n_nodes2": delta,
1207+
f"opt_ort_{model_type}_delta_node": delta - n_nodes,
11781208
f"opt_ort_{model_type}_duration": duration,
11791209
f"opt_ort_{model_type}_duration_save": d,
11801210
}, {f"opt_ort_{model_type}": output_path}

0 commit comments

Comments
 (0)