Skip to content

Commit db6bb79

Browse files
committed
enables export with fake tensors
1 parent 68d71cf commit db6bb79

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1111
from onnx_diagnostic.torch_export_patches import torch_export_patches
1212
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
13+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1314

1415

1516
class TestTasksTextGeneration(ExtTestCase):
1617
@hide_stdout()
1718
@requires_transformers("4.53")
1819
@requires_torch("2.7.99")
19-
def test_tet_generation_gemma3_for_causallm(self):
20+
def test_text_generation_gemma3_for_causallm(self):
2021
mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
2122
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
2223
self.assertEqual(data["task"], "text-generation")
@@ -31,20 +32,38 @@ def test_tet_generation_gemma3_for_causallm(self):
3132
@hide_stdout()
3233
@requires_transformers("4.53")
3334
@requires_torch("2.7.99")
34-
def test_itext_generation_phi_3_mini_128k_instruct(self):
35+
def test_text_generation_phi_3_mini_128k_instruct(self):
3536
mid = "microsoft/Phi-3-mini-128k-instruct"
3637
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
3738
self.assertEqual(data["task"], "text-generation")
3839
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
39-
print("--", self.string_type(inputs, with_shape=True))
40-
print("--", self.string_type(ds))
4140
model(**torch_deepcopy(inputs))
4241
model(**data["inputs2"])
4342
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
4443
torch.export.export(
4544
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4645
)
4746

47+
@hide_stdout()
48+
@requires_transformers("4.53")
49+
@requires_torch("2.7.99")
50+
def test_text_generation_tiny_llm(self):
51+
mid = "arnir0/Tiny-LLM"
52+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
53+
self.assertEqual(data["task"], "text-generation")
54+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
55+
expected = model(**torch_deepcopy(inputs))
56+
model(**data["inputs2"])
57+
fake = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes=ds)[0]
58+
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
59+
ep = torch.export.export(
60+
model, (), kwargs=fake, dynamic_shapes=use_dyn_not_str(ds), strict=False
61+
)
62+
# print(ep)
63+
got = ep.module()(**inputs)
64+
self.assertEqualAny(expected.past_key_values, got.past_key_values)
65+
self.assertEqualArray(expected.logits, got.logits)
66+
4867

4968
if __name__ == "__main__":
5069
unittest.main(verbosity=2)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,21 @@ def make_dynamic_cache(
169169
)
170170
print(string_type(past_key_values, with_shape=True))
171171
"""
172+
if key_value_pairs and isinstance(
173+
key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor
174+
):
175+
cache = transformers.cache_utils.DynamicCache()
176+
cache.layers.extend(
177+
[transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
178+
)
179+
for i, layer in enumerate(cache.layers):
180+
k, v = key_value_pairs[i][0], key_value_pairs[i][1]
181+
layer.dtype = k.dtype
182+
layer.device = k.device
183+
layer.keys = k
184+
layer.values = v
185+
return finalize_cache(cache)
186+
172187
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
173188
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
174189
# The cache constructor contains the two following lines

onnx_diagnostic/helpers/helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ def string_type(
463463
if verbose:
464464
print(f"[string_type] F2:{type(obj)}")
465465
return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
466+
466467
if isinstance(obj, torch.Tensor):
467468
from .torch_helper import torch_dtype_to_onnx_dtype
468469

@@ -783,6 +784,8 @@ def string_type(
783784
obj, ultralytics.engine.results.Results
784785
), f"Unexpected type={type(obj)}"
785786
return f"ultralytics.{obj.__class__.__name__}(...)"
787+
if obj.__class__.__name__ == "FakeTensorMode":
788+
return f"{obj}"
786789

787790
if verbose:
788791
print(f"[string_type] END:{type(obj)}")

0 commit comments

Comments
 (0)