Skip to content

Commit 480a8a3

Browse files
committed
add agg
2 parents 4f13496 + cae7a79 commit 480a8a3

File tree

12 files changed

+364
-48
lines changed

12 files changed

+364
-48
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Change Logs
55
+++++
66

77
* :pr:`151`: adds command line ``agg``
8+
* :pr:`152`: add a function to compute fully dynamic shapes given any inputs
89

910
0.7.0
1011
+++++

_doc/api/export/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ onnx_diagnostic.export
66
:caption: modules
77

88
dynamic_shapes
9+
shape_helper
910
validate
1011

1112
CoupleInputsDynamicShapes

_doc/api/export/shape_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.export.shape_helper
3+
===================================
4+
5+
.. automodule:: onnx_diagnostic.export.shape_helper
6+
:members:
7+
:no-undoc-members:
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
Builds dynamic shapes from any input
3+
====================================
4+
5+
Getting dynamic shapes right for :func:`torch.export.export` when the inputs
6+
includes a custom class such as a :class:`transformers.cache_utils.DynamicCache`.
7+
:func:`torch.export.export` cannot use a DynamicCache filled with dynamic shapes
8+
but instead it uses a kind of unserialized serialized form of it.
9+
10+
Standard inputs for a LLM with a dynamic cache
11+
++++++++++++++++++++++++++++++++++++++++++++++
12+
"""
13+
14+
import pprint
15+
import torch
16+
from onnx_diagnostic import doc
17+
from onnx_diagnostic.helpers import string_type
18+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
20+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
21+
from onnx_diagnostic.torch_export_patches import torch_export_patches
22+
23+
bsize, nheads, slen, dim = 2, 1, 30, 96
24+
25+
inputs = dict(
26+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
27+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
28+
position_ids=torch.arange(3, dtype=torch.int64),
29+
past_key_values=make_dynamic_cache(
30+
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
31+
),
32+
)
33+
34+
print(string_type(inputs, with_shape=True))
35+
36+
# %%
37+
# Function :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs`
38+
# produces the corresponding dynamic shapes assuming they are all dynamic.
39+
ds = all_dynamic_shape_from_inputs(inputs)
40+
pprint.pprint(ds)
41+
42+
# %%
43+
# What about a StaticCache?
44+
# +++++++++++++++++++++++++
45+
#
46+
# We use :func:`onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs` to get
47+
# a consistent configuration with a static cache.
48+
49+
data = get_untrained_model_with_inputs(
50+
"arnir0/Tiny-LLM",
51+
model_kwargs=dict(cache_implementation="static"),
52+
inputs_kwargs=dict(cls_cache="StaticCache"),
53+
)
54+
inputs = data["inputs"]
55+
print(string_type(inputs, with_shape=True))
56+
57+
# %%
58+
# And the input shapes.
59+
ds = all_dynamic_shape_from_inputs(inputs)
60+
if ds["past_key_values"]:
61+
print("transformers implemented serialization function for StaticCache.")
62+
else:
63+
print("We need to use serialization function implemented in this package.")
64+
with torch_export_patches(patch_transformers=True):
65+
ds = all_dynamic_shape_from_inputs(inputs)
66+
67+
# %%
68+
# That gives.
69+
pprint.pprint(ds)
70+
71+
# %%
72+
# We can compare with the ones returned by the function.
73+
pprint.pprint(data["dynamic_shapes"])
74+
75+
76+
# %%
77+
78+
doc.plot_legend("dynamic shapes\nfrom inputs", "dynamic shapes", "green")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers, requires_torch
4+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
5+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
6+
7+
8+
class TestShapeHelper(ExtTestCase):
9+
@requires_transformers("4.52")
10+
@requires_torch("2.7.99")
11+
def test_all_dynamic_shape_from_inputs(self):
12+
ds = all_dynamic_shape_from_inputs((torch.randn((5, 6)), torch.randn((1, 6))))
13+
self.assertEqual([{0: "d_0_0", 1: "d_0_1"}, {0: "d_1_0", 1: "d_1_1"}], ds)
14+
ds = all_dynamic_shape_from_inputs(
15+
(torch.randn((5, 6)), torch.randn((1, 6))), dim_prefix=torch.export.Dim.AUTO
16+
)
17+
self.assertEqual(
18+
[
19+
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
20+
{0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
21+
],
22+
ds,
23+
)
24+
25+
@requires_transformers("4.52")
26+
@requires_torch("2.7.99")
27+
def test_all_dynamic_shape_from_inputs_dynamic_cache(self):
28+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
29+
print(self.string_type(data["inputs"], with_shape=True))
30+
ds = all_dynamic_shape_from_inputs(data["inputs"])
31+
self.assertEqual(
32+
{
33+
"input_ids": {0: "d_0_0", 1: "d_0_1"},
34+
"attention_mask": {0: "d_1_0", 1: "d_1_1"},
35+
"position_ids": {0: "d_2_0", 1: "d_2_1"},
36+
"past_key_values": {
37+
"key_cache": [{0: "d_3_0", 1: "d_3_1", 2: "d_3_2", 3: "d_3_3"}],
38+
"value_cache": [{0: "d_4_0", 1: "d_4_1", 2: "d_4_2", 3: "d_4_3"}],
39+
},
40+
},
41+
ds,
42+
)
43+
44+
45+
if __name__ == "__main__":
46+
unittest.main(verbosity=2)
39.5 KB
Binary file not shown.

_unittests/ut_helpers/test_log_helper.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
88
from onnx_diagnostic.helpers.log_helper import (
99
CubeLogs,
10+
CubeLogsPerformance,
1011
CubeViewDef,
1112
enumerate_csv_files,
1213
open_dataframe,
@@ -166,6 +167,7 @@ def test_cube_logs_excel(self):
166167
)
167168
self.assertExists(output)
168169

170+
@hide_stdout()
169171
def test_enumerate_csv_files(self):
170172
df = self.df1()
171173
filename = self.get_dump_file("test_enumerate_csv_files.csv")
@@ -186,6 +188,30 @@ def test_enumerate_csv_files(self):
186188
self.assertEqual((3, 11), cube.shape)
187189
self.assertIn("RAWFILENAME", cube.data.columns)
188190

191+
def test_cube_logs_performance(self):
192+
output = self.get_dump_file("test_cube_logs_performance.xlsx")
193+
filename = os.path.join(os.path.dirname(__file__), "data", "data-agg.zip")
194+
assert list(enumerate_csv_files(filename))
195+
dfs = [open_dataframe(df) for df in enumerate_csv_files(filename)]
196+
assert dfs, f"{filename!r} empty"
197+
cube = CubeLogsPerformance(dfs)
198+
cube.load()
199+
cube.to_excel(
200+
output,
201+
views=[
202+
"agg-suite",
203+
"disc",
204+
"speedup",
205+
"time",
206+
"time_export",
207+
"err",
208+
"cmd",
209+
"bucket-speedup",
210+
"raw-short",
211+
],
212+
)
213+
self.assertExists(output)
214+
189215

190216
if __name__ == "__main__":
191217
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def get_parser_agg() -> ArgumentParser:
641641
parser.add_argument(
642642
"-k",
643643
"--keys",
644-
default="^version_.*,^model_.*,providers,opt_patterns,suite,memory_peak,machine,exporter,dynamic,rtopt,dtype,device,architecture",
644+
default="^version_.*,^model_.*,device,opt_patterns,suite,memory_peak,machine,exporter,dynamic,rtopt,dtype,device,architecture",
645645
help="List of columns to consider as keys, "
646646
"multiple values are separated by `,`\n"
647647
"regular expressions are allowed",
@@ -665,9 +665,14 @@ def get_parser_agg() -> ArgumentParser:
665665
)
666666
parser.add_argument(
667667
"--views",
668-
default="summary-suite,disc,speedup,time,time_export,err,cmd,bucket-speedup",
668+
default="agg-suite,disc,speedup,time,time_export,err,cmd,bucket-speedup,raw-short",
669669
help="Views to add to the output files.",
670670
)
671+
parser.add_argument(
672+
"--csv",
673+
default="raw-short",
674+
help="Views to dump as csv files.",
675+
)
671676
parser.add_argument("-v", "--verbose", type=int, default=0, help="verbosity")
672677
return parser
673678

@@ -709,7 +714,12 @@ def _cmd_agg(argv: List[Any]):
709714
cube.load(verbose=max(args.verbose - 1, 0))
710715
if args.verbose:
711716
print(f"Dumps final file into {args.output!r}")
712-
cube.to_excel(args.output, {k: k for k in args.views.split(",")}, verbose=args.verbose)
717+
cube.to_excel(
718+
args.output,
719+
{k: k for k in args.views.split(",")},
720+
verbose=args.verbose,
721+
csv=args.csv.split(","),
722+
)
713723
if args.verbose:
714724
print(f"Wrote {args.output!r}")
715725

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Any, Set
2+
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3+
4+
5+
def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
6+
"""
7+
Returns the dynamic shapes for the given inputs.
8+
All dimensions are considered as dynamic.
9+
``dim_prefix`` can be a string (the function uses it as a prefix),
10+
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
11+
12+
.. runpython::
13+
:showcode:
14+
15+
import pprint
16+
import torch
17+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
18+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
19+
20+
bsize, nheads, slen, dim = 2, 1, 30, 96
21+
inputs = dict(
22+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
23+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
24+
position_ids=torch.arange(3, dtype=torch.int64),
25+
past_key_values=make_dynamic_cache(
26+
[(torch.randn(bsize, nheads, slen, dim),
27+
torch.randn(bsize, nheads, slen, dim))]
28+
),
29+
)
30+
ds = all_dynamic_shape_from_inputs(inputs)
31+
pprint.pprint(ds)
32+
"""
33+
if isinstance(dim_prefix, str):
34+
prefixes: Set[str] = set()
35+
36+
def tensor_to_shape(tensor):
37+
n = len(prefixes)
38+
p = f"{dim_prefix}_{n}"
39+
prefixes.add(p)
40+
return {i: f"{p}_{i}" for i in range(tensor.ndim)}
41+
42+
else:
43+
44+
def tensor_to_shape(tensor):
45+
return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420
46+
47+
return flatten_unflatten_for_dynamic_shapes(
48+
inputs, change_function=tensor_to_shape, use_dict=True
49+
)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from typing import Any, List, Tuple
1+
from typing import Any, Callable, List, Optional, Tuple
22
import packaging.version as pv
33
import torch
44
import transformers
55
import transformers.cache_utils
66

77

8-
def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any:
8+
def flatten_unflatten_for_dynamic_shapes(
9+
obj: Any,
10+
use_dict: bool = False,
11+
change_function: Optional[Callable[[torch.Tensor], Any]] = None,
12+
) -> Any:
913
"""
1014
Returns the object in a different structure similar to what
1115
the definition of the dynamic shapes should use.
@@ -16,18 +20,22 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
1620
the context gives the dictionary keys but it is not expressed
1721
in the dynamic shapes, these specifications seems to be different
1822
for the strict and non strict mode.
23+
:param change_function: to modifies the tensor in the structure itself,
24+
like replace them by a shape
1925
:return: the serialized object
2026
"""
2127
if isinstance(obj, torch.Tensor):
22-
return obj
28+
return change_function(obj) if change_function else obj
2329
flat, spec = torch.utils._pytree.tree_flatten(obj)
2430
start = 0
2531
end = 0
2632
subtrees = []
2733
for subspec in spec.children_specs:
2834
end += subspec.num_leaves
2935
value = subspec.unflatten(flat[start:end])
30-
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
36+
value = flatten_unflatten_for_dynamic_shapes(
37+
value, use_dict=use_dict, change_function=change_function
38+
)
3139
subtrees.append(value)
3240
start = end
3341
if use_dict and (spec.type is dict or spec.context):

0 commit comments

Comments
 (0)