Skip to content

Commit ab0f3de

Browse files
committed
add change dynamic
1 parent 735b373 commit ab0f3de

File tree

2 files changed

+199
-1
lines changed

2 files changed

+199
-1
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,50 @@ def test_couple_input_ds_args_kwargs_1(self):
605605
).invalid_paths(),
606606
)
607607

608+
def test_couple_input_ds_replace_string(self):
609+
T3x1 = torch.rand((3, 1))
610+
T3x4 = torch.rand((3, 4))
611+
T5x1 = torch.rand((5, 1))
612+
ds_batch = {0: "batch"}
613+
ds_batch_seq = {0: "batch", 1: "seq"}
614+
args = (T5x1,)
615+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
616+
Cls = CoupleInputsDynamicShapes
617+
self.assertEqual(
618+
{"X": {0: "DYN"}, "A": {0: "DYN"}, "B": ({0: "DYN"}, {0: "DYN"})},
619+
Cls(
620+
args,
621+
kwargs,
622+
{"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)},
623+
args_names=["X"],
624+
).replace_string_by(value="DYN"),
625+
)
626+
self.assertEqual(
627+
{
628+
"A": {0: "DYN"},
629+
"B": ({0: "DYN"}, {0: "DYN", 1: "DYN"}),
630+
"X": {0: "DYN", 1: "DYN"},
631+
},
632+
Cls(
633+
args,
634+
kwargs,
635+
{"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)},
636+
args_names=["X"],
637+
).replace_string_by(value="DYN"),
638+
)
639+
640+
def test_couple_input_ds_change_dynamic_dimensions(self):
641+
T257 = torch.arange(2 * 5 * 7).reshape((2, 5, 7))
642+
T29 = torch.arange(2 * 9).reshape((2, 9))
643+
inst = CoupleInputsDynamicShapes(
644+
(),
645+
{"A": T257, "B": T29},
646+
{"A": {0: "batch", 2: "last"}, "B": {0: "batch", 1: "seq"}},
647+
)
648+
new_input = inst.change_dynamic_dimensions()
649+
self.assertEqual((3, 5, 8), new_input["A"].shape)
650+
self.assertEqual((3, 10), new_input["B"].shape)
651+
608652

609653
if __name__ == "__main__":
610654
unittest.main(verbosity=2)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,15 +488,73 @@ def __str__(self) -> str:
488488
]
489489
)
490490

491+
def replace_string_by(self, value: Any = None):
492+
"""
493+
Replaces string by the value ``torch.export.Dim.DYNAMIC``
494+
(default) or any other value specified by value.
495+
496+
Example:
497+
498+
.. runpython::
499+
:showcode:
500+
501+
import torch
502+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
503+
504+
T3x1 = torch.rand((3, 1))
505+
T3x4 = torch.rand((3, 4))
506+
ds_batch = {0: "batch"}
507+
ds_batch_seq = {0: "batch", 1: "seq"}
508+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
509+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
510+
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_string_by())
511+
"""
512+
return self._generic_walker(
513+
lambda inputs, ds, value=value: self._replace_string_dim_tensor(
514+
inputs, ds, value=value
515+
)
516+
)
517+
518+
@classmethod
519+
def _replace_string_dim_tensor(cls, inputs, ds, value=None):
520+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
521+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
522+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
523+
f"a dictionary is expected to specify a dimension dimension"
524+
)
525+
if value is None:
526+
value = torch.export.Dim.DYNAMIC
527+
new_ds = ds.copy()
528+
for i, v in ds.items():
529+
if isinstance(v, str):
530+
new_ds[i] = value
531+
return new_ds
532+
491533
def invalid_paths(self):
492534
"""
493-
Tells the inputs are valid based on the dynamic shapes definition.
535+
Tells if the inputs are valid based on the dynamic shapes definition.
494536
The method assumes that all custom classes can be serialized.
495537
If some patches were applied to export, they should enabled while
496538
calling this method if the inputs contains such classes.
497539
498540
The function checks that a dynamic dimension does not receive a value
499541
of 0 or 1. It returns a list of invalid path.
542+
543+
Example:
544+
545+
.. runpython::
546+
:showcode:
547+
548+
import torch
549+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
550+
551+
T3x1 = torch.rand((3, 1))
552+
T3x4 = torch.rand((3, 4))
553+
ds_batch = {0: "batch"}
554+
ds_batch_seq = {0: "batch", 1: "seq"}
555+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
556+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
557+
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_paths())
500558
"""
501559
return self._generic_walker(self._valid_shapes_tensor)
502560

