Skip to content

Commit d787152

Browse files
committed
check llms separatly
1 parent b8a1b98 commit d787152

File tree

5 files changed

+40
-7
lines changed

5 files changed

+40
-7
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ jobs:
5252
- name: pip freeze
5353
run: python -m pip freeze
5454

55+
- name: tiny-llm torch.export.export
56+
run: python _unittests/ut_torch_models/test_llms.py
57+
58+
- name: tiny-llm onnx
59+
run: python _unittests/ut_torch_models/test_llms_onnx.py
60+
5561
- name: run tests
5662
run: |
5763
pip install pytest

.github/workflows/documentation.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ jobs:
3939
run: python -m pip install -r requirements.txt
4040

4141
- name: Install requirements dev
42-
run: |
43-
python -m pip install -r requirements-dev.txt
42+
run: python -m pip install -r requirements-dev.txt
4443

4544
- name: Cache pip
4645
uses: actions/cache@v4
@@ -54,6 +53,12 @@ jobs:
5453
- name: pip freeze
5554
run: python -m pip freeze
5655

56+
- name: tiny-llm torch.export.export
57+
run: python _unittests/ut_torch_models/test_llms.py
58+
59+
- name: tiny-llm onnx
60+
run: python _unittests/ut_torch_models/test_llms_onnx.py
61+
5762
- name: Generate coverage report
5863
run: |
5964
pip install pytest

_unittests/ut_torch_models/test_llms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@ def test_get_tiny_llm(self):
1717
def test_export_tiny_llm_1(self):
1818
data = get_tiny_llm()
1919
model, inputs = data["model"], data["inputs"]
20+
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
2021
ep = torch.export.export(
2122
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]
2223
)
2324
assert ep
24-
print(ep)
2525

2626
@ignore_warnings(UserWarning)
2727
def test_export_tiny_llm_2_bypassed(self):
2828
data = get_tiny_llm()
2929
model, inputs = data["model"], data["inputs"]
30+
self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs))
3031
with bypass_export_some_errors():
3132
ep = torch.export.export(
3233
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"]

_unittests/ut_xrun_doc/test_torch_test_helper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,21 @@
66
import onnx.numpy_helper as onh
77
import torch
88
from onnx_diagnostic.ext_test_case import ExtTestCase
9-
from onnx_diagnostic.torch_test_helper import dummy_llm, check_model_ort, to_numpy
9+
from onnx_diagnostic.torch_test_helper import (
10+
dummy_llm,
11+
check_model_ort,
12+
to_numpy,
13+
is_torchdynamo_exporting,
14+
)
1015

1116
TFLOAT = onnx.TensorProto.FLOAT
1217

1318

1419
class TestOrtSession(ExtTestCase):
1520

21+
def test_is_torchdynamo_exporting(self):
22+
self.assertFalse(is_torchdynamo_exporting())
23+
1624
def test_dummy_llm(self):
1725
for cls_name in ["AttentionBlock", "MultiAttentionBlock", "DecoderLayer", "LLM"]:
1826
model, inputs = dummy_llm(cls_name)

onnx_diagnostic/torch_test_helper.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,23 @@
99
from .helpers import pretty_onnx
1010

1111

12+
def is_torchdynamo_exporting() -> bool:
13+
"""Tells if torch is exporting a model."""
14+
import torch
15+
16+
try:
17+
return torch.compiler.is_exporting()
18+
except Exception:
19+
try:
20+
import torch._dynamo as dynamo
21+
22+
return dynamo.is_exporting()
23+
except Exception:
24+
return False
25+
26+
1227
def to_numpy(tensor: "torch.Tensor"): # noqa: F821
13-
"""
14-
Converts a torch tensor to numy.
15-
"""
28+
"""Converts a torch tensor to numy."""
1629
try:
1730
return tensor.numpy()
1831
except TypeError:

0 commit comments

Comments
 (0)