Skip to content

Commit 9f9b7d3

Browse files
committed
fix bug in change_dimnension
1 parent 345b783 commit 9f9b7d3

File tree

4 files changed

+35
-3
lines changed

4 files changed

+35
-3
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import torch
3+
import transformers
34
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
45
from onnx_diagnostic.helpers import string_type
56
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
@@ -742,6 +743,17 @@ def test_couple_input_ds_change_dynamic_dimensions_fixed(self):
742743
self.assertEqual((1, 5, 8), new_input["A"].shape)
743744
self.assertEqual((1, 50), new_input["B"].shape)
744745

746+
def test_couple_input_ds_change_dynamic_dimensions_dynamic_cache(self):
747+
inst = CoupleInputsDynamicShapes(
748+
(),
749+
{"A": make_dynamic_cache([(torch.ones((2, 2, 2, 2)), torch.ones((2, 2, 2, 2)))])},
750+
{"A": [[{0: "batch", 2: "last"}], [{0: "batch", 2: "last"}]]},
751+
)
752+
new_inputs = inst.change_dynamic_dimensions()
753+
self.assertIsInstance(new_inputs["A"], transformers.cache_utils.DynamicCache)
754+
self.assertEqual((3, 2, 3, 2), new_inputs["A"].key_cache[0].shape)
755+
self.assertEqual((3, 2, 3, 2), new_inputs["A"].value_cache[0].shape)
756+
745757
@requires_transformers("4.51")
746758
def test_dynamic_cache_replace_by_string(self):
747759
n_layers = 2

_unittests/ut_helpers/test_helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
dtype_to_tensor_dtype,
4040
)
4141
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
42+
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
43+
4244

4345
TFLOAT = onnx.TensorProto.FLOAT
4446

@@ -484,6 +486,11 @@ def test_flatten_encoder_decoder_cache(self):
484486
s = string_type(inputs)
485487
self.assertIn("EncoderDecoderCache", s)
486488

489+
def test_string_typeçconfig(self):
490+
conf = get_pretrained_config("microsoft/phi-2")
491+
s = string_type(conf)
492+
self.assertStartsWith("PhiConfig(**{", s)
493+
487494

488495
if __name__ == "__main__":
489496
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,16 +363,20 @@ def _generic_walker_step(
363363
)
364364
if flatten_unflatten:
365365
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
366-
return cls._generic_walker_step(
366+
res = cls._generic_walker_step(
367367
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
368368
)
369-
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
369+
# Should we restore the original class?
370+
return res
371+
flat, spec = torch.utils._pytree.tree_flatten(inputs)
370372
if all(isinstance(t, torch.Tensor) for t in flat):
371373
# We need to flatten dynamic shapes as well
372374
ds = flatten_dynamic_shapes(ds)
373-
return cls._generic_walker_step(
375+
res = cls._generic_walker_step(
374376
processor, flat, ds, flatten_unflatten=flatten_unflatten
375377
)
378+
# Then we restore the original class.
379+
return torch.utils._pytree.tree_unflatten(res, spec)
376380

377381
class ChangeDimensionProcessor:
378382
def __init__(self, desired_values):

onnx_diagnostic/helpers/helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,15 @@ def string_type(
666666
print(f"[string_type] CACHE4:{type(obj)}")
667667
return f"{obj.__class__.__name__}(...)"
668668

669+
if obj.__class__.__name__.endswith("Config"):
670+
import transformers.configuration_utils as tcu
671+
672+
if isinstance(obj, tcu.PretrainedConfig):
673+
if verbose:
674+
print(f"[string_type] CONFIG:{type(obj)}")
675+
s = str(obj.to_diff_dict()).replace("\n", "").replace(" ", "")
676+
return f"{obj.__class__.__name__}(**{s})"
677+
669678
if verbose:
670679
print(f"[string_type] END:{type(obj)}")
671680
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")

0 commit comments

Comments
 (0)