@@ -109,6 +109,34 @@ def filter_inputs(
109109 return new_inputs , dyn
110110
111111
112+ def _make_folder_name (
113+ model_id : str ,
114+ exporter : str ,
115+ optimization : Optional [str ] = None ,
116+ dtype : Optional [Union [str , torch .dtype ]] = None ,
117+ device : Optional [Union [str , torch .device ]] = None ,
118+ ) -> str :
119+ "Creates a filename unique based on the given options."
120+ els = [model_id .replace ("/" , "_" ), exporter ]
121+ if optimization :
122+ els .append (optimization )
123+ if dtype is not None and dtype :
124+ stype = dtype if isinstance (dtype , str ) else str (dtype )
125+ stype = stype .replace ("float" , "f" ).replace ("uint" , "u" ).replace ("int" , "i" )
126+ els .append (stype )
127+ if device is not None and device :
128+ sdev = device if isinstance (device , str ) else str (device )
129+ sdev = sdev .lower ()
130+ if "cpu" in sdev :
131+ sdev = "cpu"
132+ elif "cuda" in sdev :
133+ sdev = "cuda"
134+ else :
135+ raise AssertionError (f"unexpected value for device={ device } , sdev={ sdev !r} " )
136+ els .append (sdev )
137+ return "-" .join (els )
138+
139+
112140def validate_model (
113141 model_id : str ,
114142 task : Optional [str ] = None ,
@@ -152,7 +180,9 @@ def validate_model(
152180 assert not trained , f"trained={ trained } not supported yet"
153181 summary : Dict [str , Union [int , float , str ]] = {}
154182 if dump_folder :
155- folder_name = f"{ model_id .replace ('/' ,'-' )} -{ exporter } -{ optimization or '' } "
183+ folder_name = _make_folder_name (
184+ model_id , exporter , optimization , dtype = dtype , device = device
185+ )
156186 dump_folder = os .path .join (dump_folder , folder_name )
157187 if not os .path .exists (dump_folder ):
158188 os .makedirs (dump_folder )
@@ -353,7 +383,7 @@ def validate_model(
353383 if verbose :
354384 print (f"[validate_model] dumps onnx program in { dump_folder !r} ..." )
355385 onnx_file_name = os .path .join (dump_folder , f"{ folder_name } .onnx" )
356- epo .save (onnx_file_name )
386+ epo .save (onnx_file_name , external_data = True )
357387 if verbose :
358388 print ("[validate_model] done (dump onnx)" )
359389 if verbose :
0 commit comments