@@ -610,3 +668,99 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
610668
)
611669
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
612670
return cls._generic_walker_step(processor, flat, ds)
671+
672+
class ChangeDimensionProcessor:
673+
def __init__(self):
674+
self.mapping = {}
675+
676+
def _build_new_shape(
677+
self, shape: Tuple[int, ...], ds: Dict[int, Any]
678+
) -> Tuple[int, ...]:
679+
new_shape = list(shape)
680+
for i in range(len(shape)):
681+
if i in ds:
682+
if isinstance(ds[i], str):
683+
d = ds[i]
684+
elif isinstance(
685+
ds[i],
686+
(
687+
torch.export.dynamic_shapes._DerivedDim,
688+
torch.export.dynamic_shapes._Dim,
689+
),
690+
):
691+
d = str(ds[i])
692+
elif not isinstance(ds[i], int):
693+
raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
694+
if d in self.mapping:
695+
new_dim = self.mapping[d]
696+
else:
697+
new_dim = shape[i] + 1
698+
self.mapping[d] = new_dim
699+
new_shape[i] = new_dim
700+
return tuple(new_shape)
701+
702+
def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
703+
rank = len(tensor.shape)
704+
for i in range(len(tensor.shape)):
705+
d0 = tensor.shape[i]
706+
d1 = new_shape[i]
707+
if d0 == d1:
708+
continue
709+
alt_shape = list(tensor.shape)
710+
alt_shape[i] = d1
711+
new_tensor = torch.zeros(
712+
tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
713+
)
714+
mind = min(d0, d1)
715+
indices = [slice(None) for _ in range(rank)]
716+
indices[i] = slice(0, mind)
717+
ind = tuple(indices)
718+
new_tensor[ind] = tensor[ind]
719+
if d1 > mind:
720+
for k in range(d1 - mind):
721+
indices0 = [slice(None) for _ in range(rank)]
722+
indices1 = [slice(None) for _ in range(rank)]
723+
indices1[i] = mind + k
724+
indices0[i] = k % mind
725+
new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
726+
tensor = new_tensor
727+
return tensor
728+
729+
def __call__(self, inputs, ds):
730+
assert isinstance(
731+
inputs, torch.Tensor
732+
), f"unexpected type for inputs {type(inputs)}"
733+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
734+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
735+
f"a dictionary is expected to specify a dimension dimension"
736+
)
737+
new_shape = self._build_new_shape(inputs.shape, ds)
738+
return self._build_new_tensor(inputs, new_shape)
739+
740+
def change_dynamic_dimensions(self):
741+
"""
742+
A model exported with dynamic shapes is not necessarily dynamic
743+
just because the user specified dynamic shapes. The algorithm
744+
may discover that a dimension cannot be dynamic and then continues
745+
the export making the assumption it is static. That may lead a wrong
746+
model. This function produces a new set of inputs with different values
747+
for the dimension than the first ones, assuming they were used to export
748+
the model.
749+
750+
Example:
751+
752+
.. runpython::
753+
:showcode:
754+
755+
import torch
756+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
757+
758+
T3x1 = torch.rand((3, 1))
759+
T3x4 = torch.rand((3, 4))
760+
ds_batch = {0: "batch"}
761+
ds_batch_seq = {0: "batch", 1: "seq"}
762+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
763+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
764+
print(CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimension())
765+
"""
766+
return self._generic_walker(self.ChangeDimensionProcessor())

0 commit comments

Comments
 (0)