Skip to content

Commit c3da823

Browse files
authored
First draft to automate the rewriting of loops (#101)
* first draft * scan * fix rewriting * fix issues * extend test * fix * fix
1 parent a0ef200 commit c3da823

File tree

2 files changed

+431
-20
lines changed

2 files changed

+431
-20
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 145 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
import inspect
33
import textwrap
44
import unittest
5+
import numpy as np
6+
from scipy.spatial.distance import cdist
57
import torch
6-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
8+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_torch, requires_torch
79
from onnx_diagnostic.torch_export_patches import torch_export_patches, torch_export_rewrite
810
from onnx_diagnostic.torch_export_patches.patch_module import (
911
transform_method,
1012
inplace_add_parent,
13+
ShapeFinder,
14+
RewriteControlFlow,
1115
)
1216

1317

@@ -42,7 +46,7 @@ def forward(self, x, y):
4246
hasattr(node, "parent") for node in ast.walk(tree)
4347
), f"Missing parent in {ast.dump(tree, indent=2)}"
4448

45-
def test_rewrite_forward_return1(self):
49+
def test_rewrite_test_in_forward_return1(self):
4650

4751
class Model(torch.nn.Module):
4852
def forward(self, x, y):
@@ -69,7 +73,7 @@ def forward(self, x, y):
6973
self.assertEqualAny(expected_, ep.module()(-x, y))
7074

7175
@hide_stdout()
72-
def test_rewrite_forward_return2(self):
76+
def test_rewrite_test_in_forward_return2(self):
7377

7478
class Model(torch.nn.Module):
7579
def forward(self, x, y):
@@ -95,7 +99,7 @@ def forward(self, x, y):
9599
self.assertEqualAny(expected, ep.module()(x, y))
96100
self.assertEqualAny(expected_, ep.module()(-x, y))
97101

98-
def test_rewrite_forward_assign1(self):
102+
def test_rewrite_test_in_forward_assign1(self):
99103

100104
class Model(torch.nn.Module):
101105
def forward(self, x, y):
@@ -122,7 +126,7 @@ def forward(self, x, y):
122126
self.assertEqualAny(expected, ep.module()(x, y))
123127
self.assertEqualArray(expected_, ep.module()(-x, y))
124128

125-
def test_rewrite_forward_assign2(self):
129+
def test_rewrite_test_in_forward_assign2(self):
126130

127131
class Model(torch.nn.Module):
128132
def forward(self, x, y):
@@ -149,7 +153,7 @@ def forward(self, x, y):
149153
self.assertEqualAny(expected, ep.module()(x, y))
150154
self.assertEqualAny(expected_, ep.module()(-x, y))
151155

152-
def test_rewrite_forward_assign_noelse(self):
156+
def test_rewrite_test_in_forward_assign_noelse(self):
153157

154158
class Model(torch.nn.Module):
155159
def forward(self, x, y):
@@ -174,7 +178,7 @@ def forward(self, x, y):
174178
self.assertEqualAny(expected, ep.module()(x, y))
175179
self.assertEqualAny(expected_, ep.module()(-x, y))
176180

177-
def test_rewrite_forward_return_noelse(self):
181+
def test_rewrite_test_in_forward_return_noelse(self):
178182

179183
class Model(torch.nn.Module):
180184
def forward(self, x, y):
@@ -186,7 +190,7 @@ def forward(self, x, y):
186190
lambda: transform_method(Model.forward, verbose=self.verbose), NotImplementedError
187191
)
188192

189-
def test_rewrite_forward_assign2_in_2(self):
193+
def test_rewrite_test_in_forward_assign2_in_2(self):
190194

191195
class Model(torch.nn.Module):
192196
def forward(self, x, y):
@@ -215,7 +219,7 @@ def forward(self, x, y):
215219
self.assertEqualAny(expected, ep.module()(x, y))
216220
self.assertEqualAny(expected_, ep.module()(-x, y))
217221

218-
def test_rewrite_forward_assign2_in_3(self):
222+
def test_rewrite_test_in_forward_assign2_in_3(self):
219223

220224
class Model(torch.nn.Module):
221225
def forward(self, x, y):
@@ -292,7 +296,7 @@ def torch_cond_else_2(y):
292296
x, y = torch.rand((3, 4)), torch.rand((3, 4))
293297
Model()(x, y)
294298

295-
def test_rewrite_forward_assign_nested(self):
299+
def test_rewrite_test_in_forward_assign_nested(self):
296300

297301
class Model(torch.nn.Module):
298302
def forward(self, x, y):
@@ -341,7 +345,7 @@ def forward(self, x, y):
341345
self.assertEqualAny(expected_0, ep.module()(x, -y))
342346
self.assertEqualAny(expected_1, ep.module()(-x, -y))
343347

344-
def test_rewrite_forward_none(self):
348+
def test_rewrite_test_in_forward_none(self):
345349

346350
class Model(torch.nn.Module):
347351
def forward(self, x, y):
@@ -365,7 +369,7 @@ def forward(self, x, y):
365369
self.assertEqualAny(expected, ep.module()(x, y))
366370
self.assertEqualAny(expected_, ep.module()(-x, y))
367371

