|
4 | 4 |
|
5 | 5 | def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: |
6 | 6 | """ |
7 | | - Returns the dyanmic shapes for the given inputs. |
| 7 | + Returns the dynamic shapes for the given inputs. |
8 | 8 | All dimensions are considered as dynamic. |
9 | 9 | ``dim_prefix`` can be a string (the function uses it as a prefix), |
10 | 10 | or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``. |
| 11 | +
|
| 12 | + .. runpython:: |
| 13 | + :showcode: |
| 14 | +
|
| 15 | + import pprint |
| 16 | + import torch |
| 17 | + from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache |
| 18 | + from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs |
| 19 | +
|
| 20 | + bsize, nheads, slen, dim = 2, 1, 30, 96 |
| 21 | + inputs = dict( |
| 22 | + input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), |
| 23 | + attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), |
| 24 | + position_ids=torch.arange(3, dtype=torch.int64), |
| 25 | + past_key_values=make_dynamic_cache( |
| 26 | + [(torch.randn(bsize, nheads, slen, dim), |
| 27 | + torch.randn(bsize, nheads, slen, dim))] |
| 28 | + ), |
| 29 | + ) |
| 30 | + ds = all_dynamic_shape_from_inputs(inputs) |
| 31 | + pprint.pprint(ds) |
11 | 32 | """ |
12 | 33 | if isinstance(dim_prefix, str): |
13 | 34 | prefixes: Set[str] = set() |
|
0 commit comments