Skip to content

Commit dbfd255

Browse files
authored
Add text2text_generation (#24)
* add text2text_generation * better * fix mypy * fix unit trest * add function to convert dynamic_axes * doc * fix function * version * gh * update * copy
1 parent d5c802c commit dbfd255

File tree

19 files changed

+930
-106
lines changed

19 files changed

+930
-106
lines changed

CHANGELOGS.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Change Logs
44
0.3.0
55
+++++
66

7+
* :pr:`24`: dummy inputs for ``text2text-generation``, add new function
8+
``convert_dynamic_axes_into_dynamic_shapes`` to convert dynamic axes
9+
into dynamic shapes, add support for ``T5ForConditionalGeneration``
710
* :pr:`23`: dummy inputs for ``image-classification``
811
* :pr:`22`: api to create untrained model copying the architecture
912
of the trained models and dummy inputs for them,

_doc/api/torch_export_patches/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ onnx_diagnostic.torch_export_patches
66
:caption: submodules
77

88
patches/index
9+
patch_inputs
10+
911

1012
.. automodule:: onnx_diagnostic.torch_export_patches
1113
:members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.patch_inputs
3+
=================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_inputs
6+
:members:
7+
:no-undoc-members:

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
("py:class", "transformers.LlamaConfig"),
122122
("py:class", "transformers.cache_utils.Cache"),
123123
("py:class", "transformers.cache_utils.DynamicCache"),
124+
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
124125
("py:class", "transformers.cache_utils.MambaCache"),
125126
("py:func", "torch.export._draft_export.draft_export"),
126127
("py:func", "torch._export.tools.report_exportability"),

_doc/examples/plot_export_tiny_llm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from onnx_diagnostic import doc
3333
from onnx_diagnostic.helpers import string_type
3434
from onnx_diagnostic.torch_models.llms import get_tiny_llm
35+
from onnx_diagnostic.torch_test_helper import steel_forward
3536

3637

3738
MODEL_NAME = "arnir0/Tiny-LLM"
@@ -49,7 +50,7 @@ def _forward_(*args, _f=None, **kwargs):
4950
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
5051
res = _f(*args, **kwargs)
5152
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
52-
print("->", string_type((args, kwargs), with_shape=True, with_min_max=True))
53+
print("->", string_type(res, with_shape=True, with_min_max=True))
5354
return res
5455

5556

@@ -75,6 +76,12 @@ def _forward_(*args, _f=None, **kwargs):
7576
# Let's restore the forward as it was.
7677
model.forward = keep_model_forward
7778

79+
# %%
80+
# Another syntax with :func:`onnx_diagnostic.torch_test_helper.steel_forward`.
81+
82+
with steel_forward(model):
83+
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
84+
7885
# %%
7986
# Untrained model
8087
# +++++++++++++++
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import unittest
2+
import torch
3+
import transformers
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
5+
from onnx_diagnostic.helpers import string_type
6+
from onnx_diagnostic.torch_export_patches.patch_inputs import (
7+
convert_dynamic_axes_into_dynamic_shapes,
8+
)
9+
10+
11+
class TestPatchInputs(ExtTestCase):
12+
@hide_stdout()
13+
def test_convert_dynamic_axes_into_dynamic_shapes_1(self):
14+
args = (
15+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
16+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
17+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
18+
[(torch.rand((2, 1, 3, 96)), torch.rand((2, 1, 3, 96)))],
19+
)
20+
dynamic_axes = {
21+
"attention_mask": {0: "batch_size", 1: "total_sequence_length"},
22+
"input_ids": {0: "batch_size", 1: "sequence_length"},
23+
"logits": {0: "batch_size", 1: "sequence_length"},
24+
"past_key_values.0.key": {0: "batch_size", 2: "past_sequence_length"},
25+
"past_key_values.0.value": {0: "batch_size", 2: "past_sequence_length"},
26+
"position_ids": {0: "batch_size", 1: "sequence_length"},
27+
"present.0.key": {0: "batch_size", 2: "total_sequence_length"},
28+
"present.0.value": {0: "batch_size", 2: "total_sequence_length"},
29+
}
30+
31+
model_cls = transformers.LlamaModel
32+
res = convert_dynamic_axes_into_dynamic_shapes(
33+
model_cls, args=args, dynamic_axes=dynamic_axes, verbose=1
34+
)
35+
self.assertEqual((), res[0])
36+
self.assertEqual(
37+
(
38+
"dict(input_ids:T7s2x8,attention_mask:T7s2x8,position_ids:T7s2x8,"
39+
"past_key_values:DynamicCache(key_cache=#1[T1s2x1x3x96], "
40+
"value_cache=#1[T1s2x1x3x96]))"
41+
),
42+
string_type(res[1], with_shape=True),
43+
)
44+
self.assertEqual(
45+
{
46+
"attention_mask": {0: "batch_size", 1: "total_sequence_length"},
47+
"input_ids": {0: "batch_size", 1: "sequence_length"},
48+
"past_key_values": [
49+
[{0: "batch_size", 2: "past_sequence_length"}],
50+
[{0: "batch_size", 2: "past_sequence_length"}],
51+
],
52+
"position_ids": {0: "batch_size", 1: "sequence_length"},
53+
},
54+
res[2],
55+
)
56+
57+
@hide_stdout()
58+
def test_convert_dynamic_axes_into_dynamic_shapes_2(self):
59+
args = (
60+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
61+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
62+
torch.randint(0, 10, size=(2, 8)).to(torch.int64),
63+
[(torch.rand((2, 1, 3, 96)), torch.rand((2, 1, 3, 96)))],
64+
)
65+
dynamic_axes = {
66+
"input_ids": {0: "batch_size", 1: "sequence_length"},
67+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
68+
"position_ids": {0: "batch_size", 1: "sequence_length"},
69+
"logits": {0: "batch_size", 1: "sequence_length"},
70+
"present.0.key": {0: "batch_size", 2: "past_sequence_length"},
71+
"present.0.value": {0: "batch_size", 2: "past_sequence_length"},
72+
}
73+
74+
model_cls = transformers.LlamaModel
75+
res = convert_dynamic_axes_into_dynamic_shapes(
76+
model_cls,
77+
args=args,
78+
dynamic_axes=dynamic_axes,
79+
verbose=1,
80+
prefix_mapping={"present": "past_key_values"},
81+
)
82+
self.assertEqual((), res[0])
83+
self.assertEqual(
84+
{"attention_mask", "input_ids", "past_key_values", "position_ids"}, set(res[2])
85+
)
86+
self.assertEqual(
87+
[
88+
[{0: "batch_size", 2: "past_sequence_length"}],
89+
[{0: "batch_size", 2: "past_sequence_length"}],
90+
],
91+
res[2]["past_key_values"],
92+
)
93+
self.assertEqual(
94+
{
95+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
96+
"input_ids": {0: "batch_size", 1: "sequence_length"},
97+
"past_key_values": [
98+
[{0: "batch_size", 2: "past_sequence_length"}],
99+
[{0: "batch_size", 2: "past_sequence_length"}],
100+
],
101+
"position_ids": {0: "batch_size", 1: "sequence_length"},
102+
},
103+
res[2],
104+
)
105+
self.assertEqual(
106+
(
107+
"dict(input_ids:T7s2x8,attention_mask:T7s2x8,position_ids:T7s2x8,"
108+
"past_key_values:DynamicCache(key_cache=#1[T1s2x1x3x96], "
109+
"value_cache=#1[T1s2x1x3x96]))"
110+
),
111+
string_type(res[1], with_shape=True),
112+
)
113+
114+
115+
if __name__ == "__main__":
116+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from onnx_diagnostic.ext_test_case import (
55
ExtTestCase,
66
hide_stdout,
7-
long_test,
87
requires_torch,
98
requires_transformers,
109
)
@@ -79,7 +78,7 @@ def test_get_untrained_model_with_inputs_beit(self):
7978
model, inputs = data["model"], data["inputs"]
8079
model(**inputs)
8180
# different expected value for different version of transformers
82-
self.assertIn((data["size"], data["n_weights"]), [(111448, 27862)])
81+
self.assertIn((data["size"], data["n_weights"]), [(111448, 27862), (56880, 14220)])
8382

8483
@hide_stdout()
8584
def test_get_untrained_model_with_inputs_codellama(self):
@@ -91,8 +90,18 @@ def test_get_untrained_model_with_inputs_codellama(self):
9190
self.assertIn((data["size"], data["n_weights"]), [(410532864, 102633216)])
9291

9392
@hide_stdout()
94-
@long_test()
93+
def test_get_untrained_model_with_inputs_text2text_generation(self):
94+
mid = "sshleifer/tiny-marian-en-de"
95+
# mid = "Salesforce/codet5-small"
96+
data = get_untrained_model_with_inputs(mid, verbose=1)
97+
self.assertIn((data["size"], data["n_weights"]), [(473928, 118482)])
98+
model, inputs = data["model"], data["inputs"]
99+
raise unittest.SkipTest(f"not working for {mid!r}")
100+
model(**inputs)
101+
102+
@hide_stdout()
95103
def test_get_untrained_model_Ltesting_models(self):
104+
# UNHIDE=1 python _unittests/ut_torch_models/test_hghub_model.py -k L -f
96105
def _diff(c1, c2):
97106
rows = [f"types {c1.__class__.__name__} <> {c2.__class__.__name__}"]
98107
for k, v in c1.__dict__.items():
@@ -102,11 +111,22 @@ def _diff(c1, c2):
102111
rows.append(f"{k} :: -- {v} ++ {getattr(c2, k, 'MISS')}")
103112
return "\n".join(rows)
104113

105-
# UNHIDE=1 LONGTEST=1 python _unittests/ut_torch_models/test_hghub_model.py -k L -f
106114
for mid in load_models_testing():
107115
with self.subTest(mid=mid):
116+
if mid in {
117+
"hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation",
118+
"hf-internal-testing/tiny-random-MoonshineForConditionalGeneration",
119+
"fxmarty/pix2struct-tiny-random",
120+
"hf-internal-testing/tiny-random-ViTMSNForImageClassification",
121+
"hf-internal-testing/tiny-random-YolosModel",
122+
}:
123+
print(f"-- not implemented yet for {mid!r}")
124+
continue
108125
data = get_untrained_model_with_inputs(mid, verbose=1)
109126
model, inputs = data["model"], data["inputs"]
127+
if mid in {"sshleifer/tiny-marian-en-de"}:
128+
print(f"-- not fully implemented yet for {mid!r}")
129+
continue
110130
try:
111131
model(**inputs)
112132
except Exception as e:

_unittests/ut_torch_models/try_models.py

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
3+
from onnx_diagnostic.helpers import string_type
4+
from onnx_diagnostic.torch_test_helper import steel_forward
5+
6+
7+
class TestHuggingFaceHubModel(ExtTestCase):
8+
@never_test()
9+
def test_image_classiciation(self):
10+
# clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k image_c
11+
12+
from transformers import ViTImageProcessor, ViTModel
13+
from PIL import Image
14+
import requests
15+
16+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
17+
image = Image.open(requests.get(url, stream=True).raw)
18+
19+
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
20+
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
21+
inputs = processor(images=image, return_tensors="pt")
22+
print()
23+
print("-- inputs", string_type(inputs, with_shape=True, with_min_max=True))
24+
25+
outputs = model(**inputs)
26+
print("-- outputs", string_type(outputs, with_shape=True, with_min_max=True))
27+
28+
@never_test()
29+
def test_text2text_generation(self):
30+
# clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k text2t
31+
32+
import torch
33+
from transformers import RobertaTokenizer, T5ForConditionalGeneration
34+
35+
tokenizer = RobertaTokenizer.from_pretrained("Salesforce/codet5-small")
36+
model = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
37+
38+
text = "def greet(user): print(f'hello <extra_id_0>!')"
39+
input_ids = tokenizer(text, return_tensors="pt").input_ids
40+
mask = (
41+
torch.tensor([1 for i in range(input_ids.shape[1])])
42+
.to(torch.int64)
43+
.reshape((1, -1))
44+
)
45+
46+
# simply generate a single sequence
47+
print()
48+
with steel_forward(model):
49+
generated_ids = model.generate(
50+
decoder_input_ids=input_ids, attention_mask=mask, max_length=100
51+
)
52+
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_helpers.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
rename_dynamic_dimensions,
3737
rename_dynamic_expression,
3838
)
39-
from onnx_diagnostic.cache_helpers import make_dynamic_cache
39+
from onnx_diagnostic.cache_helpers import make_dynamic_cache, make_encoder_decoder_cache
4040

