@@ -526,28 +526,16 @@ def loop_body_1(z, iv, x, y):
526526
527527 rewritten_expected2 = RewrittenModel2 ()(x , y )
528528 self .assertEqualArray (expected , rewritten_expected2 )
529- torch .export .export (RewrittenModel2 (), (x , y ), dynamic_shapes = ds )
529+ torch .export .export (RewrittenModel2 (), (x , y ), dynamic_shapes = ds , strict = False )
530530
531531 rewritten = transform_method (Model .forward , verbose = self .verbose )
532- print ("-------" )
533- print (rewritten .code )
534- print ("-------" )
535-
536532 self .assertIn ("torch.ops.higher_order.scan(" , rewritten .code )
537533 Model .forward = rewritten .func
538534 self .assertEqualAny (expected , Model ()(x , y ))
539535
540- ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds )
536+ ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ds , strict = False )
541537 self .assertEqualAny (expected , ep .module ()(x , y ))
542538
543- """
544- z = torch.empty((x.shape[0], y.shape[0]))
545- def loop_body_0(i, x_row, y, z):
546- z[i, :] = ((x_row - y) ** 2).sum(dim=-1)
547- return z
548- z = torch.ops.higher_order.scan(loop_body_0, [x], [y], [])
549- """
550-
551539
552540if __name__ == "__main__" :
553541 unittest .main (verbosity = 2 )
0 commit comments