Skip to content

Commit c2c98cc

Browse files
authored
Add desired_values to change_dynamic_dimensions (#45)
* Add desired_values to change_dynamic_dimensions * change
1 parent d7e5976 commit c2c98cc

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.4.0
5+
+++++
6+
7+
* :pr:`45`: improve change_dynamic_dimension to fix some dimensions
8+
49
0.3.0
510
+++++
611

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,18 @@ def test_couple_input_ds_change_dynamic_dimensions(self):
691691
self.assertEqual((3, 5, 8), new_input["A"].shape)
692692
self.assertEqual((3, 10), new_input["B"].shape)
693693

694+
def test_couple_input_ds_change_dynamic_dimensions_fixed(self):
695+
T257 = torch.arange(2 * 5 * 7).reshape((2, 5, 7))
696+
T29 = torch.arange(2 * 9).reshape((2, 9))
697+
inst = CoupleInputsDynamicShapes(
698+
(),
699+
{"A": T257, "B": T29},
700+
{"A": {0: "batch", 2: "last"}, "B": {0: "batch", 1: "seq"}},
701+
)
702+
new_input = inst.change_dynamic_dimensions({"seq": 50, "batch": 1})
703+
self.assertEqual((1, 5, 8), new_input["A"].shape)
704+
self.assertEqual((1, 50), new_input["B"].shape)
705+
694706

695707
if __name__ == "__main__":
696708
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
243243
return cls._generic_walker_step(processor, flat, ds)
244244

245245
class ChangeDimensionProcessor:
246-
def __init__(self):
247-
self.mapping = {}
246+
def __init__(self, desired_values):
247+
self.mapping = desired_values or {}
248248

249249
def _build_new_shape(
250250
self, shape: Tuple[int, ...], ds: Dict[int, Any]
@@ -310,7 +310,7 @@ def __call__(self, inputs, ds):
310310
new_shape = self._build_new_shape(inputs.shape, ds)
311311
return self._build_new_tensor(inputs, new_shape)
312312

313-
def change_dynamic_dimensions(self):
313+
def change_dynamic_dimensions(self, desired_values: Optional[Dict[str, int]] = None):
314314
"""
315315
A model exported with dynamic shapes is not necessarily dynamic
316316
just because the user specified dynamic shapes. The algorithm
@@ -320,6 +320,9 @@ def change_dynamic_dimensions(self):
320320
for the dimension than the first ones, assuming they were used to export
321321
the model.
322322
323+
:param desired_values: to fixed named dimension to have the desired value
324+
:return: new inputs
325+
323326
Example:
324327
325328
.. runpython::
@@ -340,7 +343,7 @@ def change_dynamic_dimensions(self):
340343
print("before:", string_type(kwargs, with_shape=True))
341344
print("-after:", string_type(new_kwargs, with_shape=True))
342345
"""
343-
return self._generic_walker(self.ChangeDimensionProcessor())
346+
return self._generic_walker(self.ChangeDimensionProcessor(desired_values))
344347

345348

346349
class ModelInputs:

0 commit comments

Comments
 (0)