4141
TFLOAT = onnx.TensorProto.FLOAT
4242

@@ -164,6 +164,8 @@ def test_flatten(self):
164164
},
165165
],
166166
)
167+
diff = max_diff(inputs, inputs, flatten=True, verbose=10)
168+
self.assertEqual(diff["abs"], 0)
167169
flat = flatten_object(inputs, drop_keys=True)
168170
diff = max_diff(inputs, flat, flatten=True, verbose=10)
169171
self.assertEqual(diff["abs"], 0)
@@ -442,6 +444,32 @@ def test_from_tensor(self):
442444
convert_endian(proto)
443445
dtype_to_tensor_dtype(dt)
444446

447+
@hide_stdout()
448+
def test_flatten_encoder_decoder_cache(self):
449+
inputs = (
450+
torch.rand((3, 4), dtype=torch.float16),
451+
[
452+
torch.rand((5, 6), dtype=torch.float16),
453+
torch.rand((5, 6, 7), dtype=torch.float16),
454+
{
455+
"a": torch.rand((2,), dtype=torch.float16),
456+
"cache": make_encoder_decoder_cache(
457+
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
458+
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
459+
),
460+
},
461+
],
462+
)
463+
diff = max_diff(inputs, inputs, flatten=True, verbose=10)
464+
self.assertEqual(diff["abs"], 0)
465+
flat = flatten_object(inputs, drop_keys=True)
466+
diff = max_diff(inputs, flat, flatten=True, verbose=10)
467+
self.assertEqual(diff["abs"], 0)
468+
d = string_diff(diff)
469+
self.assertIsInstance(d, str)
470+
s = string_type(inputs)
471+
self.assertIn("EncoderDecoderCache", s)
472+
445473

446474
if __name__ == "__main__":
447475
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)