Skip to content

Commit 9607f38

Browse files
committed
first draft
1 parent d19f48a commit 9607f38

File tree

2 files changed

+270
-17
lines changed

2 files changed

+270
-17
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
import inspect
33
import textwrap
44
import unittest
5+
import numpy as np
6+
from scipy.spatial.distance import cdist
57
import torch
68
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
79
from onnx_diagnostic.torch_export_patches import torch_export_patches, torch_export_rewrite
810
from onnx_diagnostic.torch_export_patches.patch_module import (
911
transform_method,
1012
inplace_add_parent,
13+
ShapeFinder,
14+
RewriteControlFlow,
1115
)
1216

1317

@@ -42,7 +46,7 @@ def forward(self, x, y):
4246
hasattr(node, "parent") for node in ast.walk(tree)
4347
), f"Missing parent in {ast.dump(tree, indent=2)}"
4448

45-
def test_rewrite_forward_return1(self):
49+
def test_rewrite_test_in_forward_return1(self):
4650

4751
class Model(torch.nn.Module):
4852
def forward(self, x, y):
@@ -69,7 +73,7 @@ def forward(self, x, y):
6973
self.assertEqualAny(expected_, ep.module()(-x, y))
7074

7175
@hide_stdout()
72-
def test_rewrite_forward_return2(self):
76+
def test_rewrite_test_in_forward_return2(self):
7377

7478
class Model(torch.nn.Module):
7579
def forward(self, x, y):
@@ -95,7 +99,7 @@ def forward(self, x, y):
9599
self.assertEqualAny(expected, ep.module()(x, y))
96100
self.assertEqualAny(expected_, ep.module()(-x, y))
97101

98-
def test_rewrite_forward_assign1(self):
102+
def test_rewrite_test_in_forward_assign1(self):
99103

100104
class Model(torch.nn.Module):
101105
def forward(self, x, y):
@@ -122,7 +126,7 @@ def forward(self, x, y):
122126
self.assertEqualAny(expected, ep.module()(x, y))
123127
self.assertEqualArray(expected_, ep.module()(-x, y))
124128

125-
def test_rewrite_forward_assign2(self):
129+
def test_rewrite_test_in_forward_assign2(self):
126130

127131
class Model(torch.nn.Module):
128132
def forward(self, x, y):
@@ -149,7 +153,7 @@ def forward(self, x, y):
149153
self.assertEqualAny(expected, ep.module()(x, y))
150154
self.assertEqualAny(expected_, ep.module()(-x, y))
151155

152-
def test_rewrite_forward_assign_noelse(self):
156+
def test_rewrite_test_in_forward_assign_noelse(self):
153157

154158
class Model(torch.nn.Module):
155159
def forward(self, x, y):
@@ -174,7 +178,7 @@ def forward(self, x, y):
174178
self.assertEqualAny(expected, ep.module()(x, y))
175179
self.assertEqualAny(expected_, ep.module()(-x, y))
176180

177-
def test_rewrite_forward_return_noelse(self):
181+
def test_rewrite_test_in_forward_return_noelse(self):
178182

179183
class Model(torch.nn.Module):
180184
def forward(self, x, y):
@@ -186,7 +190,7 @@ def forward(self, x, y):
186190
lambda: transform_method(Model.forward, verbose=self.verbose), NotImplementedError
187191
)
188192

189-
def test_rewrite_forward_assign2_in_2(self):
193+
def test_rewrite_test_in_forward_assign2_in_2(self):
190194

191195
class Model(torch.nn.Module):
192196
def forward(self, x, y):
@@ -215,7 +219,7 @@ def forward(self, x, y):
215219
self.assertEqualAny(expected, ep.module()(x, y))
216220
self.assertEqualAny(expected_, ep.module()(-x, y))
217221

218-
def test_rewrite_forward_assign2_in_3(self):
222+
def test_rewrite_test_in_forward_assign2_in_3(self):
219223

220224
class Model(torch.nn.Module):
221225
def forward(self, x, y):
@@ -292,7 +296,7 @@ def torch_cond_else_2(y):
292296
x, y = torch.rand((3, 4)), torch.rand((3, 4))
293297
Model()(x, y)
294298

295-
def test_rewrite_forward_assign_nested(self):
299+
def test_rewrite_test_in_forward_assign_nested(self):
296300

297301
class Model(torch.nn.Module):
298302
def forward(self, x, y):
@@ -341,7 +345,7 @@ def forward(self, x, y):
341345
self.assertEqualAny(expected_0, ep.module()(x, -y))
342346
self.assertEqualAny(expected_1, ep.module()(-x, -y))
343347

