Skip to content

Commit ff91678

Browse files
committed
fix example
1 parent 94d69b8 commit ff91678

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

_doc/recipes/plot_dynamic_shapes_max.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,13 @@ def register(fct, fct_shape, namespace, fname):
126126
# %%
127127
# Now everything is registered. Let's rewrite the model.
128128

129+
129130
class RewrittenModel(torch.nn.Module):
130131
def forward(self, x, y, fact):
131132
z = torch.ops.mylib.copy_max_dimensions(x, y)
132133
return z * fact
133134

135+
134136
# %%
135137
# And check it works.
136138

_doc/recipes/plot_dynamic_shapes_nonzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Half certain nonzero
33
====================
44
5-
:func:`torch.nonzero` returns the indices or the first zero found
5+
:func:`torch.nonzero` returns the indices of the first zero found
66
in a tensor. The output shape is unknown in the generic case
77
but... If you have a 2D tensor with at least a nonzero value
88
in every row, you can guess the dimension. But :func:`torch.export.export`
@@ -49,7 +49,7 @@ def forward(self, x):
4949
# ++++++
5050

5151
DYN = torch.export.Dim.DYNAMIC
52-
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
52+
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
5353
print(ep)
5454

5555

_doc/recipes/plot_dynamic_shapes_python_int.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x):
3636
# ++++++
3737

3838
DYN = torch.export.Dim.DYNAMIC
39-
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
39+
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
4040
print(ep)
4141

4242
# %%
@@ -65,7 +65,7 @@ def forward(self, x):
6565
# Export
6666
# ++++++
6767

68-
ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=((DYN, DYN),))
68+
ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
6969
print(ep)
7070

7171

@@ -79,7 +79,7 @@ def forward(self, x):
7979

8080

8181
with bypass_export_some_errors(stop_if_static=True):
82-
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
82+
ep = torch.export.export(model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
8383
print(ep)
8484

8585
# %%

0 commit comments

Comments
 (0)