Skip to content

Commit ad93117

Browse files
authored
Improves image-text-to-text (#192)
* Add a unit test about an issue * fix * fix mambacache import * fix import issues * improves text-image-to-text * add use_cache * add support for hybrid cache * update * fix * fix two issues * fix cache * fix cachekeyvalue * fixes * mypy * fix issues * restore patch * requires * another quick fix * one fix * fix cache * cache * fix patches * fix patch
1 parent 5d6ba01 commit ad93117

25 files changed

+1225
-311
lines changed

CHANGELOGS.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Change Logs
44
0.7.6
55
+++++
66

7+
* :pr:`192`: add support for Gemma-3, add serialization for HybridCache,
8+
changes to support ``transformers>=4.54``
9+
710
0.7.5
811
+++++
912

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def linkcode_resolve(domain, info):
140140
("py:class", "transformers.cache_utils.Cache"),
141141
("py:class", "transformers.cache_utils.DynamicCache"),
142142
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
143+
("py:class", "transformers.cache_utils.HybridCache"),
143144
("py:class", "transformers.cache_utils.MambaCache"),
144145
("py:class", "transformers.cache_utils.SlidingWindowCache"),
145146
("py:class", "transformers.cache_utils.StaticCache"),

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
flatten_unflatten_for_dynamic_shapes,
88
make_dynamic_cache,
99
make_encoder_decoder_cache,
10+
make_hybrid_cache,
1011
make_mamba_cache,
1112
make_sliding_window_cache,
1213
make_static_cache,
1314
)
15+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1416
from onnx_diagnostic.export import CoupleInputsDynamicShapes
1517
from onnx_diagnostic.torch_export_patches.patch_inputs import (
1618
convert_dynamic_axes_into_dynamic_shapes,
@@ -48,6 +50,10 @@ def test_replace_by(self):
4850
past_key_values = make_dynamic_cache(
4951
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
5052
)
53+
self.assertEqual(
54+
"DynamicCache(key_cache=#1[T1s2x4x3x7], value_cache=#1[T1s2x4x3x7])",
55+
self.string_type(past_key_values, with_shape=True),
56+
)
5157
kwargs = dict(
5258
input_ids=torch.zeros(2, 3),
5359
attention_mask=torch.zeros(2, 3),
@@ -209,6 +215,45 @@ def test_unflatten_flatten_static_cache(self):
209215
self.string_type(unflat, with_shape=True),
210216
)
211217

218+
def test_make_hybrid_cache(self):
219+
cache = make_hybrid_cache(
220+
[
221+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
222+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
223+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
224+
],
225+
)
226+
text = self.string_type(cache, with_shape=True)
227+
self.assertEqual(
228+
"HybridCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
229+
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
230+
text,
231+
)
232+
self.assertEqual(0, max_diff(cache, cache)["abs"])
233+
self.assertEqual(0, max_diff(cache, torch_deepcopy(cache))["abs"])
234+
235+
def test_unflatten_flatten_hybrid_cache(self):
236+
with torch_export_patches(patch_transformers=True):
237+
c2 = make_hybrid_cache(
238+
[
239+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
240+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
241+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
242+
],
243+
)
244+
self.assertEqual(0, max_diff(c2, c2)["abs"])
245+
self.assertIsInstance(c2, transformers.cache_utils.HybridCache)
246+
flat, _spec = torch.utils._pytree.tree_flatten(c2)
247+
self.assertIsInstance(flat, list)
248+
self.assertEqual(len(flat), 6)
249+
unflat = flatten_unflatten_for_dynamic_shapes(c2)
250+
self.assertIsInstance(unflat, list)
251+
self.assertEqual(len(unflat), 2)
252+
self.assertEqual(
253+
"#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]",
254+
self.string_type(unflat, with_shape=True),
255+
)
256+
212257

213258
if __name__ == "__main__":
214259
unittest.main(verbosity=2)

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import onnx
55
import torch
66
import transformers
7-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_torch
88
from onnx_diagnostic.helpers import max_diff, string_type
99
from onnx_diagnostic.helpers.torch_helper import (
1010
dummy_llm,
@@ -23,6 +23,7 @@
2323
make_encoder_decoder_cache,
2424
make_mamba_cache,
2525
make_sliding_window_cache,
26+
CacheKeyValue,
2627
)
2728
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
2829
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended
@@ -210,15 +211,17 @@ def forward(self, x, y):
210211
print(string_type(restored, with_shape=True))
211212
l1, l2 = 186, 195
212213
self.assertEqual(
213-
[
214-
(f"-Model-{l2}", 0, "I"),
215-
(f"-Model-{l2}", 0, "O"),
216-
(f"s1-SubModel-{l1}", 0, "I"),
217-
(f"s1-SubModel-{l1}", 0, "O"),
218-
(f"s2-SubModel-{l1}", 0, "I"),
219-
(f"s2-SubModel-{l1}", 0, "O"),
220-
],
221-
sorted(restored),
214+
len(
215+
[
216+
(f"-Model-{l2}", 0, "I"),
217+
(f"-Model-{l2}", 0, "O"),
218+
(f"s1-SubModel-{l1}", 0, "I"),
219+
(f"s1-SubModel-{l1}", 0, "O"),
220+
(f"s2-SubModel-{l1}", 0, "I"),
221+
(f"s2-SubModel-{l1}", 0, "O"),
222+
]
223+
),
224+
len(sorted(restored)),
222225
)
223226

224227
def test_replace_string_by_dynamic(self):
@@ -265,11 +268,13 @@ def test_torch_deepcopy_cache_dce(self):
265268
a = {"t": [(torch.tensor([1, 2]), c1, c2), {4, 5}]}
266269
at = torch_deepcopy(a)
267270
hash1 = string_type(at, with_shape=True, with_min_max=True)
268-
c1.key_cache[0] += 1000
271+
ccv = CacheKeyValue(c1)
272+
ccv.key_cache[0] += 1000
269273
hash2 = string_type(at, with_shape=True, with_min_max=True)
270274
self.assertEqual(hash1, hash2)
271275
self.assertGreater(torch_tensor_size(cc), 1)
272276

277+
@requires_torch("4.50")
273278
def test_torch_deepcopy_mamba_cache(self):
274279
cache = make_mamba_cache(
275280
[
@@ -312,7 +317,7 @@ def test_torch_deepcopy_sliding_windon_cache(self):
312317
self.assertEqual(type(cache), type(at))
313318
self.assertEqual(max_diff(cache, at)["abs"], 0)
314319
hash1 = string_type(at, with_shape=True, with_min_max=True)
315-
cache.key_cache[0] += 1000
320+
CacheKeyValue(cache).key_cache[0] += 1000
316321
hash2 = string_type(at, with_shape=True, with_min_max=True)
317322
self.assertEqual(hash1, hash2)
318323
self.assertGreater(torch_tensor_size(cache), 1)

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ class TestTasksImageTextToText(ExtTestCase):
1616
@hide_stdout()
1717
@requires_transformers("4.53")
1818
@requires_torch("2.7.99")
19-
def test_image_text_to_text(self):
19+
def test_image_text_to_text_idefics(self):
2020
mid = "HuggingFaceM4/tiny-random-idefics"
2121
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
2222
self.assertEqual(data["task"], "image-text-to-text")
23-
self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)])
23+
self.assertIn((data["size"], data["n_weights"]), [(12628776, 3157194)])
2424
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2525
model(**torch_deepcopy(inputs))
2626
model(**data["inputs2"])
@@ -29,6 +29,24 @@ def test_image_text_to_text(self):
2929
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
3030
)
3131

