Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def linkcode_resolve(domain, info):
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
("py:class", "transformers.cache_utils.MambaCache"),
("py:class", "transformers.cache_utils.SlidingWindowCache"),
("py:class", "transformers.cache_utils.StaticCache"),
("py:class", "transformers.configuration_utils.PretrainedConfig"),
("py:class", "transformers.modeling_outputs.BaseModelOutput"),
("py:class", "transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding"),
Expand Down
13 changes: 11 additions & 2 deletions _doc/examples/plot_export_tiny_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


MODEL_NAME = "arnir0/Tiny-LLM"
Expand Down Expand Up @@ -131,7 +132,11 @@ def _forward_(*args, _f=None, **kwargs):

try:
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
untrained_model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
Expand Down Expand Up @@ -166,7 +171,11 @@ def _forward_(*args, _f=None, **kwargs):

try:
ep = torch.export.export(
model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
Expand Down
5 changes: 3 additions & 2 deletions _doc/examples/plot_export_tiny_llm_patched.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.llms import get_tiny_llm


Expand Down Expand Up @@ -110,7 +111,7 @@
untrained_model,
(),
kwargs=modificator(cloned_inputs),
dynamic_shapes=dynamic_shapes,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False, # mandatory for torch==2.6
)
print("It worked:")
Expand All @@ -131,7 +132,7 @@
model,
(),
kwargs=modificator(cloned_inputs),
dynamic_shapes=dynamic_shapes,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False, # mandatory for torch==2.6
)
print("It worked:")
Expand Down
44 changes: 43 additions & 1 deletion _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import torch
import transformers
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers import string_type, max_diff
from onnx_diagnostic.helpers.cache_helper import (
flatten_unflatten_for_dynamic_shapes,
make_dynamic_cache,
make_encoder_decoder_cache,
make_mamba_cache,
make_sliding_window_cache,
make_static_cache,
)
from onnx_diagnostic.export import CoupleInputsDynamicShapes
from onnx_diagnostic.torch_export_patches.patch_inputs import (
Expand Down Expand Up @@ -104,6 +105,7 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
]
),
)
self.assertEqual(0, max_diff(c2, c2)["abs"])
self.assertIsInstance(c2, transformers.cache_utils.EncoderDecoderCache)
flat, _spec = torch.utils._pytree.tree_flatten(c2)
self.assertIsInstance(flat, list)
Expand Down Expand Up @@ -149,6 +151,7 @@ def test_make_mamba_cache(self):
"ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])",
text,
)
self.assertEqual(0, max_diff(cache, cache)["abs"])

def test_make_sliding_window_cache(self):
cache = make_sliding_window_cache(
Expand All @@ -164,6 +167,45 @@ def test_make_sliding_window_cache(self):
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
text,
)
self.assertEqual(0, max_diff(cache, cache)["abs"])

def test_make_static_cache(self):
cache = make_static_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
)
text = self.string_type(cache, with_shape=True)
self.assertEqual(
"StaticCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
text,
)
self.assertEqual(0, max_diff(cache, cache)["abs"])

def test_unflatten_flatten_static_cache(self):
with torch_export_patches(patch_transformers=True):
c2 = make_static_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
)
self.assertEqual(0, max_diff(c2, c2)["abs"])
self.assertIsInstance(c2, transformers.cache_utils.StaticCache)
flat, _spec = torch.utils._pytree.tree_flatten(c2)
self.assertIsInstance(flat, list)
self.assertEqual(len(flat), 6)
unflat = flatten_unflatten_for_dynamic_shapes(c2)
self.assertIsInstance(unflat, list)
self.assertEqual(len(unflat), 2)
self.assertEqual(
"#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]",
self.string_type(unflat, with_shape=True),
)


if __name__ == "__main__":
Expand Down
29 changes: 29 additions & 0 deletions _unittests/ut_tasks/try_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,35 @@ def test_text2text_generation(self):
)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

@never_test()
def test_text2text_generation_static(self):
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k text2t

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("arnir0/Tiny-LLM")
model = AutoModelForCausalLM.from_pretrained("arnir0/Tiny-LLM")

text = "def greet(user): print(f'hello <extra_id_0>!')"
input_ids = tokenizer(text, return_tensors="pt").input_ids
mask = (
torch.tensor([1 for i in range(input_ids.shape[1])])
.to(torch.int64)
.reshape((1, -1))
)

# simply generate a single sequence
print()
with steal_forward(model):
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=mask,
max_new_tokens=117,
cache_implementation="static",
)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

@never_test()
def test_text_generation_phi4_mini(self):
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_mini
Expand Down
54 changes: 46 additions & 8 deletions _unittests/ut_torch_models/test_tiny_llms.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,70 @@
import copy
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
ignore_warnings,
requires_transformers,
requires_torch,
)
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestTinyLlm(ExtTestCase):
def test_get_tiny_llm(self):
def test_tiny_llm_run_dynamic(self):
data = get_tiny_llm()
model, inputs = data["model"], data["inputs"]
self.assertIn("DynamicCache", string_type(inputs))
model(**inputs)

@ignore_warnings(UserWarning)
@requires_transformers("4.53")
def test_export_tiny_llm_1(self):
@requires_torch("2.8")
def test_tiny_llm_export_dynamic(self):
data = get_tiny_llm()
model, inputs = data["model"], data["inputs"]
expected = model(**copy.deepcopy(inputs))
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
ep = torch.export.export(
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=data["dynamic_shapes"]
with torch_export_patches(patch_transformers=True):
ep = torch.export.export(
model,
(),
kwargs=copy.deepcopy(inputs),
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
)
got = ep.module()(**inputs)
self.assertEqualArrayAny(expected, got)

@requires_transformers("4.52")
def test_tiny_llm_run_static(self):
data = get_tiny_llm(use_static_cache=True)
model, inputs = data["model"], data["inputs"]
self.assertIn("StaticCache", string_type(inputs))
model(**inputs)

@ignore_warnings(UserWarning)
@requires_transformers("4.52")
@requires_torch("2.8")
def test_tiny_llm_export_static(self):
data = get_tiny_llm(use_static_cache=True)
model, inputs = data["model"], data["inputs"]
expected = model(**copy.deepcopy(inputs))
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs)
)
got = ep.module()(**inputs)
self.assertEqualArrayAny(expected, got)
with torch_export_patches(patch_transformers=True, stop_if_static=1):
ep = torch.export.export(
model,
(),
kwargs=copy.deepcopy(inputs),
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
)
got = ep.module()(**inputs)
self.assertEqualArrayAny(expected, got)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions _unittests/ut_torch_models/test_tiny_llms_bypassed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import unittest
import torch
from transformers.cache_utils import DynamicCache
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_models.llms import get_phi2
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched_DynamicCache,
)


class TestTinyLlmBypassed(ExtTestCase):
@ignore_warnings(UserWarning)
@hide_stdout()
def test_export_tiny_llm_2_bypassed(self):
data = get_tiny_llm()
model, inputs = data["model"], data["inputs"]
Expand Down Expand Up @@ -50,7 +52,11 @@ def debug():
debug()

ep = torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], strict=False
model,
(),
kwargs=inputs,
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
strict=False,
)
got = ep.module()(**inputs)
self.assertEqualArrayAny(expected, got)
Expand Down
Loading
Loading