Skip to content

Commit 731a746

Browse files
committed
fix test
1 parent d551e3a commit 731a746

File tree

6 files changed

+18
-9
lines changed

6 files changed

+18
-9
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ jobs:
6262
- name: tiny-llm torch.export.export
6363
run: |
6464
export PYTHONPATH=.
65-
python _unittests/ut_torch_models/test_llms.py
65+
python _unittests/ut_torch_models/test_tiny_llms.py
6666
6767
- name: tiny-llm onnx
6868
run: |
6969
export PYTHONPATH=.
70-
python _unittests/ut_torch_models/test_llms_onnx.py
70+
python _unittests/ut_torch_models/test_tiny_llms_onnx.py
7171
7272
- name: run tests
7373
run: |

.github/workflows/documentation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ jobs:
5656
- name: tiny-llm torch.export.export
5757
run: |
5858
export PYTHONPATH=.
59-
python _unittests/ut_torch_models/test_llms.py
59+
python _unittests/ut_torch_models/test_tiny_llms.py
6060
6161
- name: tiny-llm onnx
6262
run: |
6363
export PYTHONPATH=.
64-
python _unittests/ut_torch_models/test_llms_onnx.py
64+
python _unittests/ut_torch_models/test_tiny_llms_onnx.py
6565
6666
- name: Generate coverage report
6767
run: |

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
skipif_ci_windows,
77
ignore_warnings,
88
hide_stdout,
9+
has_transformers,
910
)
1011
from onnx_diagnostic.helpers import string_type
1112
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
@@ -68,10 +69,12 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
6869
return x2
6970

7071
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
71-
self.assertEqual(
72-
string_type(cache),
73-
"MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])",
74-
)
72+
if has_transformers("4.50"):
73+
# MambaCache was updated in 4.50
74+
self.assertEqual(
75+
"MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])",
76+
string_type(cache),
77+
)
7578
x = torch.ones(2, 8, 16).to(torch.float16)
7679
model = Model()
7780
model(x, cache)

onnx_diagnostic/helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def string_type(
189189
"""
190190
if obj is None:
191191
return "None"
192+
# tuple
192193
if isinstance(obj, tuple):
193194
if len(obj) == 1:
194195
s = string_type(
@@ -225,6 +226,7 @@ def string_type(
225226
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
226227
return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]"
227228
return f"#{len(obj)}({tt},...)"
229+
# list
228230
if isinstance(obj, list):
229231
if len(obj) < limit:
230232
js = ",".join(
@@ -251,6 +253,7 @@ def string_type(
251253
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
252254
return f"#{len(obj)}[{tt},...][{mini},{maxi}:{avg}]"
253255
return f"#{len(obj)}[{tt},...]"
256+
# set
254257
if isinstance(obj, set):
255258
if len(obj) < 10:
256259
js = ",".join(
@@ -269,6 +272,7 @@ def string_type(
269272
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
270273
return f"{{...}}#{len(obj)}[{mini},{maxi}:A{avg}]"
271274
return f"{{...}}#{len(obj)}" if with_shape else "{...}"
275+
# dict
272276
if isinstance(obj, dict):
273277
if len(obj) == 0:
274278
return "{}"
@@ -281,6 +285,7 @@ def string_type(
281285
)
282286
s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items())
283287
return f"dict({s})"
288+
# arrat
284289
if isinstance(obj, np.ndarray):
285290
if with_min_max:
286291
s = string_type(obj, with_shape=with_shape)
@@ -303,6 +308,7 @@ def string_type(
303308

304309
import torch
305310

311+
# Dim, SymInt
306312
if isinstance(obj, torch.export.dynamic_shapes._DerivedDim):
307313
return "DerivedDim"
308314
if isinstance(obj, torch.export.dynamic_shapes._Dim):
@@ -311,13 +317,13 @@ def string_type(
311317
return "SymInt"
312318
if isinstance(obj, torch.SymFloat):
313319
return "SymFloat"
320+
# Tensors
314321
if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
315322
i = torch_dtype_to_onnx_dtype(obj.dtype)
316323
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
317324
if not with_shape:
318325
return f"{prefix}F{i}r{len(obj.shape)}"
319326
return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
320-
321327
if isinstance(obj, torch.Tensor):
322328
if with_min_max:
323329
s = string_type(obj, with_shape=with_shape, with_device=with_device)

0 commit comments

Comments
 (0)