22import inspect
33import textwrap
44import unittest
5+ import numpy as np
6+ from scipy .spatial .distance import cdist
57import torch
68from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
79from onnx_diagnostic .torch_export_patches import torch_export_patches , torch_export_rewrite
810from onnx_diagnostic .torch_export_patches .patch_module import (
911 transform_method ,
1012 inplace_add_parent ,
13+ ShapeFinder ,
14+ RewriteControlFlow ,
1115)
1216
1317
@@ -42,7 +46,7 @@ def forward(self, x, y):
4246 hasattr (node , "parent" ) for node in ast .walk (tree )
4347 ), f"Missing parent in { ast .dump (tree , indent = 2 )} "
4448
45- def test_rewrite_forward_return1 (self ):
49+ def test_rewrite_test_in_forward_return1 (self ):
4650
4751 class Model (torch .nn .Module ):
4852 def forward (self , x , y ):
@@ -69,7 +73,7 @@ def forward(self, x, y):
6973 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
7074
7175 @hide_stdout ()
72- def test_rewrite_forward_return2 (self ):
76+ def test_rewrite_test_in_forward_return2 (self ):
7377
7478 class Model (torch .nn .Module ):
7579 def forward (self , x , y ):
@@ -95,7 +99,7 @@ def forward(self, x, y):
9599 self .assertEqualAny (expected , ep .module ()(x , y ))
96100 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
97101
98- def test_rewrite_forward_assign1 (self ):
102+ def test_rewrite_test_in_forward_assign1 (self ):
99103
100104 class Model (torch .nn .Module ):
101105 def forward (self , x , y ):
@@ -122,7 +126,7 @@ def forward(self, x, y):
122126 self .assertEqualAny (expected , ep .module ()(x , y ))
123127 self .assertEqualArray (expected_ , ep .module ()(- x , y ))
124128
125- def test_rewrite_forward_assign2 (self ):
129+ def test_rewrite_test_in_forward_assign2 (self ):
126130
127131 class Model (torch .nn .Module ):
128132 def forward (self , x , y ):
@@ -149,7 +153,7 @@ def forward(self, x, y):
149153 self .assertEqualAny (expected , ep .module ()(x , y ))
150154 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
151155
152- def test_rewrite_forward_assign_noelse (self ):
156+ def test_rewrite_test_in_forward_assign_noelse (self ):
153157
154158 class Model (torch .nn .Module ):
155159 def forward (self , x , y ):
@@ -174,7 +178,7 @@ def forward(self, x, y):
174178 self .assertEqualAny (expected , ep .module ()(x , y ))
175179 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
176180
177- def test_rewrite_forward_return_noelse (self ):
181+ def test_rewrite_test_in_forward_return_noelse (self ):
178182
179183 class Model (torch .nn .Module ):
180184 def forward (self , x , y ):
@@ -186,7 +190,7 @@ def forward(self, x, y):
186190 lambda : transform_method (Model .forward , verbose = self .verbose ), NotImplementedError
187191 )
188192
189- def test_rewrite_forward_assign2_in_2 (self ):
193+ def test_rewrite_test_in_forward_assign2_in_2 (self ):
190194
191195 class Model (torch .nn .Module ):
192196 def forward (self , x , y ):
@@ -215,7 +219,7 @@ def forward(self, x, y):
215219 self .assertEqualAny (expected , ep .module ()(x , y ))
216220 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
217221
218- def test_rewrite_forward_assign2_in_3 (self ):
222+ def test_rewrite_test_in_forward_assign2_in_3 (self ):
219223
220224 class Model (torch .nn .Module ):
221225 def forward (self , x , y ):
@@ -292,7 +296,7 @@ def torch_cond_else_2(y):
292296 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
293297 Model ()(x , y )
294298
295- def test_rewrite_forward_assign_nested (self ):
299+ def test_rewrite_test_in_forward_assign_nested (self ):
296300
297301 class Model (torch .nn .Module ):
298302 def forward (self , x , y ):
@@ -341,7 +345,7 @@ def forward(self, x, y):
341345 self .assertEqualAny (expected_0 , ep .module ()(x , - y ))
342346 self .assertEqualAny (expected_1 , ep .module ()(- x , - y ))
343347
344- def test_rewrite_forward_none (self ):
348+ def test_rewrite_test_in_forward_none (self ):
345349
346350 class Model (torch .nn .Module ):
347351 def forward (self , x , y ):
@@ -365,7 +369,7 @@ def forward(self, x, y):
365369 self .assertEqualAny (expected , ep .module ()(x , y ))
366370 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
367371
368- def test_rewrite_PLBartEncoderLayer (self ):
372+ def test_rewrite_test_in_PLBartEncoderLayer (self ):
369373 from transformers .models .plbart .modeling_plbart import PLBartEncoderLayer
370374
371375 rewritten = transform_method (PLBartEncoderLayer .forward , verbose = self .verbose )
@@ -443,6 +447,91 @@ def forward(self, x, y):
443447 got = ep .module ()(x , y )
444448 self .assertEqualArray (expected , got )
445449
450+ def test_shape_finder (self ):
451+ expr = "range(x.shape[0])"
452+ node = ast .parse (expr )
453+ sh = ShapeFinder ()
454+ sh .visit (node )
455+ self .assertEqual ({"x" }, sh .found_shape )
456+
457+ def test__find_loop_vars (self ):
458+ code = textwrap .dedent (
459+ """
460+ for i in range(x.shape[0]):
461+ z[i, :] = ((x[i : i + 1, :] - y) ** 2).sum(dim=-1)
462+ """
463+ )
464+ node = ast .parse (code )
465+ tr = RewriteControlFlow ()
466+ vars = tr ._find_loop_vars (node .body [0 ])
467+ self .assertEqual (
468+ {"loop" : ["i" ], "scan" : ["x" ], "input" : ["y" ], "output" : ["z" ], "init" : []}, vars
469+ )
470+
471+ def test_rewrite_loop (self ):
472+
473+ class Model (torch .nn .Module ):
474+ def forward (self , x , y ):
475+ z = torch .empty ((x .shape [0 ], y .shape [0 ]))
476+ for i in range (x .shape [0 ]):
477+ z [i , :] = ((x [i : i + 1 , :] - y ) ** 2 ).sum (dim = - 1 )
478+ return z
479+
480+ class RewrittenModel (torch .nn .Module ):
481+ def forward (self , x , y ):
482+ def loop_body_0 (x , y ):
483+ x = x .reshape ((- 1 , * x .shape ))
484+ z = ((x - y ) ** 2 ).sum (dim = - 1 )
485+ return [z ]
486+
487+ z = torch .ops .higher_order .scan (loop_body_0 , [], [x ], [y ])
488+ return z [0 ]
489+
490+ class RewrittenModel2 (torch .nn .Module ):
491+ def forward (self , x , y ):
492+ def loop_body_1 (z , i , x , y ):
493+ z = z .clone ()
494+ z [i , :] = ((x [i , :] - y ) ** 2 ).sum (dim = - 1 )
495+ return [z , i ]
496+
497+ z = torch .empty ((x .shape [0 ], y .shape [0 ]))
498+ r = torch .ops .higher_order .scan (
499+ loop_body_1 , [z ], [torch .arange (x .shape [0 ], dtype = torch .int64 )], [x , y ]
500+ )
501+ return r [0 ]
502+
503+ x , y = torch .rand ((3 , 4 )), torch .rand ((5 , 4 ))
504+ expected = Model ()(x , y )
505+ self .assertEqualArray (
506+ expected .numpy (),
507+ cdist (x .numpy (), y .numpy (), metric = "sqeuclidean" ).astype (np .float32 ),
508+ atol = 1e-5 ,
509+ )
510+ rewritten_expected = RewrittenModel ()(x , y )
511+ self .assertEqualArray (expected , rewritten_expected )
512+ rewritten_expected2 = RewrittenModel2 ()(x , y )
513+ self .assertEqualArray (expected , rewritten_expected2 )
514+
515+ rewritten = transform_method (Model .forward , verbose = self .verbose )
516+ print (rewritten .code )
517+
518+ self .assertIn ("torch.ops.higher_order.scan(" , rewritten .code )
519+ Model .forward = rewritten .func
520+ self .assertEqualAny (expected , Model ()(x , y ))
521+
522+ DYN = torch .export .Dim .DYNAMIC
523+ ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
524+ ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
525+ self .assertEqualAny (expected , ep .module ()(x , y ))
526+
527+ """
528+ z = torch.empty((x.shape[0], y.shape[0]))
529+ def loop_body_0(i, x_row, y, z):
530+ z[i, :] = ((x_row - y) ** 2).sum(dim=-1)
531+ return z
532+ z = torch.ops.higher_order.scan(loop_body_0, [x], [y], [])
533+ """
534+
446535
447536if __name__ == "__main__" :
448537 unittest .main (verbosity = 2 )
0 commit comments