@@ -109,9 +109,12 @@ def _make_folder_name(
109109 optimization : Optional [str ] = None ,
110110 dtype : Optional [Union [str , torch .dtype ]] = None ,
111111 device : Optional [Union [str , torch .device ]] = None ,
112+ subfolder : Optional [str ] = None ,
112113) -> str :
113114 "Creates a filename unique based on the given options."
114115 els = [model_id .replace ("/" , "_" )]
116+ if subfolder :
117+ els .append (subfolder .replace ("/" , "_" ))
115118 if exporter :
116119 els .append (exporter )
117120 if optimization :
@@ -224,6 +227,7 @@ def validate_model(
224227 ortfusiontype : Optional [str ] = None ,
225228 input_options : Optional [Dict [str , Any ]] = None ,
226229 model_options : Optional [Dict [str , Any ]] = None ,
230+ subfolder : Optional [str ] = None ,
227231) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
228232 """
229233 Validates a model.
@@ -256,11 +260,11 @@ def validate_model(
256260 used to export
257261 :param model_options: additional options when creating the model such as
258262 ``num_hidden_layers`` or ``attn_implementation``
263+ :param subfolder: version or subfolders to uses when retrieving a model id
259264 :return: two dictionaries, one with some metrics,
260265 another one with whatever the function produces
261266 """
262267 summary = version_summary ()
263-
264268 summary .update (
265269 dict (
266270 version_model_id = model_id ,
@@ -282,7 +286,7 @@ def validate_model(
282286 folder_name = None
283287 if dump_folder :
284288 folder_name = _make_folder_name (
285- model_id , exporter , optimization , dtype = dtype , device = device
289+ model_id , exporter , optimization , dtype = dtype , device = device , subfolder = subfolder
286290 )
287291 dump_folder = os .path .join (dump_folder , folder_name )
288292 if not os .path .exists (dump_folder ):
@@ -293,11 +297,15 @@ def validate_model(
293297 print (f"[validate_model] dump into { folder_name !r} " )
294298
295299 if verbose :
296- print (f"[validate_model] validate model id { model_id !r} " )
300+ if subfolder :
301+ print (f"[validate_model] validate model id { model_id !r} , subfolder={ subfolder !r} " )
302+ else :
303+ print (f"[validate_model] validate model id { model_id !r} " )
297304 if model_options :
298305 print (f"[validate_model] model_options={ model_options !r} " )
299306 print (f"[validate_model] get dummy inputs with input_options={ input_options } ..." )
300307 summary ["model_id" ] = model_id
308+ summary ["model_subfolder" ] = subfolder
301309
302310 iop = input_options or {}
303311 mop = model_options or {}
@@ -307,14 +315,15 @@ def validate_model(
307315 summary ,
308316 None ,
309317 (
310- lambda mid = model_id , v = verbose , task = task , tr = trained , iop = iop : (
318+ lambda mid = model_id , v = verbose , task = task , tr = trained , iop = iop , sub = subfolder : (
311319 get_untrained_model_with_inputs (
312320 mid ,
313321 verbose = v ,
314322 task = task ,
315323 same_as_pretrained = tr ,
316324 inputs_kwargs = iop ,
317325 model_kwargs = mop ,
326+ subfolder = sub ,
318327 )
319328 )
320329 ),
0 commit comments