368-
def test_rewrite_PLBartEncoderLayer(self):
372+
def test_rewrite_test_in_PLBartEncoderLayer(self):
369373
from transformers.models.plbart.modeling_plbart import PLBartEncoderLayer
370374

371375
rewritten = transform_method(PLBartEncoderLayer.forward, verbose=self.verbose)
@@ -443,6 +447,135 @@ def forward(self, x, y):
443447
got = ep.module()(x, y)
444448
self.assertEqualArray(expected, got)
445449

450+
def test_shape_finder(self):
451+
expr = "range(x.shape[0])"
452+
node = ast.parse(expr)
453+
sh = ShapeFinder()
454+
sh.visit(node)
455+
self.assertEqual({"x"}, sh.found_shape)
456+
457+
def test__find_loop_vars(self):
458+
code = textwrap.dedent(
459+
"""
460+
for i in range(x.shape[0]):
461+
z[i, :] = ((x[i : i + 1, :] - y) ** 2).sum(dim=-1)
462+
"""
463+
)
464+
node = ast.parse(code)
465+
tr = RewriteControlFlow()
466+
vars = tr._find_loop_vars(node.body[0])
467+
self.assertEqual(
468+
{
469+
"init": ["z"],
470+
"input": ["y"],
471+
"loop": ["i"],
472+
"output": [],
473+
"scan": [],
474+
"scan_shape": ["x"],
475+
},
476+
vars,
477+
)
478+
479+
@requires_torch("2.8")
480+
def test_rewrite_loop(self):
481+
482+
class Model(torch.nn.Module):
483+
def forward(self, x, y):
484+
z = torch.empty((x.shape[0], y.shape[0]))
485+
for i in range(x.shape[0]):
486+
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
487+
return z
488+
489+
x, y = torch.rand((3, 4)), torch.rand((5, 4))
490+
expected = Model()(x, y)
491+
self.assertEqualArray(
492+
expected.numpy(),
493+
cdist(x.numpy(), y.numpy(), metric="sqeuclidean").astype(np.float32),
494+
atol=1e-5,
495+
)
496+
497+
class RewrittenModel(torch.nn.Module):
498+
def forward(self, x, y):
499+
def loop_body_0(x, y):
500+
x = x.reshape((-1, *x.shape))
501+
z = ((x - y) ** 2).sum(dim=-1)
502+
return (z,)
503+
504+
z = torch.ops.higher_order.scan(loop_body_0, [], [x], [y])
505+
return z[0]
506+
507+
rewritten_expected = RewrittenModel()(x, y)
508+
self.assertEqualArray(expected, rewritten_expected)
509+
510+
DYN = torch.export.Dim.DYNAMIC
511+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
512+
torch.export.export(RewrittenModel(), (x, y), dynamic_shapes=ds)
513+
514+
class RewrittenModelLoop(torch.nn.Module):
515+
def forward(self, z, iv, x, y):
516+
z = z.clone()
517+
i = iv.item()
518+
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
519+
return (z, iv)
520+
521+
inputs = (
522+
torch.empty((x.shape[0], y.shape[0])),
523+
torch.tensor([2], dtype=torch.int64),
524+
x,
525+
y,
526+
)
527+
RewrittenModelLoop()(*inputs)
528+
try:
529+
from experimental_experiment.torch_interpreter.tracing import CustomTracer
530+
except ImportError:
531+
CustomTracer = None
532+
if CustomTracer:
533+
graph = CustomTracer().trace(RewrittenModelLoop())
534+
self.assertNotEmpty(graph)
535+
536+
# does not wiork
537+
# dsl = ({0: DYN, 1: DYN}, {}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN})
538+
# torch.export.export(RewrittenModelLoop(), inputs, dynamic_shapes=dsl)
539+
540+
class RewrittenModel2(torch.nn.Module):
541+
def forward(self, x, y):
542+
def loop_body_1(z, iv, x, y):
543+
z = z.clone()
544+
i = iv.item()
545+
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
546+
return (z, iv)
547+
548+
z = torch.empty((x.shape[0], y.shape[0]))
549+
r = torch.ops.higher_order.scan(
550+
loop_body_1,
551+
[z],
552+
[torch.arange(x.shape[0], dtype=torch.int64).reshape((-1, 1))],
553+
[x, y],
554+
)
555+
return r[0]
556+
557+
rewritten = transform_method(Model.forward, verbose=self.verbose)
558+
self.assertIn("torch.ops.higher_order.scan(", rewritten.code)
559+
Model.forward = rewritten.func
560+
self.assertEqualAny(expected, Model()(x, y))
561+
562+
rewritten_expected2 = RewrittenModel2()(x, y)
563+
self.assertEqualArray(expected, rewritten_expected2)
564+
565+
if not has_torch("2.9"):
566+
raise unittest.SkipTest("skipped export, torch must be >= 2.9")
567+
568+
torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds, strict=False)
569+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds, strict=False)
570+
self.assertEqualAny(expected, ep.module()(x, y))
571+
572+
"""
573+
position_encodings = torch.cat(
574+
[weight[:, :required_pos_encodings_columns]
575+
for weight in broadcasted_weights], dim=-1
576+
)
577+
"""
578+
446579

447580
if __name__ == "__main__":
448581
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)