Skip to content

Commit 8cad654

Browse files
committed
fix issues
1 parent fbcd0ab commit 8cad654

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from scipy.spatial.distance import cdist
77
import torch
8-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
8+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_torch
99
from onnx_diagnostic.torch_export_patches import torch_export_patches, torch_export_rewrite
1010
from onnx_diagnostic.torch_export_patches.patch_module import (
1111
transform_method,
@@ -498,7 +498,7 @@ def forward(self, x, y):
498498
def loop_body_0(x, y):
499499
x = x.reshape((-1, *x.shape))
500500
z = ((x - y) ** 2).sum(dim=-1)
501-
return [z]
501+
return (z,)
502502

503503
z = torch.ops.higher_order.scan(loop_body_0, [], [x], [y])
504504
return z[0]
@@ -510,29 +510,61 @@ def loop_body_0(x, y):
510510
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
511511
torch.export.export(RewrittenModel(), (x, y), dynamic_shapes=ds)
512512

513+
class RewrittenModelLoop(torch.nn.Module):
514+
def forward(self, z, iv, x, y):
515+
z = z.clone()
516+
i = iv.item()
517+
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
518+
return (z, iv)
519+
520+
inputs = (
521+
torch.empty((x.shape[0], y.shape[0])),
522+
torch.tensor([2], dtype=torch.int64),
523+
x,
524+
y,
525+
)
526+
RewrittenModelLoop()(*inputs)
527+
try:
528+
from experimental_experiment.torch_interpreter.tracing import CustomTracer
529+
except ImportError:
530+
CustomTracer = None
531+
if CustomTracer:
532+
graph = CustomTracer().trace(RewrittenModelLoop())
533+
self.assertNotEmpty(graph)
534+
535+
# does not wiork
536+
# dsl = ({0: DYN, 1: DYN}, {}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN})
537+
# torch.export.export(RewrittenModelLoop(), inputs, dynamic_shapes=dsl)
538+
513539
class RewrittenModel2(torch.nn.Module):
514540
def forward(self, x, y):
515541
def loop_body_1(z, iv, x, y):
516542
z = z.clone()
517543
i = iv.item()
518544
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
519-
return [z, iv]
545+
return (z, iv)
520546

521547
z = torch.empty((x.shape[0], y.shape[0]))
522548
r = torch.ops.higher_order.scan(
523-
loop_body_1, [z], [torch.arange(x.shape[0], dtype=torch.int64)], [x, y]
549+
loop_body_1,
550+
[z],
551+
[torch.arange(x.shape[0], dtype=torch.int64).reshape((-1, 1))],
552+
[x, y],
524553
)
525554
return r[0]
526555

527-
rewritten_expected2 = RewrittenModel2()(x, y)
528-
self.assertEqualArray(expected, rewritten_expected2)
529-
torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds, strict=False)
530-
531556
rewritten = transform_method(Model.forward, verbose=self.verbose)
532557
self.assertIn("torch.ops.higher_order.scan(", rewritten.code)
533558
Model.forward = rewritten.func
534559
self.assertEqualAny(expected, Model()(x, y))
535560

561+
rewritten_expected2 = RewrittenModel2()(x, y)
562+
self.assertEqualArray(expected, rewritten_expected2)
563+
564+
if not has_torch("2.9"):
565+
raise unittest.SkipTest("skipped export, torch must be >= 2.9")
566+
567+
torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds, strict=False)
536568
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds, strict=False)
537569
self.assertEqualAny(expected, ep.module()(x, y))
538570

0 commit comments

Comments
 (0)