Skip to content

Commit 8d5c8cf

Browse files
committed
scan
1 parent 9607f38 commit 8d5c8cf

File tree

2 files changed

+170
-40
lines changed

2 files changed

+170
-40
lines changed

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,15 @@ def test__find_loop_vars(self):
465465
tr = RewriteControlFlow()
466466
vars = tr._find_loop_vars(node.body[0])
467467
self.assertEqual(
468-
{"loop": ["i"], "scan": ["x"], "input": ["y"], "output": ["z"], "init": []}, vars
468+
{
469+
"init": ["z"],
470+
"input": ["y"],
471+
"loop": ["i"],
472+
"output": [],
473+
"scan": [],
474+
"scan_shape": ["x"],
475+
},
476+
vars,
469477
)
470478

471479
def test_rewrite_loop(self):
@@ -474,9 +482,17 @@ class Model(torch.nn.Module):
474482
def forward(self, x, y):
475483
z = torch.empty((x.shape[0], y.shape[0]))
476484
for i in range(x.shape[0]):
477-
z[i, :] = ((x[i : i + 1, :] - y) ** 2).sum(dim=-1)
485+
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
478486
return z
479487

488+
x, y = torch.rand((3, 4)), torch.rand((5, 4))
489+
expected = Model()(x, y)
490+
self.assertEqualArray(
491+
expected.numpy(),
492+
cdist(x.numpy(), y.numpy(), metric="sqeuclidean").astype(np.float32),
493+
atol=1e-5,
494+
)
495+
480496
class RewrittenModel(torch.nn.Module):
481497
def forward(self, x, y):
482498
def loop_body_0(x, y):
@@ -487,40 +503,40 @@ def loop_body_0(x, y):
487503
z = torch.ops.higher_order.scan(loop_body_0, [], [x], [y])
488504
return z[0]
489505

506+
rewritten_expected = RewrittenModel()(x, y)
507+
self.assertEqualArray(expected, rewritten_expected)
508+
509+
DYN = torch.export.Dim.DYNAMIC
510+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
511+
torch.export.export(RewrittenModel(), (x, y), dynamic_shapes=ds)
512+
490513
class RewrittenModel2(torch.nn.Module):
491514
def forward(self, x, y):
492-
def loop_body_1(z, i, x, y):
515+
def loop_body_1(z, iv, x, y):
493516
z = z.clone()
517+
i = iv.item()
494518
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
495-
return [z, i]
519+
return [z, iv]
496520

497521
z = torch.empty((x.shape[0], y.shape[0]))
498522
r = torch.ops.higher_order.scan(
499523
loop_body_1, [z], [torch.arange(x.shape[0], dtype=torch.int64)], [x, y]
500524
)
501525
return r[0]
502526

503-
x, y = torch.rand((3, 4)), torch.rand((5, 4))
504-
expected = Model()(x, y)
505-
self.assertEqualArray(
506-
expected.numpy(),
507-
cdist(x.numpy(), y.numpy(), metric="sqeuclidean").astype(np.float32),
508-
atol=1e-5,
509-
)
510-
rewritten_expected = RewrittenModel()(x, y)
511-
self.assertEqualArray(expected, rewritten_expected)
512527
rewritten_expected2 = RewrittenModel2()(x, y)
513528
self.assertEqualArray(expected, rewritten_expected2)
529+
torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds)
514530

515531
rewritten = transform_method(Model.forward, verbose=self.verbose)
532+
print("-------")
516533
print(rewritten.code)
534+
print("-------")
517535

518536
self.assertIn("torch.ops.higher_order.scan(", rewritten.code)
519537
Model.forward = rewritten.func
520538
self.assertEqualAny(expected, Model()(x, y))
521539

