Skip to content

Commit 0bc08eb

Browse files
committed
add 4.57.0 to ci
1 parent 38720ad commit 0bc08eb

File tree

4 files changed

+90
-10
lines changed

4 files changed

+90
-10
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [ubuntu-latest]
1919
python: ['3.10', '3.11', '3.12', '3.13']
20-
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', 'main']
20+
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57', 'main']
2121
torch: ['2.8', 'main']
2222
exclude:
2323
- python: '3.10'
@@ -30,6 +30,8 @@ jobs:
3030
transformers: '4.55.4'
3131
- python: '3.10'
3232
transformers: '4.56.2'
33+
- python: '3.10'
34+
transformers: '4.57.0'
3335
- python: '3.11'
3436
torch: 'main'
3537
- python: '3.11'
@@ -38,6 +40,8 @@ jobs:
3840
transformers: '4.55.4'
3941
- python: '3.11'
4042
transformers: '4.56.2'
43+
- python: '3.11'
44+
transformers: '4.57.0'
4145
- python: '3.13'
4246
torch: '2.8'
4347
- python: '3.13'

_unittests/ut_tasks/test_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def test_falcon_mamba_dev(self):
270270
model(**inputs)
271271
model(**data["inputs2"])
272272
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
273-
if not has_transformers("4.57"):
273+
if not has_transformers("4.57.99"):
274274
raise unittest.SkipTest("The model has control flow.")
275275
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
276276
torch.export.export(

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
requires_transformers,
99
has_torch,
1010
)
11+
from onnx_diagnostic.helpers.cache_helper import CacheKeyValue, make_dynamic_cache
1112
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1213
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1314
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -345,16 +346,92 @@ def forward(self, x, ind1, ind2):
345346

346347
@requires_torch("2.7.9999")
347348
@requires_transformers("4.49.9999")
348-
def test_export_tiny_llm_dim_meta(self):
349+
def test_export_with_patch_tiny_llm_dim_meta(self):
349350
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
350351
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
352+
order = ["input_ids", "attention_mask", "position_ids", "past_key_values"]
353+
self.assertEqual(list(inputs), order)
351354
expected = model(**torch_deepcopy(inputs))
352-
with torch_export_patches(patch_transformers=True):
353-
ep = torch.export.export(
354-
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
355-
)
356-
got = ep.module()(**inputs)
357-
self.assertEqualArrayAny(expected, got)
355+
with self.subTest(input="no01", backed_size_oblivious=False):
356+
with torch_export_patches(patch_transformers=True):
357+
ep = torch.export.export(
358+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
359+
)
360+
got = ep.module()(**torch_deepcopy(inputs))
361+
self.assertEqualArrayAny(expected, got)
362+
363+
with self.subTest(input="no01", backed_size_oblivious=True):
364+
with (
365+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
366+
torch_export_patches(patch_transformers=True),
367+
):
368+
ep = torch.export.export(
369+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
370+
)
371+
got = ep.module()(**torch_deepcopy(inputs))
372+
self.assertEqualArrayAny(expected, got)
373+
374+
def _batch1(t):
375+
if t.__class__.__name__ == "DynamicCache":
376+
kv = CacheKeyValue(t)
377+
keys = [t[:1] for t in kv.key_cache]
378+
values = [t[:1] for t in kv.value_cache]
379+
return make_dynamic_cache(tuple(zip(keys, values)))
380+
if t.ndim > 1:
381+
return t[:1]
382+
return t
383+
384+
export_inputs = {k: _batch1(v) for k, v in inputs.items()}
385+
386+
# with self.subTest(input="batch1", backed_size_oblivious=False):
387+
# with torch_export_patches(patch_transformers=True):
388+
# ep = torch.export.export(
389+
# model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
390+
# )
391+
# got = ep.module()(**torch_deepcopy(inputs))
392+
# self.assertEqualArrayAny(expected, got)
393+
394+
with self.subTest(input="batch1", backed_size_oblivious=True):
395+
with (
396+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
397+
torch_export_patches(patch_transformers=True),
398+
):
399+
ep = torch.export.export(
400+
model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
401+
)
402+
try:
403+
got = ep.module()(**torch_deepcopy(inputs))
404+
except AssertionError as e:
405+
got = None
406+
if "Guard failed: position_ids.size()[0] == 1" not in str(e):
407+
raise
408+
409+
if got is not None:
410+
self.assertEqualArrayAny(expected, got)
411+
412+
if "inputs_empty_cache" not in data:
413+
return
414+
415+
export_inputs = data["inputs_empty_cache"]
416+
417+
# with self.subTest(input="cache0", backed_size_oblivious=False):
418+
# with torch_export_patches(patch_transformers=True):
419+
# ep = torch.export.export(
420+
# model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
421+
# )
422+
# got = ep.module()(**torch_deepcopy(inputs))
423+
# self.assertEqualArrayAny(expected, got)
424+
425+
with self.subTest(input="cache0", backed_size_oblivious=True):
426+
with (
427+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
428+
torch_export_patches(patch_transformers=True),
429+
):
430+
ep = torch.export.export(
431+
model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
432+
)
433+
got = ep.module()(**torch_deepcopy(inputs))
434+
self.assertEqualArrayAny(expected, got)
358435

359436

360437
if __name__ == "__main__":

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,6 @@ def _greater_than_reduce(acc, x):
671671

672672
return x
673673

674-
print("****", broadcast_dimensions)
675674
reduce(_greater_than_reduce, broadcast_dimensions, -1)
676675

677676
# shape must be broadcastable to

0 commit comments

Comments
 (0)