32+
@hide_stdout()
33+
@requires_transformers("4.53")
34+
@requires_torch("2.7.99")
35+
def test_image_text_to_text_gemma3(self):
36+
# mid = "google/gemma-3-4b-it"
37+
mid = "tiny-random/gemma-3"
38+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
39+
self.assertEqual(data["task"], "image-text-to-text")
40+
# self.assertIn((data["size"], data["n_weights"]), [(17248576, 4312144)])
41+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
42+
print("--", self.string_type(data["inputs"], with_shape=True))
43+
model(**torch_deepcopy(inputs))
44+
model(**data["inputs2"])
45+
with torch_export_patches(patch_transformers=True, verbose=10):
46+
torch.export.export(
47+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
48+
)
49+
3250

3351
if __name__ == "__main__":
3452
unittest.main(verbosity=2)

_unittests/ut_tasks/test_tasks_mask_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_mask_generation(self):
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2424
model(**torch_deepcopy(inputs))
2525
model(**data["inputs2"])
26-
with torch_export_patches(patch_transformers=True, verbose=10):
26+
with torch_export_patches(patch_transformers=True, verbose=1):
2727
torch.export.export(
2828
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2929
)

_unittests/ut_tasks/try_tasks.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_text_generation_phi4_moe(self):
257257
print(f">>> Response\n{response}")
258258

