8383
8484
8585def export_model (
86- model , dynamic_shapes , inputs , cache = False , oblivious = False , rt = False , cache_patch = False
86+ model ,
87+ dynamic_shapes ,
88+ inputs ,
89+ cache = False ,
90+ oblivious = False ,
91+ rt = False ,
92+ cache_patch = False ,
93+ strict = False ,
8794):
8895 if cache and not cache_patch :
8996 with register_additional_serialization_functions (patch_transformers = True ):
90- return export_model (model , dynamic_shapes , inputs , oblivious = oblivious , rt = rt )
97+ return export_model (
98+ model , dynamic_shapes , inputs , oblivious = oblivious , rt = rt , strict = strict
99+ )
91100 if cache_patch :
92101 with torch_export_patches (
93102 patch_torch = cache_patch in ("all" , "torch" , True , 1 ),
94103 patch_transformers = cache_patch in ("all" , "transformers" , True , 1 ),
95104 ):
96- return export_model (model , dynamic_shapes , inputs , oblivious = oblivious , rt = rt )
105+ return export_model (
106+ model , dynamic_shapes , inputs , oblivious = oblivious , rt = rt , strict = strict
107+ )
97108 if oblivious :
98109 with torch .fx .experimental ._config .patch (backed_size_oblivious = True ):
99- return export_model (model , dynamic_shapes , inputs , rt = rt )
110+ return export_model (model , dynamic_shapes , inputs , rt = rt , strict = strict )
100111 return torch .export .export (
101112 model ,
102113 (),
103114 inputs ,
104115 dynamic_shapes = dynamic_shapes ,
116+ strict = strict ,
105117 prefer_deferred_runtime_asserts_over_guards = rt ,
106118 )
107119
108120
109121def try_export_model (
110- model , dynamic_shapes , inputs , cache = False , oblivious = False , rt = False , cache_patch = False
122+ model ,
123+ dynamic_shapes ,
124+ inputs ,
125+ cache = False ,
126+ oblivious = False ,
127+ rt = False ,
128+ cache_patch = False ,
129+ strict = False ,
111130):
112131 try :
113132 return export_model (
@@ -118,6 +137,7 @@ def try_export_model(
118137 oblivious = oblivious ,
119138 rt = rt ,
120139 cache_patch = cache_patch ,
140+ strict = strict ,
121141 )
122142 except Exception as e :
123143 return e
@@ -140,14 +160,16 @@ def validation(ep, input_sets, expected):
140160
141161results = []
142162
143- possibilities = [* [[0 , 1 ] for _ in range (4 )], list (input_sets )]
163+ possibilities = [* [[0 , 1 ] for _ in range (5 )], list (input_sets )]
144164possibilities [1 ] = [0 , "all" , "torch" , "transformers" ]
145165with tqdm (list (itertools .product (* possibilities ))) as pbar :
146- for cache , cache_patch , oblivious , rt , inputs in pbar :
166+ for cache , cache_patch , strict , oblivious , rt , inputs in pbar :
147167 if cache_patch and not cache :
148168 # patches include caches.
149169 continue
150- kwargs = dict (cache = cache , cache_patch = cache_patch , oblivious = oblivious , rt = rt )
170+ kwargs = dict (
171+ cache = cache , cache_patch = cache_patch , strict = strict , oblivious = oblivious , rt = rt
172+ )
151173 legend = "-" .join (
152174 (k if isinstance (v , int ) else f"{ k } :{ v } " ) for k , v in kwargs .items () if v
153175 )
@@ -203,7 +225,7 @@ def validation(ep, input_sets, expected):
203225# The validation failures.
204226
205227invalid = df [(df .EXPORT == 1 ) & (df .WORKS == 0 )].pivot (
206- index = ["cache" , "cache_patch" , "oblivious" , "rt" , "export_with" ],
228+ index = ["cache" , "cache_patch" , "strict" , " oblivious" , "rt" , "export_with" ],
207229 columns = ["run_with" ],
208230 values = ["WORKS" , "ERR-RUN" ],
209231)
@@ -213,7 +235,7 @@ def validation(ep, input_sets, expected):
213235# %% Successes.
214236
215237success = df [(df .EXPORT == 1 ) & (df .WORKS == 1 )].pivot (
216- index = ["cache" , "cache_patch" , "oblivious" , "rt" , "export_with" ],
238+ index = ["cache" , "cache_patch" , "strict" , " oblivious" , "rt" , "export_with" ],
217239 columns = ["run_with" ],
218240 values = ["WORKS" ],
219241)
0 commit comments