Skip to content

Commit 2e874f1

Browse files
committed
better tests
1 parent 28410cd commit 2e874f1

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`61`: improves dynamic shapes for EncoderDecoderCache
78
* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``,
89
use string instead of ``torch.export.Dim.DYNAMIC`` when returning the dynamic shapes
910
for a specific models, it is a valid definition for ``torch.onnx.export``

_unittests/ut_tasks/test_tasks.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ def test_text2text_generation(self):
1111
mid = "sshleifer/tiny-marian-en-de"
1212
data = get_untrained_model_with_inputs(mid, verbose=1)
1313
self.assertIn((data["size"], data["n_weights"]), [(473928, 118482)])
14-
model, inputs = data["model"], data["inputs"]
14+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
1515
raise unittest.SkipTest(f"not working for {mid!r}")
1616
model(**inputs)
17+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
18+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
1719

1820
@hide_stdout()
1921
def test_automatic_speech_recognition(self):
@@ -86,41 +88,50 @@ def test_imagetext2text_generation(self):
8688
mid = "HuggingFaceM4/tiny-random-idefics"
8789
data = get_untrained_model_with_inputs(mid, verbose=1)
8890
self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)])
89-
model, inputs = data["model"], data["inputs"]
91+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
9092
model(**inputs)
93+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
94+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
9195

9296
@hide_stdout()
9397
def test_fill_mask(self):
9498
mid = "google-bert/bert-base-multilingual-cased"
9599
data = get_untrained_model_with_inputs(mid, verbose=1)
96100
self.assertIn((data["size"], data["n_weights"]), [(428383212, 107095803)])
97-
model, inputs = data["model"], data["inputs"]
101+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
98102
model(**inputs)
103+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
104+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
99105

100106
@hide_stdout()
101107
def test_text_classification(self):
102108
mid = "Intel/bert-base-uncased-mrpc"
103109
data = get_untrained_model_with_inputs(mid, verbose=1)
104110
self.assertIn((data["size"], data["n_weights"]), [(154420232, 38605058)])
105-
model, inputs = data["model"], data["inputs"]
111+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
106112
model(**inputs)
113+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
114+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
107115

108116
@hide_stdout()
109117
def test_sentence_similary(self):
110118
mid = "sentence-transformers/all-MiniLM-L6-v1"
111119
data = get_untrained_model_with_inputs(mid, verbose=1)
112120
self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)])
113-
model, inputs = data["model"], data["inputs"]
121+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
114122
model(**inputs)
123+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
124+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
115125

116126
@hide_stdout()
117127
def test_falcon_mamba_dev(self):
118128
mid = "tiiuae/falcon-mamba-tiny-dev"
119129
data = get_untrained_model_with_inputs(mid, verbose=1)
120-
model, inputs = data["model"], data["inputs"]
121-
print(self.string_type(inputs, with_shape=True))
130+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
122131
model(**inputs)
123132
self.assertIn((data["size"], data["n_weights"]), [(138640384, 34660096)])
133+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
134+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
124135

125136

126137
if __name__ == "__main__":

0 commit comments

Comments
 (0)