259259
@never_test()
260-
def test_imagetext2text_generation(self):
260+
def test_imagetext2text_generation_idefics(self):
261261
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k etext2t
262262
# https://huggingface.co/docs/transformers/main/en/tasks/idefics
263263

@@ -287,6 +287,81 @@ def test_imagetext2text_generation(self):
287287

288288
print(generated_text[0])
289289

290+
@never_test()
291+
def test_imagetext2text_generation_gemma3(self):
292+
"""
293+
::
294+
295+
dict(input_ids:T7s1x281,
296+
pixel_values:T16s1x3x896x896,
297+
attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
298+
position_ids:T7s1x281,
299+
past_key_values:HybridCache(
300+
key_cache=#34[T1s1x4x380x256,...],
301+
value_cache=#34[T1s1x4x380x256,...]),
302+
token_type_ids:T7s1x281,
303+
cache_position:T7s281,
304+
logits_to_keep:1)
305+
dict(input_ids:T7s1x1,
306+
pixel_values:None,
307+
attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
308+
position_ids:T7s1x1,
309+
past_key_values:HybridCache(
310+
key_cache=#34[T1s1x4x380x256,...],
311+
value_cache=#34[T1s1x4x380x256,...]),
312+
token_type_ids:T7s1x1,
313+
cache_position:T7s1,
314+
logits_to_keep:1)
315+
"""
316+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
317+
import torch
318+
319+
# model_id = "tiny-random/gemma-3"
320+
model_id = "google/gemma-3-4b-it"
321+
322+
model = Gemma3ForConditionalGeneration.from_pretrained(
323+
model_id, device_map="auto"
324+
).eval()
325+
326+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
327+
328+
messages = [
329+
{
330+
"role": "system",
331+
"content": [{"type": "text", "text": "You are a helpful assistant."}],
332+
},
333+
{
334+
"role": "user",
335+
"content": [
336+
{
337+
"type": "image",
338+
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
339+
},
340+
{"type": "text", "text": "Describe this image in detail."},
341+
],
342+
},
343+
]
344+
345+
inputs = processor.apply_chat_template(
346+
messages,
347+
add_generation_prompt=True,
348+
tokenize=True,
349+
return_dict=True,
350+
return_tensors="pt",
351+
).to(model.device, dtype=torch.bfloat16)
352+
353+
input_len = inputs["input_ids"].shape[-1]
354+
355+
print()
356+
print(f"-- input_len={input_len}")
357+
# steal forward creates a bug...
358+
# with steal_forward(model), torch.inference_mode():
359+
with torch.inference_mode():
360+
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
361+
generation = generation[0][input_len:]
362+
decoded = processor.decode(generation, skip_special_tokens=True)
363+
print(decoded)
364+
290365
@never_test()
291366
def test_automatic_speech_recognition(self):
292367
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k automatic_speech

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
has_transformers,
1212
)
1313
from onnx_diagnostic.helpers import string_type
14-
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
14+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue
1515
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
1616
torch_export_patches,
1717
)
@@ -27,14 +27,26 @@ def test_export_dynamic_cache_update(self):
2727

