Skip to content

Commit 8e07a7a

Browse files
committed
split files
1 parent 20d83db commit 8e07a7a

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

_unittests/ut_torch_models/test_validate_whole_models.py renamed to _unittests/ut_torch_models/test_validate_whole_models1.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from onnx_diagnostic.tasks import supported_tasks
2323

2424

25-
class TestValidateWholeModels(ExtTestCase):
25+
class TestValidateWholeModels1(ExtTestCase):
2626
def test_a_get_inputs_for_task(self):
2727
fcts = supported_tasks()
2828
for task in self.subloop(sorted(fcts)):
@@ -260,30 +260,6 @@ def test_n_validate_phi35_mini_instruct(self):
260260
op_types = set(n.op_type for n in onx.graph.node)
261261
self.assertIn("If", op_types)
262262

263-
@requires_torch("2.9")
264-
@hide_stdout()
265-
@ignore_warnings(FutureWarning)
266-
@requires_transformers("4.55")
267-
def test_o_validate_phi35_4k_mini_instruct(self):
268-
mid = "microsoft/Phi-3-mini-4k-instruct"
269-
summary, data = validate_model(
270-
mid,
271-
do_run=True,
272-
verbose=10,
273-
exporter="custom",
274-
dump_folder="dump_test/validate_phi35_mini_instruct",
275-
inputs2=True,
276-
patch=True,
277-
rewrite=True,
278-
model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}},
279-
)
280-
self.assertIsInstance(summary, dict)
281-
self.assertIsInstance(data, dict)
282-
onnx_filename = data["onnx_filename"]
283-
onx = onnx.load(onnx_filename)
284-
op_types = set(n.op_type for n in onx.graph.node)
285-
self.assertIn("If", op_types)
286-
287263

288264
if __name__ == "__main__":
289265
unittest.main(verbosity=2)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
import onnx
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
ignore_warnings,
7+
requires_torch,
8+
requires_transformers,
9+
)
10+
from onnx_diagnostic.torch_models.validate import validate_model
11+
12+
13+
class TestValidateWholeModels2(ExtTestCase):
14+
@requires_torch("2.9")
15+
@hide_stdout()
16+
@ignore_warnings(FutureWarning)
17+
@requires_transformers("4.55")
18+
def test_o_validate_phi35_4k_mini_instruct(self):
19+
mid = "microsoft/Phi-3-mini-4k-instruct"
20+
summary, data = validate_model(
21+
mid,
22+
do_run=True,
23+
verbose=10,
24+
exporter="custom",
25+
dump_folder="dump_test/validate_phi35_mini_instruct",
26+
inputs2=True,
27+
patch=True,
28+
rewrite=True,
29+
model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}},
30+
)
31+
self.assertIsInstance(summary, dict)
32+
self.assertIsInstance(data, dict)
33+
onnx_filename = data["onnx_filename"]
34+
onx = onnx.load(onnx_filename)
35+
op_types = set(n.op_type for n in onx.graph.node)
36+
self.assertIn("If", op_types)
37+
38+
39+
if __name__ == "__main__":
40+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)