344-
def test_rewrite_forward_none(self):
348+
def test_rewrite_test_in_forward_none(self):
345349

346350
class Model(torch.nn.Module):
347351
def forward(self, x, y):
@@ -365,7 +369,7 @@ def forward(self, x, y):
365369
self.assertEqualAny(expected, ep.module()(x, y))
366370
self.assertEqualAny(expected_, ep.module()(-x, y))
367371

368-
def test_rewrite_PLBartEncoderLayer(self):
372+
def test_rewrite_test_in_PLBartEncoderLayer(self):
369373
from transformers.models.plbart.modeling_plbart import PLBartEncoderLayer
370374

371375
rewritten = transform_method(PLBartEncoderLayer.forward, verbose=self.verbose)
@@ -443,6 +447,91 @@ def forward(self, x, y):
443447
got = ep.module()(x, y)
444448
self.assertEqualArray(expected, got)
445449

450+
def test_shape_finder(self):
451+
expr = "range(x.shape[0])"
452+
node = ast.parse(expr)
453+
sh = ShapeFinder()
454+
sh.visit(node)
455+
self.assertEqual({"x"}, sh.found_shape)
456+
457+
def test__find_loop_vars(self):
458+
code = textwrap.dedent(
459+
"""
460+
for i in range(x.shape[0]):
461+
z[i, :] = ((x[i : i + 1, :] - y) ** 2).sum(dim=-1)
462+
"""
463+
)
464+
node = ast.parse(code)
465+
tr = RewriteControlFlow()
466+
vars = tr._find_loop_vars(node.body[0])
467+
self.assertEqual(
468+
{"loop": ["i"], "scan": ["x"], "input": ["y"], "output": ["z"], "init": []}, vars
469+
)
470+
471+
def test_rewrite_loop(self):
472+
473+
class Model(torch.nn.Module):
474+
def forward(self, x, y):
475+
z = torch.empty((x.shape[0], y.shape[0]))
476+
for i in range(x.shape[0]):
477+
z[i, :] = ((x[i : i + 1, :] - y) ** 2).sum(dim=-1)
478+
return z
479+
480+
class RewrittenModel(torch.nn.Module):
481+
def forward(self, x, y):
482+
def loop_body_0(x, y):
483+
x = x.reshape((-1, *x.shape))
484+
z = ((x - y) ** 2).sum(dim=-1)
485+
return [z]
486+
487+
z = torch.ops.higher_order.scan(loop_body_0, [], [x], [y])
488+
return z[0]
489+
490+
class RewrittenModel2(torch.nn.Module):
491+
def forward(self, x, y):
492+
def loop_body_1(z, i, x, y):
493+
z = z.clone()
494+
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
495+
return [z, i]
496+
497+
z = torch.empty((x.shape[0], y.shape[0]))
498+
r = torch.ops.higher_order.scan(
499+
loop_body_1, [z], [torch.arange(x.shape[0], dtype=torch.int64)], [x, y]
500+
)
501+
return r[0]
502+
503+
x, y = torch.rand((3, 4)), torch.rand((5, 4))
504+
expected = Model()(x, y)
505+
self.assertEqualArray(
506+
expected.numpy(),
507+
cdist(x.numpy(), y.numpy(), metric="sqeuclidean").astype(np.float32),
508+
atol=1e-5,
509+
)
510+
rewritten_expected = RewrittenModel()(x, y)
511+
self.assertEqualArray(expected, rewritten_expected)
512+
rewritten_expected2 = RewrittenModel2()(x, y)
513+
self.assertEqualArray(expected, rewritten_expected2)
514+
515+
rewritten = transform_method(Model.forward, verbose=self.verbose)
516+
print(rewritten.code)
517+
518+
self.assertIn("torch.ops.higher_order.scan(", rewritten.code)
519+
Model.forward = rewritten.func
520+
self.assertEqualAny(expected, Model()(x, y))
521+
522+
DYN = torch.export.Dim.DYNAMIC
523+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
524+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
525+
self.assertEqualAny(expected, ep.module()(x, y))
526+
527+
"""
528+
z = torch.empty((x.shape[0], y.shape[0]))
529+
def loop_body_0(i, x_row, y, z):
530+
z[i, :] = ((x_row - y) ** 2).sum(dim=-1)
531+
return z
532+
z = torch.ops.higher_order.scan(loop_body_0, [x], [y], [])
533+
"""
534+
446535

447536
if __name__ == "__main__":
448537
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)