2828
class SubModelCache(torch.nn.Module):
2929
def forward(self, cache):
30+
cc = CacheKeyValue(cache)
31+
# If not patched...
32+
# Fails with transformers>=4.54 because function ``parse_processor_args``
33+
# relies in inspect and the exporter is not very fond of that.
34+
# torch._dynamo.exc.Unsupported: id() with unsupported args
35+
# Explanation: Dynamo doesn't know how to trace id()
36+
# call with args
37+
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
38+
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
39+
# objects from outside the compiled region.
40+
# Hint: It may be possible to write Dynamo tracing rules for this code.
3041
d = cache.__class__()
31-
d.update(cache.key_cache[0] + 1, cache.value_cache[0] + 2, 0)
32-
d.update(cache.key_cache[0] + 3, cache.value_cache[0] + 5, 1)
42+
d.update(cc.key_cache[0] + 1, cc.value_cache[0] + 2, 0)
43+
d.update(cc.key_cache[0] + 3, cc.value_cache[0] + 5, 1)
3344
return d
3445

3546
class SubModel(torch.nn.Module):
3647
def forward(self, x, cache):
37-
return x + cache.key_cache[0] + cache.value_cache[0]
48+
cc = CacheKeyValue(cache)
49+
return x + cc.key_cache[0] + cc.value_cache[0]
3850

3951
class Model(torch.nn.Module):
4052
def __init__(self):
@@ -56,7 +68,7 @@ def forward(self, x, cache):
5668
DYN = torch.export.Dim.DYNAMIC
5769

5870
# patching
59-
with torch_export_patches(patch_transformers=True):
71+
with torch_export_patches(patch_transformers=True, verbose=10):
6072
got = model(*inputs)
6173
self.assertEqualArray(expected, got)
6274
ep = torch.export.export(
@@ -230,9 +242,10 @@ def test_export_dynamic_cache_cat(self):
230242

231243
class ModelDynamicCache(torch.nn.Module):
232244
def forward(self, x, dc):
245+
cc = CacheKeyValue(dc)
233246
y = (
234247
(
235-
torch.cat(dc.key_cache, axis=1) + torch.cat(dc.value_cache, axis=1)
248+
torch.cat(cc.key_cache, axis=1) + torch.cat(cc.value_cache, axis=1)
236249
).reshape((-1, x.shape[1]))
237250
).transpose(1, 0)
238251
return x @ y

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
skipif_ci_windows,
77
ignore_warnings,
88
hide_stdout,
9-
has_transformers,
109
)
1110
from onnx_diagnostic.helpers import string_type
1211
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
@@ -48,7 +47,7 @@ def __init__(self):
4847
self.assertEqualArrayAny(cache.conv_states, cache2.conv_states)
4948
self.assertEqualArrayAny(cache.ssm_states, cache2.ssm_states)
5049

51-
@requires_transformers("4.43")
50+
@requires_transformers("4.50")
5251
@requires_torch("2.7")
5352
@skipif_ci_windows("not working on Windows")
5453
@ignore_warnings(UserWarning)
@@ -72,17 +71,16 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
7271
return x2
7372

7473
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
75-
if has_transformers("4.50"):
76-
# MambaCache was updated in 4.50
77-
self.assertEqual(
78-
"MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])",
79-
string_type(cache),
80-
)
74+
# MambaCache was updated in 4.50
75+
self.assertEqual(
76+
"MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])",
77+
string_type(cache),
78+
)
8179
x = torch.ones(2, 8, 16).to(torch.float16)
8280
model = Model()
8381
model(x, cache)
8482

85-
with torch_export_patches(verbose=1):
83+
with torch_export_patches(verbose=1, patch_transformers=True):
8684
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
8785
torch.export.export(Model(), (x, cache))
8886

0 commit comments

Comments
 (0)