11import unittest
2- from typing import Callable
32import torch
4- from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout
3+ from onnx_diagnostic .ext_test_case import ExtTestCase , hide_stdout , ignore_warnings
54from onnx_diagnostic .reference import ExtendedReferenceEvaluator
65from onnx_diagnostic .helpers .torch_test_helper import is_torchdynamo_exporting
76
7+ try :
8+ from experimental_experiment .torch_interpreter import to_onnx
9+ except ImportError :
10+ to_onnx = None
11+
812
913@torch .jit .script_if_tracing
1014def dummy_loop (padded : torch .Tensor , pos : torch .Tensor ):
@@ -15,42 +19,53 @@ def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
1519 return copy
1620
1721
18- def wrap_for_export (f : Callable ) -> Callable :
19-
20- class _wrapped (torch .nn .Module ):
21- def __init__ (self ):
22- super ().__init__ ()
23- self .f = f
22+ def dummy_loop_with_scan (padded : torch .Tensor , pos : torch .Tensor ):
23+ def pad_row (padded , p ):
24+ row = torch .zeros ((padded .shape [0 ],))
25+ torch ._check (p .item () > 0 )
26+ torch ._check (p .item () < padded .shape [0 ])
27+ # this check is not always true, we add it anyway to make this dimension >= 2
28+ # and avoid raising an exception about dynamic dimension in {0, 1}
29+ if is_torchdynamo_exporting ():
30+ torch ._check (p .item () > 1 )
31+ row [: p .item ()] = padded [: p .item ()]
32+ return (row ,)
2433
25- def forward (self , * args , ** kwargs ):
26- return self .f (* args , ** kwargs )
34+ return torch .ops .higher_order .scan (
35+ pad_row ,
36+ [],
37+ [padded , pos ],
38+ [],
39+ )
2740
28- return _wrapped ()
2941
30-
31- def select_when_exporting (mod , f ):
32- if is_torchdynamo_exporting ():
33- return mod
34- return f
42+ def select_when_exporting (f , f_scan ):
43+ return f_scan if is_torchdynamo_exporting () else f
3544
3645
3746class TestJit (ExtTestCase ):
47+ def test_dummy_loop (self ):
48+ x = torch .randn ((5 , 6 ))
49+ y = torch .arange (5 , dtype = torch .int64 ) + 1
50+ res = dummy_loop (x , y )
51+ res_scan = dummy_loop_with_scan (x , y )
52+ self .assertEqualArray (res , res_scan [0 ])
53+
3854 @hide_stdout ()
39- def test_export_loop (self ):
55+ @ignore_warnings (UserWarning )
56+ def test_export_loop_onnxscript (self ):
4057 class Model (torch .nn .Module ):
41- def __init__ (self ):
42- super ().__init__ ()
43- self .wrapped_f = wrap_for_export (dummy_loop )
44-
4558 def forward (self , images , position ):
46- return select_when_exporting (self .wrapped_f , dummy_loop )(images , position )
59+ return select_when_exporting (dummy_loop , dummy_loop_with_scan )(
60+ images , position
61+ )
4762
4863 model = Model ()
4964 x = torch .randn ((5 , 6 ))
5065 y = torch .arange (5 , dtype = torch .int64 ) + 1
5166 expected = model (x , y )
5267
53- name = self .get_dump_file ("test_export_loop .onnx" )
68+ name = self .get_dump_file ("test_export_loop_onnxscript .onnx" )
5469 torch .onnx .export (
5570 model ,
5671 (x , y ),
@@ -68,15 +83,16 @@ def forward(self, images, position):
6883 model ,
6984 (x , y ),
7085 dynamic_shapes = {"images" : {0 : DYN , 1 : DYN }, "position" : {0 : DYN }},
86+ strict = False ,
7187 )
72- print (ep )
88+ self . assertNotEmpty (ep )
7389
74- name2 = self .get_dump_file ("test_export_loop .dynamo.onnx" )
90+ name2 = self .get_dump_file ("test_export_loop_onnxscript .dynamo.onnx" )
7591 torch .onnx .export (
7692 model ,
7793 (x , y ),
7894 name2 ,
79- dynamic_axes = {"images" : {0 : "batch" , 1 : "maxdim" }, "position" : {0 : "batch" }},
95+ dynamic_shapes = {"images" : {0 : "batch" , 1 : "maxdim" }, "position" : {0 : "batch" }},
8096 dynamo = True ,
8197 fallback = False ,
8298 )
@@ -85,6 +101,33 @@ def forward(self, images, position):
85101 got = ref .run (None , feeds )[0 ]
86102 self .assertEqualArray (expected , got )
87103
104+ @hide_stdout ()
105+ @ignore_warnings (UserWarning )
106+ @unittest .skipIf (to_onnx is None , "missing to_onnx" )
107+ def test_export_loop_custom (self ):
108+ class Model (torch .nn .Module ):
109+ def forward (self , images , position ):
110+ return select_when_exporting (dummy_loop , dummy_loop_with_scan )(
111+ images , position
112+ )
113+
114+ model = Model ()
115+ x = torch .randn ((5 , 6 ))
116+ y = torch .arange (5 , dtype = torch .int64 ) + 1
117+ expected = model (x , y )
118+
119+ name2 = self .get_dump_file ("test_export_loop.custom.onnx" )
120+ to_onnx (
121+ model ,
122+ (x , y ),
123+ filename = name2 ,
124+ dynamic_shapes = {"images" : {0 : "batch" , 1 : "maxdim" }, "position" : {0 : "batch" }},
125+ )
126+ ref = ExtendedReferenceEvaluator (name2 )
127+ feeds = dict (images = x .numpy (), position = y .numpy ())
128+ got = ref .run (None , feeds )[0 ]
129+ self .assertEqualArray (expected , got )
130+
88131
89132if __name__ == "__main__" :
90133 unittest .main (verbosity = 2 )
0 commit comments