Skip to content

Commit 752d014

Browse files
committed
Improve patches for transformers
1 parent f6ad410 commit 752d014

File tree

11 files changed

+477
-249
lines changed

11 files changed

+477
-249
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
strategy:
1616
matrix:
1717
os: [ubuntu-latest]
18-
python: ['3.12']
18+
python: ['3.11', '3.12']
1919
transformers: ['4.48', 'main']
2020

2121
steps:

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,13 @@ def forward(self, cache, z):
210210
# The export is simple if ``transformers>=4.50``, otherwise,
211211
# transformers needs to be patched.
212212
# :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
213-
# registers functions to serialize ``DynamicCache`` and another class
214-
# called ``patched_DynamicCache``. This one is modified to make
213+
# registers functions to serialize ``DynamicCache``. This one is modified to make
215214
# the shape inference implemented in :epkg:`torch` happy.
216215

217216
if has_transformers("4.50"):
218217
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
219218
else:
220-
with bypass_export_some_errors(
221-
patch_transformers=True, replace_dynamic_cache=True
222-
) as modificator:
219+
with bypass_export_some_errors(patch_transformers=True) as modificator:
223220
ep = torch.export.export(
224221
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
225222
)

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
7979
model = Model()
8080
model(x, cache)
8181

82-
with bypass_export_some_errors(replace_dynamic_cache=True, verbose=1):
82+
with bypass_export_some_errors(verbose=1):
8383
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
8484
torch.export.export(Model(), (x, cache))
8585

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
4+
5+
class TestPatchBaseClass(ExtTestCase):
6+
def test_check_that_trick_can_work_in_python(self):
7+
class zero:
8+
def ret(self, a):
9+
return a - 100
10+
11+
def ok(self):
12+
return self.ret(3)
13+
14+
class A(zero):
15+
def ret(self, a):
16+
return a + 1
17+
18+
class B:
19+
def ret(self, a):
20+
return a + 10
21+
22+
z = zero()
23+
self.assertEqual(z.ret(4), -96)
24+
self.assertEqual(z.ok(), -97)
25+
a = A()
26+
self.assertEqual(a.ret(4), 5)
27+
self.assertEqual(a.ok(), 4)
28+
b = B()
29+
self.assertEqual(b.ret(4), 14)
30+
self.assertFalse(hasattr(b, "ok"))
31+
self.assertFalse(hasattr(B, "ok"))
32+
33+
self.assertEqual(A.__bases__, (zero,))
34+
A.__bases__ = (zero, B)
35+
self.assertEqual(a.ret(4), 5)
36+
self.assertEqual(a.ok(), 4)
37+
aa = A()
38+
self.assertEqual(aa.ret(4), 5)
39+
self.assertEqual(aa.ok(), 4)
40+
41+
A.__bases__ = (B, zero)
42+
self.assertEqual(a.ret(4), 5)
43+
self.assertEqual(a.ok(), 4)
44+
aa = A()
45+
self.assertEqual(aa.ret(4), 5)
46+
self.assertEqual(aa.ok(), 4)
47+
48+
A.__bases__ = (zero,)
49+
A.ret = B.ret
50+
self.assertEqual(aa.ret(4), 14)
51+
self.assertEqual(aa.ok(), 13)
52+
self.assertEqual(a.ret(4), 14)
53+
self.assertEqual(a.ok(), 13)
54+
55+
56+
if __name__ == "__main__":
57+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def test_export_tiny_llm_2_bypassed(self):
2929
data = get_tiny_llm()
3030
model, inputs = data["model"], data["inputs"]
3131
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
32-
with bypass_export_some_errors(
33-
patch_transformers=True, replace_dynamic_cache=True
34-
) as modificator:
32+
with bypass_export_some_errors(patch_transformers=True) as modificator:
3533
inputs = modificator(inputs)
3634
ep = torch.export.export(
3735
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@ def test_bypass_onnx_export_tiny_llm_official(self):
5757
data = get_tiny_llm()
5858
model, inputs = data["model"], data["inputs"]
5959
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
60-
with bypass_export_some_errors(
61-
patch_transformers=True, replace_dynamic_cache=True, verbose=1
62-
) as modificator:
60+
with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator:
6361
new_inputs = modificator(inputs)
6462
ep = torch.onnx.export(
6563
model,
@@ -80,9 +78,7 @@ def test_bypass_onnx_export_tiny_llm_xdbg(self):
8078
data = get_tiny_llm()
8179
model, inputs = data["model"], data["inputs"]
8280
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
83-
with bypass_export_some_errors(
84-
patch_transformers=True, replace_dynamic_cache=True, verbose=1
85-
) as modificator:
81+
with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator:
8682
new_inputs = modificator(inputs)
8783
onx = to_onnx(
8884
model, (), kwargs=new_inputs, dynamic_shapes=data["dynamic_shapes"], verbose=1

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None)
311311
shapes[i] = self.guess_dynamic_shape_object(*[o[i] for o in objs])
312312
return shapes
313313

314-
if obj.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"):
314+
if obj.__class__.__name__ == "DynamicCache":
315315
kc = set(len(o.key_cache) for o in objs)
316316
assert (
317317
len(kc) == 1

onnx_diagnostic/helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def string_type(
408408
if type(obj).__name__ == "ValueInfoProto":
409409
return f"OT{obj.type.tensor_type.elem_type}"
410410

411-
if obj.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"):
411+
if obj.__class__.__name__ == "DynamicCache":
412412
kc = string_type(
413413
obj.key_cache,
414414
with_shape=with_shape,
@@ -1693,8 +1693,8 @@ def max_diff(
16931693
flatten=flatten,
16941694
)
16951695

1696-
if expected.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"):
1697-
if got.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"):
1696+
if expected.__class__.__name__ == "DynamicCache":
1697+
if got.__class__.__name__ == "DynamicCache":
16981698
if verbose >= 6:
16991699
print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
17001700
return max_diff(

0 commit comments

Comments
 (0)