Skip to content

Commit 78eded8

Browse files
titaiwangmspytorchmergebot
authored andcommitted
[ONNX] Use torch.export.Dim.AUTO in dynamo_export (pytorch#144356)
Align to the changes in pytorch#143158 Pull Request resolved: pytorch#144356 Approved by: https://github.com/justinchuby
1 parent 90e81a1 commit 78eded8

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

torch/onnx/__init__.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -480,29 +480,20 @@ def forward(self, x, bias=None):
480480
)
481481

482482
if export_options is not None and export_options.dynamic_shapes:
483-
# Make all shapes dynamic
484-
def _to_dynamic_shapes_mapper():
485-
arg_order = 0
486-
487-
def _to_dynamic_shape(x):
488-
nonlocal arg_order
489-
if isinstance(x, torch.Tensor):
490-
rank = len(x.shape)
491-
dynamic_shape = {}
492-
for i in range(rank):
493-
dynamic_shape[i] = torch.export.Dim(
494-
f"arg_{arg_order}_dim_{i}"
495-
)
496-
arg_order += 1
497-
return dynamic_shape
498-
else:
499-
return None
500-
501-
return _to_dynamic_shape
483+
# Make all shapes dynamic if it's possible
484+
def _to_dynamic_shape(x):
485+
if isinstance(x, torch.Tensor):
486+
rank = len(x.shape)
487+
dynamic_shape = {}
488+
for i in range(rank):
489+
dynamic_shape[i] = torch.export.Dim.AUTO # type: ignore[attr-defined]
490+
return dynamic_shape
491+
else:
492+
return None
502493

503494
# model_args could be nested
504495
dynamic_shapes = _pytree.tree_map(
505-
_to_dynamic_shapes_mapper(),
496+
_to_dynamic_shape,
506497
model_args,
507498
)
508499
else:

0 commit comments

Comments
 (0)