Skip to content

Commit 164ddb8

Browse files
committed
doc
1 parent bed8fe5 commit 164ddb8

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

onnx_diagnostic/export/shape_helper.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,31 @@
44

55
def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
66
"""
7-
Returns the dyanmic shapes for the given inputs.
7+
Returns the dynamic shapes for the given inputs.
88
All dimensions are considered as dynamic.
99
``dim_prefix`` can be a string (the function uses it as a prefix),
1010
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
11+
12+
.. runpython::
13+
:showcode:
14+
15+
import pprint
16+
import torch
17+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
18+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
19+
20+
bsize, nheads, slen, dim = 2, 1, 30, 96
21+
inputs = dict(
22+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
23+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
24+
position_ids=torch.arange(3, dtype=torch.int64),
25+
past_key_values=make_dynamic_cache(
26+
[(torch.randn(bsize, nheads, slen, dim),
27+
torch.randn(bsize, nheads, slen, dim))]
28+
),
29+
)
30+
ds = all_dynamic_shape_from_inputs(inputs)
31+
pprint.pprint(ds)
1132
"""
1233
if isinstance(dim_prefix, str):
1334
prefixes: Set[str] = set()

0 commit comments

Comments
 (0)