Skip to content

Commit 103bf64

Browse files
pianpwkpytorchmergebot
authored andcommitted
[export] refactor _Dim into Dim (pytorch#149891)
Summary: forward fix T218515233 Test Plan: test_export Differential Revision: D71769231 Pull Request resolved: pytorch#149891 Approved by: https://github.com/jingsh, https://github.com/angelayi
1 parent f649ee7 commit 103bf64

File tree

7 files changed

+120
-111
lines changed

7 files changed

+120
-111
lines changed

docs/source/export.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ API Reference
790790
.. autofunction:: save
791791
.. autofunction:: load
792792
.. autofunction:: register_dataclass
793-
.. autofunction:: torch.export.dynamic_shapes.Dim
793+
.. autoclass:: torch.export.dynamic_shapes.Dim
794794
.. autofunction:: torch.export.exported_program.default_decompositions
795795
.. autofunction:: dims
796796
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection

test/export/test_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3852,7 +3852,7 @@ def forward(self, x):
38523852

38533853
dynamic_shapes = (
38543854
{"k": {"k": dim}},
3855-
) # ValueError: Node type mismatch; expected <class 'list'>, but got .*_Dim.*.
3855+
) # ValueError: Node type mismatch; expected <class 'list'>, but got .*Dim.*.
38563856
with self.assertRaisesRegex(
38573857
torch._dynamo.exc.UserError,
38583858
re.escape(
@@ -12362,7 +12362,7 @@ def test_dynamic_shapes_serdes_user_errors(self):
1236212362

1236312363
self.assertExpectedInline(
1236412364
_load_dynamic_shapes(spec, from_dict=False),
12365-
"""[[<class 'torch._export.serde.dynamic_shapes.dx'>]]""",
12365+
"""[[Dim('dx', min=4, max=16)]]""",
1236612366
)
1236712367

1236812368
# check incorrect info in dims

test/onnx/exporter/test_small_models_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def forward(
450450
)
451451

452452
dynamic_shapes = (
453-
{0: torch.export.Dim("dim_x", min=3)}, # _Dim
453+
{0: torch.export.Dim("dim_x", min=3)}, # Dim
454454
[("custom_name_axis_ys_0",), (torch.export.Dim.AUTO,)], # custom name
455455
{
456456
"a": {0: torch.export.Dim.AUTO},

torch/_export/serde/dynamic_shapes.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch.export.dynamic_shapes import (
77
_check_dynamic_shapes,
88
_DerivedDim,
9-
_Dim,
109
_DimHint,
1110
_tree_map_with_path,
1211
Dim,
@@ -19,7 +18,7 @@
1918
@dataclasses.dataclass
2019
class RootDim:
2120
"""
22-
This represents a _Dim object.
21+
This represents a Dim object.
2322
"""
2423

2524
min: int
@@ -150,7 +149,7 @@ def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
150149
return out
151150

152151
def _track_dim_from_dims(
153-
val: Union[None, int, _DimHint, _Dim]
152+
val: Union[None, int, _DimHint, Dim]
154153
) -> Union[None, int, str]:
155154
"""
156155
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
@@ -160,7 +159,7 @@ def _track_dim_from_dims(
160159
if isinstance(val, _DimHint): # store enum as string
161160
return val.__class__.__name__ + "." + val.type.name
162161

163-
assert isinstance(val, _Dim)
162+
assert isinstance(val, Dim)
164163

165164
# track root dim
166165
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
@@ -297,7 +296,7 @@ def _load_dynamic_shapes(
297296

298297
def deserialize_shape(
299298
val: Union[None, int, str]
300-
) -> Union[None, int, _Dim, _DimHint]:
299+
) -> Union[None, int, Dim, _DimHint]:
301300
if val is None or isinstance(val, int):
302301
return val
303302
elif val == "_DimHint.AUTO":

0 commit comments

Comments
 (0)