Skip to content

Commit fbcd0ab

Browse files
committed
fix rewriting
1 parent 8d5c8cf commit fbcd0ab

File tree

1 file changed

+2
-14
lines changed

1 file changed

+2
-14
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -526,28 +526,16 @@ def loop_body_1(z, iv, x, y):
526526

527527
rewritten_expected2 = RewrittenModel2()(x, y)
528528
self.assertEqualArray(expected, rewritten_expected2)
529-
torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds)
529+
torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds, strict=False)
530530

531531
rewritten = transform_method(Model.forward, verbose=self.verbose)
532-
print("-------")
533-
print(rewritten.code)
534-
print("-------")
535-
536532
self.assertIn("torch.ops.higher_order.scan(", rewritten.code)
537533
Model.forward = rewritten.func
538534
self.assertEqualAny(expected, Model()(x, y))
539535

540-
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
536+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds, strict=False)
541537
self.assertEqualAny(expected, ep.module()(x, y))
542538

543-
"""
544-
z = torch.empty((x.shape[0], y.shape[0]))
545-
def loop_body_0(i, x_row, y, z):
546-
z[i, :] = ((x_row - y) ** 2).sum(dim=-1)
547-
return z
548-
z = torch.ops.higher_order.scan(loop_body_0, [x], [y], [])
549-
"""
550-
551539

552540
if __name__ == "__main__":
553541
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)