|
2 | 2 | from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
3 | 3 | import numpy as np |
4 | 4 | import torch |
5 | | -from ..helpers import string_type |
| 5 | +from ..helpers import string_type, flatten_object |
6 | 6 | from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes |
| 7 | +from ..helpers.fake_tensor_helper import make_fake |
7 | 8 |
|
8 | 9 | DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]] |
9 | 10 |
|
@@ -32,6 +33,51 @@ def _flat_list(li: List[Any]) -> List[Dict[int, str]]: |
32 | 33 | return res |
33 | 34 |
|
34 | 35 |
|
| 36 | +def make_fake_with_dynamic_dimensions( |
| 37 | + x: Optional[Any], |
| 38 | + dynamic_shapes: Any, |
| 39 | + fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821 |
| 40 | +) -> Optional[Tuple["FakeTensor", "FaleTensorMode"]]: # noqa: F821 |
| 41 | + """ |
| 42 | + Replaces all tensors by fake tensor respecting the same |
| 43 | + constraints as the following dynamic shapes. |
| 44 | + This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`. |
| 45 | +
|
| 46 | + .. runpython:: |
| 47 | + :showcode: |
| 48 | +
|
| 49 | + from onnx_diagnostic.export.dynamic_shapes import make_fake_with_dynamic_dimensions |
| 50 | +
|
| 51 | + inputs, _ = make_fake_with_dynamic_dimensions( |
| 52 | + dict( |
| 53 | + input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64), |
| 54 | + attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), |
| 55 | + position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64), |
| 56 | + past_key_values=make_dynamic_cache( |
| 57 | + [ |
| 58 | + ( |
| 59 | + torch.rand((2, 32, 30, 96), dtype=torch.float16), |
| 60 | + torch.rand((2, 32, 30, 96), dtype=torch.float16), |
| 61 | + ), |
| 62 | + ( |
| 63 | + torch.rand((2, 32, 30, 96), dtype=torch.float16), |
| 64 | + torch.rand((2, 32, 30, 96), dtype=torch.float16), |
| 65 | + ), |
| 66 | + ] |
| 67 | + ), |
| 68 | + ) |
| 69 | + ) |
| 70 | + print(inputs) |
| 71 | + """ |
| 72 | + fake_inputs = make_fake(x, fake_mode=fake_mode) |
| 73 | + flat_inputs = flatten_object(fake_inputs, drop_keys=True) |
| 74 | + flat_ds = flatten_dynamic_shapes(dynamic_shapes) |
| 75 | + assert len(flat_inputs) == len(flat_ds), ( |
| 76 | + f"Mismatch between the number of input tensor {len(flat_inputs)} " |
| 77 | + f"and the number of dynamic_shapes {len(flat_ds)}" |
| 78 | + ) |
| 79 | + |
| 80 | + |
35 | 81 | class CoupleInputsDynamicShapes: |
36 | 82 | """ |
37 | 83 | Pair inputs / dynamic shapes. |
|
0 commit comments