22import inspect
33import textwrap
44import unittest
5+ import numpy as np
6+ from scipy .spatial .distance import cdist
57import torch
6- from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
8+ from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout , has_torch , requires_torch
79from onnx_diagnostic .torch_export_patches import torch_export_patches , torch_export_rewrite
810from onnx_diagnostic .torch_export_patches .patch_module import (
911 transform_method ,
1012 inplace_add_parent ,
13+ ShapeFinder ,
14+ RewriteControlFlow ,
1115)
1216
1317
@@ -42,7 +46,7 @@ def forward(self, x, y):
4246 hasattr (node , "parent" ) for node in ast .walk (tree )
4347 ), f"Missing parent in { ast .dump (tree , indent = 2 )} "
4448
45- def test_rewrite_forward_return1 (self ):
49+ def test_rewrite_test_in_forward_return1 (self ):
4650
4751 class Model (torch .nn .Module ):
4852 def forward (self , x , y ):
@@ -69,7 +73,7 @@ def forward(self, x, y):
6973 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
7074
7175 @hide_stdout ()
72- def test_rewrite_forward_return2 (self ):
76+ def test_rewrite_test_in_forward_return2 (self ):
7377
7478 class Model (torch .nn .Module ):
7579 def forward (self , x , y ):
@@ -95,7 +99,7 @@ def forward(self, x, y):
9599 self .assertEqualAny (expected , ep .module ()(x , y ))
96100 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
97101
98- def test_rewrite_forward_assign1 (self ):
102+ def test_rewrite_test_in_forward_assign1 (self ):
99103
100104 class Model (torch .nn .Module ):
101105 def forward (self , x , y ):
@@ -122,7 +126,7 @@ def forward(self, x, y):
122126 self .assertEqualAny (expected , ep .module ()(x , y ))
123127 self .assertEqualArray (expected_ , ep .module ()(- x , y ))
124128
125- def test_rewrite_forward_assign2 (self ):
129+ def test_rewrite_test_in_forward_assign2 (self ):
126130
127131 class Model (torch .nn .Module ):
128132 def forward (self , x , y ):
@@ -149,7 +153,7 @@ def forward(self, x, y):
149153 self .assertEqualAny (expected , ep .module ()(x , y ))
150154 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
151155
152- def test_rewrite_forward_assign_noelse (self ):
156+ def test_rewrite_test_in_forward_assign_noelse (self ):
153157
154158 class Model (torch .nn .Module ):
155159 def forward (self , x , y ):
@@ -174,7 +178,7 @@ def forward(self, x, y):
174178 self .assertEqualAny (expected , ep .module ()(x , y ))
175179 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
176180
177- def test_rewrite_forward_return_noelse (self ):
181+ def test_rewrite_test_in_forward_return_noelse (self ):
178182
179183 class Model (torch .nn .Module ):
180184 def forward (self , x , y ):
@@ -186,7 +190,7 @@ def forward(self, x, y):
186190 lambda : transform_method (Model .forward , verbose = self .verbose ), NotImplementedError
187191 )
188192
189- def test_rewrite_forward_assign2_in_2 (self ):
193+ def test_rewrite_test_in_forward_assign2_in_2 (self ):
190194
191195 class Model (torch .nn .Module ):
192196 def forward (self , x , y ):
@@ -215,7 +219,7 @@ def forward(self, x, y):
215219 self .assertEqualAny (expected , ep .module ()(x , y ))
216220 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
217221
218- def test_rewrite_forward_assign2_in_3 (self ):
222+ def test_rewrite_test_in_forward_assign2_in_3 (self ):
219223
220224 class Model (torch .nn .Module ):
221225 def forward (self , x , y ):
@@ -292,7 +296,7 @@ def torch_cond_else_2(y):
292296 x , y = torch .rand ((3 , 4 )), torch .rand ((3 , 4 ))
293297 Model ()(x , y )
294298
295- def test_rewrite_forward_assign_nested (self ):
299+ def test_rewrite_test_in_forward_assign_nested (self ):
296300
297301 class Model (torch .nn .Module ):
298302 def forward (self , x , y ):
@@ -341,7 +345,7 @@ def forward(self, x, y):
341345 self .assertEqualAny (expected_0 , ep .module ()(x , - y ))
342346 self .assertEqualAny (expected_1 , ep .module ()(- x , - y ))
343347
344- def test_rewrite_forward_none (self ):
348+ def test_rewrite_test_in_forward_none (self ):
345349
346350 class Model (torch .nn .Module ):
347351 def forward (self , x , y ):
@@ -365,7 +369,7 @@ def forward(self, x, y):
365369 self .assertEqualAny (expected , ep .module ()(x , y ))
366370 self .assertEqualAny (expected_ , ep .module ()(- x , y ))
367371
368- def test_rewrite_PLBartEncoderLayer (self ):
372+ def test_rewrite_test_in_PLBartEncoderLayer (self ):
369373 from transformers .models .plbart .modeling_plbart import PLBartEncoderLayer
370374
371375 rewritten = transform_method (PLBartEncoderLayer .forward , verbose = self .verbose )
@@ -443,6 +447,135 @@ def forward(self, x, y):
443447 got = ep .module ()(x , y )
444448 self .assertEqualArray (expected , got )
445449
450+ def test_shape_finder (self ):
451+ expr = "range(x.shape[0])"
452+ node = ast .parse (expr )
453+ sh = ShapeFinder ()
454+ sh .visit (node )
455+ self .assertEqual ({"x" }, sh .found_shape )
456+
457+ def test__find_loop_vars (self ):
458+ code = textwrap .dedent (
459+ """
460+ for i in range(x.shape[0]):
461+ z[i, :] = ((x[i : i + 1, :] - y) ** 2).sum(dim=-1)
462+ """
463+ )
464+ node = ast .parse (code )
465+ tr = RewriteControlFlow ()
466+ vars = tr ._find_loop_vars (node .body [0 ])
467+ self .assertEqual (
468+ {
469+ "init" : ["z" ],
470+ "input" : ["y" ],
471+ "loop" : ["i" ],
472+ "output" : [],
473+ "scan" : [],
474+ "scan_shape" : ["x" ],
475+ },
476+ vars ,
477+ )
478+
479+ @requires_torch ("2.8" )
480+ def test_rewrite_loop (self ):
481+
482+ class Model (torch .nn .Module ):
483+ def forward (self , x , y ):
484+ z = torch .empty ((x .shape [0 ], y .shape [0 ]))
485+ for i in range (x .shape [0 ]):
486+ z [i , :] = ((x [i , :] - y ) ** 2 ).sum (dim = - 1 )
487+ return z
488+
489+ x , y = torch .rand ((3 , 4 )), torch .rand ((5 , 4 ))
490+ expected = Model ()(x , y )
491+ self .assertEqualArray (
492+ expected .numpy (),
493+ cdist (x .numpy (), y .numpy (), metric = "sqeuclidean" ).astype (np .float32 ),
494+ atol = 1e-5 ,
495+ )
496+
497+ class RewrittenModel (torch .nn .Module ):
498+ def forward (self , x , y ):
499+ def loop_body_0 (x , y ):
500+ x = x .reshape ((- 1 , * x .shape ))
501+ z = ((x - y ) ** 2 ).sum (dim = - 1 )
502+ return (z ,)
503+
504+ z = torch .ops .higher_order .scan (loop_body_0 , [], [x ], [y ])
505+ return z [0 ]
506+
507+ rewritten_expected = RewrittenModel ()(x , y )
508+ self .assertEqualArray (expected , rewritten_expected )
509+
510+ DYN = torch .export .Dim .DYNAMIC
511+ ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
512+ torch .export .export (RewrittenModel (), (x , y ), dynamic_shapes = ds )
513+
514+ class RewrittenModelLoop (torch .nn .Module ):
515+ def forward (self , z , iv , x , y ):
516+ z = z .clone ()
517+ i = iv .item ()
518+ z [i , :] = ((x [i , :] - y ) ** 2 ).sum (dim = - 1 )
519+ return (z , iv )
520+
521+ inputs = (
522+ torch .empty ((x .shape [0 ], y .shape [0 ])),
523+ torch .tensor ([2 ], dtype = torch .int64 ),
524+ x ,
525+ y ,
526+ )
527+ RewrittenModelLoop ()(* inputs )
528+ try :
529+ from experimental_experiment .torch_interpreter .tracing import CustomTracer
530+ except ImportError :
531+ CustomTracer = None
532+ if CustomTracer :
533+ graph = CustomTracer ().trace (RewrittenModelLoop ())
534+ self .assertNotEmpty (graph )
535+
536+ # does not wiork
537+ # dsl = ({0: DYN, 1: DYN}, {}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN})
538+ # torch.export.export(RewrittenModelLoop(), inputs, dynamic_shapes=dsl)
539+
540+ class RewrittenModel2 (torch .nn .Module ):
541+ def forward (self , x , y ):
542+ def loop_body_1 (z , iv , x , y ):
543+ z = z .clone ()
544+ i = iv .item ()
545+ z [i , :] = ((x [i , :] - y ) ** 2 ).sum (dim = - 1 )
546+ return (z , iv )
547+
548+ z = torch .empty ((x .shape [0 ], y .shape [0 ]))
549+ r = torch .ops .higher_order .scan (
550+ loop_body_1 ,
551+ [z ],
552+ [torch .arange (x .shape [0 ], dtype = torch .int64 ).reshape ((- 1 , 1 ))],
553+ [x , y ],
554+ )
555+ return r [0 ]
556+
557+ rewritten = transform_method (Model .forward , verbose = self .verbose )
558+ self .assertIn ("torch.ops.higher_order.scan(" , rewritten .code )
559+ Model .forward = rewritten .func
560+ self .assertEqualAny (expected , Model ()(x , y ))
561+
562+ rewritten_expected2 = RewrittenModel2 ()(x , y )
563+ self .assertEqualArray (expected , rewritten_expected2 )
564+
565+ if not has_torch ("2.9" ):
566+ raise unittest .SkipTest ("skipped export, torch must be >= 2.9" )
567+
568+ torch .export .export (RewrittenModel2 (), (x , y ), dynamic_shapes = ds , strict = False )
569+ ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds , strict = False )
570+ self .assertEqualAny (expected , ep .module ()(x , y ))
571+
572+ """
573+ position_encodings = torch.cat(
574+ [weight[:, :required_pos_encodings_columns]
575+ for weight in broadcasted_weights], dim=-1
576+ )
577+ """
578+
446579
447580if __name__ == "__main__" :
448581 unittest .main (verbosity = 2 )
0 commit comments