Skip to content

Commit 40cadf5

Browse files
committed
fix issues
1 parent 15ab96e commit 40cadf5

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def linkcode_resolve(domain, info):
114114
nitpicky = True
115115
# See also scikit-learn/scikit-learn#26761
116116
nitpick_ignore = [
117+
("py:class", "_DimHint"),
118+
("py:class", "KeyPath"),
117119
("py:class", "ast.Node"),
118120
("py:class", "dtype"),
119121
("py:class", "False"),

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,14 @@ def forward(self, x, ind1, ind2):
288288
name="expected shape should be broadcastable to (< 2.9)",
289289
dynamic_shapes=dynamic_shapes,
290290
):
291-
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
292-
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
293-
got = ep.module()(*inputs)
294-
self.assertEqualArray(expected, got)
291+
try:
292+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
293+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
294+
except RuntimeError as e:
295+
self.assertIn(
296+
"Expected input at *args[2].shape[0] to be equal to 1, but got 1024",
297+
str(e),
298+
)
295299

296300
with self.subTest(name="patch for 0/1", dynamic_shapes=dynamic_shapes):
297301
with torch_export_patches():

0 commit comments

Comments
 (0)