522-
DYN = torch.export.Dim.DYNAMIC
523-
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
524540
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
525541
self.assertEqualAny(expected, ep.module()(x, y))
526542

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 139 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,13 @@ def _check(
117117
"""
118118
if cls is not None:
119119
if not cond:
120-
raise cls(f"{msg}\n\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}")
120+
smsg = msg if isinstance(msg, str) else msg()
121+
raise cls(f"{smsg}\n\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}")
121122
return
122-
assert cond, f"{msg}\n\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
123+
assert cond, (
124+
f"{msg if isinstance(msg, str) else msg()}\n\n--\n"
125+
f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
126+
)
123127

124128
def visit_Name(self, node):
125129
node = self.generic_visit(node)
@@ -362,18 +366,30 @@ def _find_loop_vars(self, node):
362366
assert isinstance(node, ast.For), f"Unexpected type {type(node)} for node"
363367
finder = ShapeFinder()
364368
finder.visit(node.iter)
365-
scan_vars = finder.found_shape
369+
scan_shape_vars = finder.found_shape
370+
scan_vars = set()
366371

367372
finder = UsedVarsFinder()
368373
for stmt in node.body:
369374
finder.visit(stmt)
370375

376+
assigned_in_body = set()
377+
for stmt in node.body:
378+
if isinstance(stmt, ast.Assign):
379+
for tgt in stmt.targets:
380+
if isinstance(tgt, ast.Name) and isinstance(tgt.value.ctx, ast.Store):
381+
assigned_in_body |= {tgt.value.id}
382+
371383
extra_defined = set()
372384
for stmt in node.body:
373385
if isinstance(stmt, ast.Assign):
374386
for tgt in stmt.targets:
375387
if isinstance(tgt, ast.Subscript):
376-
if isinstance(tgt.value, ast.Name):
388+
# It means the target existed before.
389+
if (
390+
isinstance(tgt.value, ast.Name)
391+
and tgt.value.id not in assigned_in_body
392+
):
377393
extra_defined.add(tgt.value.id)
378394

379395
loop_vars = set()
@@ -382,11 +398,21 @@ def _find_loop_vars(self, node):
382398
elif isinstance(node.target, (ast.Tuple, ast.List)):
383399
loop_vars |= {elt.id for elt in node.target.elts if isinstance(elt, ast.Name)}
384400

385-
output_vars = finder.defined | extra_defined
386-
input_vars = finder.used - finder.defined - loop_vars - scan_vars - output_vars
401+
output_vars = finder.defined | assigned_in_body
402+
input_vars = (
403+
finder.used
404+
- finder.defined
405+
- loop_vars
406+
- scan_shape_vars
407+
- scan_vars
408+
- output_vars
409+
- assigned_in_body
410+
- extra_defined
411+
)
387412
return dict(
388-
init=[],
413+
init=sorted(extra_defined),
389414
loop=sorted(loop_vars),
415+
scan_shape=sorted(scan_shape_vars),
390416
scan=sorted(scan_vars),
391417
input=sorted(input_vars),
392418
output=sorted(output_vars),
@@ -397,19 +423,30 @@ def visit_For(self, node):
397423
self.generic_visit(node)
398424
# look for variables, loop, inputs and outputs of the body
399425
vars = self._find_loop_vars(node)
400-
init_vars, loop_vars, scan_vars, input_vars, output_vars = [
401-
vars[k] for k in ["init", "loop", "scan", "input", "output"]
426+
init_vars, loop_vars, scan_shape_vars, scan_vars, input_vars, output_vars = [
427+
vars[k] for k in ["init", "loop", "scan_shape", "scan", "input", "output"]
402428
]
403-
404-
# return, one value or a tuple of values
405-
return_stmt = ast.Return(
406-
value=(
407-
ast.Name(id=output_vars[0], ctx=ast.Load())
408-
if len(output_vars) == 1
409-
else ast.Tuple(
410-
elts=[ast.Name(id=v, ctx=ast.Load()) for v in output_vars], ctx=ast.Load()
411-
)
412-
)
429+
self._check(
430+
len(scan_shape_vars) == len(loop_vars),
431+
node,
432+
lambda: (
433+
f"Inconsistencies between loop_vars={loop_vars} "
434+
f"and scan_shape_vars={scan_shape_vars}"
435+
),
436+
)
437+
self._check(
438+
len(scan_shape_vars) in {0, 1},
439+
node,
440+
lambda: f"Inconsistencies with scan_shape_vars={scan_shape_vars}",
441+
)
442+
self._check(
443+
(len(scan_shape_vars) == 0 or len(scan_vars) == 0)
444+
and (scan_shape_vars or scan_vars),
445+
node,
446+
lambda: (
447+
f"Inconsistencies between scan_vars={scan_vars} "
448+
f"and scan_shape_vars={scan_shape_vars}"
449+
),
413450
)
414451

415452
# creates the function
@@ -419,12 +456,50 @@ def visit_For(self, node):
419456
name=func_name,
420457
args=ast.arguments(
421458
posonlyargs=[],
422-
args=[ast.arg(arg=v) for v in [*init_vars, *scan_vars, *input_vars]],
459+
args=[
460+
ast.arg(arg=v)
461+
for v in [
462+
*init_vars,
463+
*loop_vars,
464+
*scan_vars,
465+
*scan_shape_vars,
466+
*input_vars,
467+
]
468+
],
423469
kwonlyargs=[],
424470
kw_defaults=[],
425471
defaults=[],
426472
),
427-
body=[*node.body, return_stmt],
473+
body=[
474+
*[
475+
ast.Assign(
476+
targets=[ast.Name(id=i, ctx=ast.Load())],
477+
value=[
478+
ast.Call(
479+
func=ast.Attribute(
480+
value=ast.Name(id=i, ctx=ast.Load()),
481+
attr="clone",
482+
ctx=ast.Load(),
483+
),
484+
args=[],
485+
keywords=[],
486+
ctx=ast.Load(),
487+
)
488+
],
489+
)
490+
for i in init_vars
491+
],
492+
*node.body,
493+
ast.Return(
494+
value=ast.List(
495+
[
496+
ast.Name(id=v, ctx=ast.Load())
497+
for v in [*init_vars, *loop_vars, *output_vars]
498+
],
499+
ctx=ast.Load(),
500+
)
501+
),
502+
],
428503
decorator_list=[],
429504
ctx=ast.Store(),
430505
)
@@ -452,17 +527,56 @@ def visit_For(self, node):
452527
elts=[ast.Name(id=v, ctx=ast.Load()) for v in init_vars], ctx=ast.Store()
453528
),
454529
ast.List(
455-
elts=[ast.Name(id=v, ctx=ast.Load()) for v in scan_vars], ctx=ast.Store()
530+
elts=[
531+
*[
532+
ast.Call(
533+
ast.Attribute(
534+
value=ast.Name(id="torch", ctx=ast.Load()),
535+
attr="arange",
536+
ctx=ast.Load(),
537+
),
538+
args=[
539+
ast.Subscript(
540+
value=ast.Attribute(
541+
value=ast.Name(id=v, ctx=ast.Load()),
542+
attr="shape",
543+
ctx=ast.Load(),
544+
),
545+
slice=ast.Constant(value=0, ctx=ast.Load()),
546+
ctx=ast.Load(),
547+
),
548+
],
549+
keywords=[
550+
ast.keyword(
551+
arg="dtype",
552+
value=ast.Attribute(
553+
value=ast.Name(id="torch", ctx=ast.Load()),
554+
attr="int64",
555+
ctx=ast.Load(),
556+
),
557+
)
558+
],
559+
ctx=ast.Load(),
560+
)
561+
for v in scan_shape_vars
562+
],
563+
*[ast.Name(id=v, ctx=ast.Load()) for v in scan_vars],
564+
],
565+
ctx=ast.Store(),
456566
),
457567
ast.List(
458-
elts=[ast.Name(id=v, ctx=ast.Load()) for v in input_vars], ctx=ast.Store()
568+
elts=[
569+
ast.Name(id=v, ctx=ast.Load()) for v in [*scan_shape_vars, *input_vars]
570+
],
571+
ctx=ast.Store(),
459572
),
460573
],
461574
keywords=[],
462575
ctx=ast.Load(),
463576
)
464-
target = ast.List(
465-
[ast.Name(id=v, ctx=ast.Store()) for v in output_vars], ctx=ast.Store()
577+
target = ast.Tuple(
578+
[ast.Name(id=v, ctx=ast.Store()) for v in [*init_vars, *loop_vars, *output_vars]],
579+
ctx=ast.Store(),
466580
)
467581
assign = ast.Assign(targets=[target], value=call)
468582
return [func_def, assign]

0 commit comments

Comments
 (0)