Skip to content

Commit 02b212d

Browse files
committed
lint
1 parent 58eccb2 commit 02b212d

File tree

6 files changed

+64
-9
lines changed

6 files changed

+64
-9
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.16
55
++++++
66

7+
* :pr:`272`: makes patches woth with FakeTensor
78
* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs
89
* :pr:`269`: adds one unit test to track a patch fixing broadcast output shape
910
* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)

_unittests/ut_helpers/test_fake_tensor_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
44
from onnx_diagnostic.helpers import flatten_object
55
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
66
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
77

88

99
class TestMakeTensorHelper(ExtTestCase):
10+
@requires_transformers("4.55")
1011
def test_fake_inputs(self):
1112
inputs, _ = make_fake(
1213
dict(

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def main(argv: Optional[List[Any]] = None):
11281128
raise ValueError(
11291129
f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
11301130
)
1131-
parser = parsers[cmd]()
1131+
parser = parsers[cmd]() # type: ignore[operator]
11321132
parser.parse_args(argv[1:])
11331133
raise RuntimeError("The programme should have exited before.")
11341134

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
33
import numpy as np
44
import torch
5-
from ..helpers import string_type
5+
from ..helpers import string_type, flatten_object
66
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
7+
from ..helpers.fake_tensor_helper import make_fake
78

89
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
910

@@ -32,6 +33,51 @@ def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
3233
return res
3334

3435

36+
def make_fake_with_dynamic_dimensions(
37+
x: Optional[Any],
38+
dynamic_shapes: Any,
39+
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
40+
) -> Optional[Tuple["FakeTensor", "FaleTensorMode"]]: # noqa: F821
41+
"""
42+
Replaces all tensors by fake tensor respecting the same
43+
constraints as the following dynamic shapes.
44+
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
45+
46+
.. runpython::
47+
:showcode:
48+
49+
from onnx_diagnostic.export.dynamic_shapes import make_fake_with_dynamic_dimensions
50+
51+
inputs, _ = make_fake_with_dynamic_dimensions(
52+
dict(
53+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
54+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
55+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
56+
past_key_values=make_dynamic_cache(
57+
[
58+
(
59+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
60+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
61+
),
62+
(
63+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
64+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
65+
),
66+
]
67+
),
68+
)
69+
)
70+
print(inputs)
71+
"""
72+
fake_inputs = make_fake(x, fake_mode=fake_mode)
73+
flat_inputs = flatten_object(fake_inputs, drop_keys=True)
74+
flat_ds = flatten_dynamic_shapes(dynamic_shapes)
75+
assert len(flat_inputs) == len(flat_ds), (
76+
f"Mismatch between the number of input tensor {len(flat_inputs)} "
77+
f"and the number of dynamic_shapes {len(flat_ds)}"
78+
)
79+
80+
3581
class CoupleInputsDynamicShapes:
3682
"""
3783
Pair inputs / dynamic shapes.

onnx_diagnostic/helpers/fake_tensor_helper.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33

44
def make_fake(
5-
x: Any, fake_mode: Optional["FakeTensorMode"] = None # noqa: F821
6-
) -> Tuple["FakeTensor", "FaleTensorMode"]: # noqa: F821
5+
x: Optional[Any], fake_mode: Optional["FakeTensorMode"] = None # noqa: F821
6+
) -> Optional[Tuple["FakeTensor", "FaleTensorMode"]]: # noqa: F821
77
"""
88
Replaces all tensors by fake tensors.
99
This modification happens inplace for caches.
10+
This function is only implemented for cache with
11+
``transformers>=4.55``.
1012
1113
.. runpython::
1214
:showcode:
@@ -49,12 +51,13 @@ def make_fake(
4951
return {k: make_fake(v, fake_mode=fake_mode)[0] for k, v in x.items()}, fake_mode
5052

5153
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
52-
assert hasattr(
53-
x, "layers"
54-
), f"Une more recent version of transformers, 'layers' not found in class {type(x)}"
54+
assert hasattr(x, "layers"), (
55+
f"Une more recent version of transformers (>=4.55), "
56+
f"'layers' not found in class {type(x)}"
57+
)
5558
for layer in x.layers:
5659
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
57-
f"Une more recent version of transformers, 'layers' "
60+
f"Une more recent version of transformers (>=4.55), 'layers' "
5861
f"not found in class {type(layer)} ({dir(layer)})"
5962
)
6063
layer.keys = make_fake(layer.keys, fake_mode=fake_mode)[0]

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ disable_error_code = ["name-defined"]
5858
module = ["onnx_diagnostic.helpers.helper"]
5959
disable_error_code = ["arg-type", "assignment", "attr-defined", "call-overload", "misc", "name-defined", "union-attr"]
6060

61+
[[tool.mypy.overrides]]
62+
module = ["onnx_diagnostic.helpers.fake_tensor_helper"]
63+
disable_error_code = ["name-defined"]
64+
6165
[[tool.mypy.overrides]]
6266
module = ["onnx_diagnostic.helpers.model_builder_helper"]
6367
disable_error_code = ["attr-defined", "import-untyped", "name-defined", "union-attr", "var-annotated"]

0 commit comments

Comments
 (0)