Skip to content

Commit edc2b53

Browse files
committed
change
1 parent c563272 commit edc2b53

File tree

4 files changed

+94
-15
lines changed

4 files changed

+94
-15
lines changed

_unittests/ut_tasks/try_tasks.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23
import torch
34
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
@@ -799,17 +800,19 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
799800
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
800801

801802
model_id = "google/gemma-3-4b-it"
802-
# model_id = "google/gemma-3n-e4b-it"
803-
# model_id = "qnaug/gemma-3-4b-med"
804-
# model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
805-
# data = get_untrained_model_with_inputs(
806-
# model_id, verbose=1, add_second_input=True,
807-
# same_as_pretrained=True, use_pretrained=True
808-
# )
809-
# model = data["model"]
810-
model = Gemma3ForConditionalGeneration.from_pretrained(
811-
model_id, device_map="cpu"
812-
).eval()
803+
if os.environ.get("PRETRAINED", ""):
804+
model = Gemma3ForConditionalGeneration.from_pretrained(
805+
model_id, device_map="cpu"
806+
).eval()
807+
else:
808+
data = get_untrained_model_with_inputs(
809+
model_id,
810+
verbose=1,
811+
add_second_input=True,
812+
# same_as_pretrained=True, #use_pretrained=True
813+
)
814+
model = data["model"]
815+
813816
print(f"-- model.device={model.device}")
814817
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
815818
print(f"-- processor={type(processor)}")
@@ -845,11 +848,39 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
845848
# inputs.pop("token_type_ids", None)
846849
print(f"-- inputs={self.string_type(inputs)}")
847850

851+
# iteration 1
852+
# cache_position:T7s281,
853+
# past_key_values:StaticCache(key_cache=#0[], value_cache=#0[]),
854+
# input_ids:T7s1x281,
855+
# inputs_embeds:None,
856+
# token_type_ids:T7s1x281,
857+
# attention_mask:dict(sliding_attention:T9s1x1x281x580,
858+
# full_attention:T9s1x1x281x580),
859+
# position_ids:None,
860+
# use_cache:bool,
861+
# logits_to_keep:None,
862+
# pixel_values:T16s1x3x896x896,
863+
# return_dict:bool)
864+
# iteration 3
865+
# cache_position:T7s1,
866+
# past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
867+
# value_cache=#34[T1s1x4x580x256,...]),
868+
# input_ids:T7s1x1,
869+
# inputs_embeds:None,
870+
# token_type_ids:T7s1x1,
871+
# attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
872+
# position_ids:None,
873+
# use_cache:bool,logits_to_keep:None,return_dict:bool)
874+
848875
print()
849876
# steal forward creates a bug...
850-
with steal_forward(model): # , torch.inference_mode():
877+
with steal_forward(
878+
model,
879+
dump_file=self.get_dump_file("test_imagetext2text_generation_gemma3_4b_it.onnx"),
880+
dump_drop={"attention_mask", "past_key_values", "pixel_values"},
881+
):
851882
generated_ids = model.generate(
852-
**inputs, max_new_tokens=300, do_sample=False, cache_implementation="hybrid"
883+
**inputs, max_new_tokens=282, do_sample=False, cache_implementation="static"
853884
)
854885
output_text = processor.decode(
855886
generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,23 @@ def _mk(k):
381381
else:
382382
for p, o in _flatten_iterator(getattr(obj, att), sep):
383383
yield f"DynamicCache_{att}{sep}{p}", o
384+
elif obj.__class__.__name__ == "StaticCache":
385+
# transformers
386+
import transformers
387+
from .cache_helper import CacheKeyValue
388+
389+
assert isinstance(
390+
obj, transformers.cache_utils.StaticCache
391+
), f"Unexpected type {type(obj)}"
392+
obj = CacheKeyValue(obj)
393+
atts = ["key_cache", "value_cache"]
394+
for i, att in enumerate(atts):
395+
if i == len(atts) - 1:
396+
for p, o in _flatten_iterator(getattr(obj, att), sep):
397+
yield f"StaticCache._{att}{sep}{p}", o
398+
else:
399+
for p, o in _flatten_iterator(getattr(obj, att), sep):
400+
yield f"StaticCache_{att}{sep}{p}", o
384401
else:
385402
raise NotImplementedError(f"Unexpected type {type(obj)}")
386403

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import warnings
77
from collections.abc import Iterable
8-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
99
import numpy as np
1010
import onnx
1111
from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
@@ -283,6 +283,7 @@ def steal_forward(
283283
],
284284
fprint: Callable = string_type,
285285
dump_file: Optional[str] = None,
286+
dump_drop: Optional[Set[str]] = None,
286287
submodules: bool = False,
287288
verbose: int = 0,
288289
storage_limit: int = 2**27,
@@ -303,6 +304,7 @@ def steal_forward(
303304
:param dump_file: dumps stolen inputs and outputs in an onnx model,
304305
they can be restored with :func:`create_input_tensors_from_onnx_model
305306
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
307+
:param dump_drop: to drop some inputs too big (only if dump_file is specified)
306308
:param submodules: if True and model is a module, the list extended with all the submodules
307309
the module contains
308310
:param verbose: verbosity
@@ -411,6 +413,9 @@ def forward(self, x, y):
411413
if verbose:
412414
size = torch_tensor_size(storage)
413415
print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb")
416+
if dump_drop:
417+
print(string_type(dump_drop))
418+
stop
414419
proto = create_onnx_model_from_input_tensors(storage)
415420
if verbose:
416421
print("-- dumps stored objects")
@@ -794,9 +799,14 @@ def torch_deepcopy(value: Any) -> Any:
794799
from .cache_helper import CacheKeyValue
795800

796801
ca = CacheKeyValue(value)
802+
if len(ca.key_cache) == 0:
803+
# Use of deepcopy.
804+
import copy
805+
806+
return copy.deepcopy(value)
797807
return make_static_cache(
798808
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))),
799-
max_cache_len=value.max_cache_len,
809+
max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
800810
)
801811
if value.__class__.__name__ == "HybridCache":
802812
from .cache_helper import CacheKeyValue

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@
1414
def reduce_model_config(config: Any) -> Dict[str, Any]:
1515
"""Reduces a model size."""
1616
kwargs: Dict[str, Any] = {}
17+
if (
18+
hasattr(config, "architectures")
19+
and config.architectures
20+
and config.architectures[0] == "Gemma3ForConditionalGeneration"
21+
):
22+
if hasattr(config, "vision_config"):
23+
if hasattr(config.vision_config, "num_hidden_layers"):
24+
config.vision_config.num_hidden_layers = min(
25+
config.vision_config.num_hidden_layers, nhl()
26+
)
27+
if hasattr(config, "text_config"):
28+
if hasattr(config.text_config, "intermediate_size"):
29+
config.text_config.intermediate_size = min(
30+
config.text_config.intermediate_size, 10240 // 10 * 5 // 2
31+
)
32+
config.text_config.hidden_size = min(
33+
config.text_config.hidden_size, 2560 // 10 * 5 // 2
34+
)
35+
update_config(config, kwargs)
36+
return kwargs
37+
1738
if hasattr(config, "num_hidden_layers"):
1839
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
1940
if hasattr(config, "mm_tokens_per_image"):

0 commit comments

Comments
 (0)