@@ -109,9 +109,12 @@ def _make_folder_name(
109109 optimization : Optional [str ] = None ,
110110 dtype : Optional [Union [str , torch .dtype ]] = None ,
111111 device : Optional [Union [str , torch .device ]] = None ,
112+ subfolder : Optional [str ] = None ,
112113) -> str :
113114 "Creates a filename unique based on the given options."
114115 els = [model_id .replace ("/" , "_" )]
116+ if subfolder :
117+ els .append (subfolder .replace ("/" , "_" ))
115118 if exporter :
116119 els .append (exporter )
117120 if optimization :
@@ -224,6 +227,7 @@ def validate_model(
224227 ortfusiontype : Optional [str ] = None ,
225228 input_options : Optional [Dict [str , Any ]] = None ,
226229 model_options : Optional [Dict [str , Any ]] = None ,
230+ subfolder : Optional [str ] = None ,
227231) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
228232 """
229233 Validates a model.
@@ -256,11 +260,11 @@ def validate_model(
256260 used to export
257261 :param model_options: additional options when creating the model such as
258262 ``num_hidden_layers`` or ``attn_implementation``
263+ :param subfolder: version or subfolders to uses when retrieving a model id
259264 :return: two dictionaries, one with some metrics,
260265 another one with whatever the function produces
261266 """
262267 summary = version_summary ()
263-
264268 summary .update (
265269 dict (
266270 version_model_id = model_id ,
@@ -282,7 +286,7 @@ def validate_model(
282286 folder_name = None
283287 if dump_folder :
284288 folder_name = _make_folder_name (
285- model_id , exporter , optimization , dtype = dtype , device = device
289+ model_id , exporter , optimization , dtype = dtype , device = device , subfolder = subfolder
286290 )
287291 dump_folder = os .path .join (dump_folder , folder_name )
288292 if not os .path .exists (dump_folder ):
@@ -293,11 +297,15 @@ def validate_model(
293297 print (f"[validate_model] dump into { folder_name !r} " )
294298
295299 if verbose :
296- print (f"[validate_model] validate model id { model_id !r} " )
300+ if subfolder :
301+ print (f"[validate_model] validate model id { model_id !r} , subfolder={ subfolder !r} " )
302+ else :
303+ print (f"[validate_model] validate model id { model_id !r} " )
297304 if model_options :
298305 print (f"[validate_model] model_options={ model_options !r} " )
299306 print (f"[validate_model] get dummy inputs with input_options={ input_options } ..." )
300307 summary ["model_id" ] = model_id
308+ summary ["model_subfolder" ] = subfolder or ""
301309
302310 iop = input_options or {}
303311 mop = model_options or {}
@@ -307,14 +315,15 @@ def validate_model(
307315 summary ,
308316 None ,
309317 (
310- lambda mid = model_id , v = verbose , task = task , tr = trained , iop = iop : (
318+ lambda mid = model_id , v = verbose , task = task , tr = trained , iop = iop , sub = subfolder : (
311319 get_untrained_model_with_inputs (
312320 mid ,
313321 verbose = v ,
314322 task = task ,
315323 same_as_pretrained = tr ,
316324 inputs_kwargs = iop ,
317325 model_kwargs = mop ,
326+ subfolder = sub ,
318327 )
319328 )
320329 ),
@@ -1060,15 +1069,16 @@ def call_torch_export_custom(
10601069 assert (
10611070 optimization in available
10621071 ), f"unexpected value for optimization={ optimization } , available={ available } "
1063- assert exporter in {
1072+ available = {
10641073 "custom" ,
10651074 "custom-strict" ,
1066- "custom-strict-dec " ,
1075+ "custom-strict-default " ,
10671076 "custom-strict-all" ,
10681077 "custom-nostrict" ,
1069- "custom-nostrict-dec " ,
1078+ "custom-nostrict-default " ,
10701079 "custom-nostrict-all" ,
1071- }, f"Unexpected value for exporter={ exporter !r} "
1080+ }
1081+ assert exporter in available , f"Unexpected value for exporter={ exporter !r} in { available } "
10721082 assert "model" in data , f"model is missing from data: { sorted (data )} "
10731083 assert "inputs_export" in data , f"inputs_export is missing from data: { sorted (data )} "
10741084 summary : Dict [str , Union [str , int , float ]] = {}
@@ -1100,7 +1110,7 @@ def call_torch_export_custom(
11001110 export_options = ExportOptions (
11011111 strict = strict ,
11021112 decomposition_table = (
1103- "dec " if "-dec " in exporter else ("all" if "-all" in exporter else None )
1113+ "default " if "-default " in exporter else ("all" if "-all" in exporter else None )
11041114 ),
11051115 )
11061116 options = OptimizationOptions (patterns = optimization ) if optimization else None
0 commit comments