55import numpy as np
66from scipy .spatial .distance import cdist
77import torch
8- from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
8+ from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout , has_torch
99from onnx_diagnostic .torch_export_patches import torch_export_patches , torch_export_rewrite
1010from onnx_diagnostic .torch_export_patches .patch_module import (
1111 transform_method ,
@@ -498,7 +498,7 @@ def forward(self, x, y):
498498 def loop_body_0 (x , y ):
499499 x = x .reshape ((- 1 , * x .shape ))
500500 z = ((x - y ) ** 2 ).sum (dim = - 1 )
501- return [ z ]
501+ return ( z ,)
502502
503503 z = torch .ops .higher_order .scan (loop_body_0 , [], [x ], [y ])
504504 return z [0 ]
@@ -510,29 +510,61 @@ def loop_body_0(x, y):
510510 ds = ({0 : DYN , 1 : DYN }, {0 : DYN , 1 : DYN })
511511 torch .export .export (RewrittenModel (), (x , y ), dynamic_shapes = ds )
512512
513+ class RewrittenModelLoop (torch .nn .Module ):
514+ def forward (self , z , iv , x , y ):
515+ z = z .clone ()
516+ i = iv .item ()
517+ z [i , :] = ((x [i , :] - y ) ** 2 ).sum (dim = - 1 )
518+ return (z , iv )
519+
520+ inputs = (
521+ torch .empty ((x .shape [0 ], y .shape [0 ])),
522+ torch .tensor ([2 ], dtype = torch .int64 ),
523+ x ,
524+ y ,
525+ )
526+ RewrittenModelLoop ()(* inputs )
527+ try :
528+ from experimental_experiment .torch_interpreter .tracing import CustomTracer
529+ except ImportError :
530+ CustomTracer = None
531+ if CustomTracer :
532+ graph = CustomTracer ().trace (RewrittenModelLoop ())
533+ self .assertNotEmpty (graph )
534+
535+ # does not wiork
536+ # dsl = ({0: DYN, 1: DYN}, {}, {0: DYN, 1: DYN}, {0: DYN, 1: DYN})
537+ # torch.export.export(RewrittenModelLoop(), inputs, dynamic_shapes=dsl)
538+
513539 class RewrittenModel2 (torch .nn .Module ):
514540 def forward (self , x , y ):
515541 def loop_body_1 (z , iv , x , y ):
516542 z = z .clone ()
517543 i = iv .item ()
518544 z [i , :] = ((x [i , :] - y ) ** 2 ).sum (dim = - 1 )
519- return [ z , iv ]
545+ return ( z , iv )
520546
521547 z = torch .empty ((x .shape [0 ], y .shape [0 ]))
522548 r = torch .ops .higher_order .scan (
523- loop_body_1 , [z ], [torch .arange (x .shape [0 ], dtype = torch .int64 )], [x , y ]
549+ loop_body_1 ,
550+ [z ],
551+ [torch .arange (x .shape [0 ], dtype = torch .int64 ).reshape ((- 1 , 1 ))],
552+ [x , y ],
524553 )
525554 return r [0 ]
526555
527- rewritten_expected2 = RewrittenModel2 ()(x , y )
528- self .assertEqualArray (expected , rewritten_expected2 )
529- torch .export .export (RewrittenModel2 (), (x , y ), dynamic_shapes = ds , strict = False )
530-
531556 rewritten = transform_method (Model .forward , verbose = self .verbose )
532557 self .assertIn ("torch.ops.higher_order.scan(" , rewritten .code )
533558 Model .forward = rewritten .func
534559 self .assertEqualAny (expected , Model ()(x , y ))
535560
561+ rewritten_expected2 = RewrittenModel2 ()(x , y )
562+ self .assertEqualArray (expected , rewritten_expected2 )
563+
564+ if not has_torch ("2.9" ):
565+ raise unittest .SkipTest ("skipped export, torch must be >= 2.9" )
566+
567+ torch .export .export (RewrittenModel2 (), (x , y ), dynamic_shapes = ds , strict = False )
536568 ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds , strict = False )
537569 self .assertEqualAny (expected , ep .module ()(x , y ))
538570
0 commit comments