@@ -86,13 +86,13 @@ def _cmd_line(script_name: str, **kwargs: Dict[str, Union[str, int, float]]) ->
8686 return args
8787
8888
89- def _extract_metrics (text : str ) -> Dict [str , str ]:
89+ def _extract_metrics (text : str ) -> Dict [str , Union [ str , int , float ] ]:
9090 reg = re .compile (":(.*?),(.*.?);" )
9191 res = reg .findall (text )
9292 if len (res ) == 0 :
9393 return {}
9494 kw = dict (res )
95- new_kw = {}
95+ new_kw : Dict [ str , Any ] = {}
9696 for k , w in kw .items ():
9797 assert isinstance (k , str ) and isinstance (
9898 w , str
@@ -159,7 +159,7 @@ def run_benchmark(
159159 summary : Optional [Callable ] = None ,
160160 timeout : int = 600 ,
161161 missing : Optional [Dict [str , Union [str , Callable ]]] = None ,
162- ) -> List [Dict [str , Union [str , int , float , Tuple [ int , int ] ]]]:
162+ ) -> List [Dict [str , Union [str , int , float ]]]:
163163 """
164164 Runs a script multiple times and extract information from the output
165165 following the pattern ``:<metric>,<value>;``.
@@ -188,7 +188,7 @@ def run_benchmark(
188188 else :
189189 loop = configs
190190
191- data : List [Dict [str , Union [str , int , float , Tuple [ int , int ] ]]] = []
191+ data : List [Dict [str , Union [str , int , float ]]] = []
192192 for iter_loop , config in enumerate (loop ):
193193 if iter_loop < start :
194194 continue
@@ -266,23 +266,32 @@ def run_benchmark(
266266 metrics .update (config )
267267 if filename_out and os .path .exists (filename_out ):
268268 if "model_name" in metrics :
269+ assert isinstance (
270+ metrics ["model_name" ], str
271+ ), f"unexpected type { type (metrics ['model_name' ])} "
269272 new_name = f"{ filename_out } .{ _clean_string (metrics ['model_name' ])} "
270273 os .rename (filename_out , new_name )
271274 filename_out = new_name
272275 metrics ["file.stdout" ] = filename_out
273276 if filename_err and os .path .exists (filename_err ):
274277 if "model_name" in metrics :
278+ assert isinstance (
279+ metrics ["model_name" ], str
280+ ), f"unexpected type { type (metrics ['model_name' ])} "
275281 new_name = f"{ filename_err } .{ _clean_string (metrics ['model_name' ])} "
276282 os .rename (filename_err , new_name )
277283 filename_err = new_name
278284 metrics ["file.stderr" ] = filename_err
279285 metrics ["DATE" ] = f"{ datetime .now ():%Y-%m-%d} "
280- metrics ["ITER" ] = iter_loop
286+ metrics ["ITER" ] = str ( iter_loop )
281287 metrics ["TIME_ITER" ] = time .perf_counter () - begin
282288 metrics ["ERROR" ] = _clean_string (serr )
283289 metrics ["ERR_stdout" ] = _clean_string (sout )
284290 if metrics ["ERROR" ]:
285291 metrics ["ERR_std" ] = metrics ["ERROR" ]
292+ assert isinstance (
293+ metrics ["ERROR" ], str
294+ ), f"unexpected type { type (metrics ['ERROR' ])} "
286295 if "CUDA out of memory" in metrics ["ERROR" ]:
287296 metrics ["ERR_CUDA_OOM" ] = 1
288297 if "Cannot access gated repo for url" in metrics ["ERROR" ]:
@@ -348,8 +357,8 @@ def make_configs(
348357 drop : Optional [Set [str ]] = None ,
349358 replace : Optional [Dict [str , str ]] = None ,
350359 last : Optional [List [str ]] = None ,
351- filter_function : Optional [Callable [Dict [str , Any ], bool ]] = None ,
352- ) -> List [Dict [str , Any ]]:
360+ filter_function : Optional [Callable [[ Dict [str , Union [ str , int , float ]] ], bool ]] = None ,
361+ ) -> List [Dict [str , Union [ str , int , float ] ]]:
353362 """
354363 Creates all the configurations based on the command line arguments.
355364
@@ -383,14 +392,14 @@ def make_configs(
383392 for k in last :
384393 if k not in kwargs_ :
385394 continue
386- v = kwargs [k ]
395+ v = kwargs [k ] # type: ignore
387396 if isinstance (v , str ):
388397 args .append ([(k , s ) for s in v .split ("," )])
389398 else :
390399 args .append ([(k , v )])
391400
392401 configs = list (itertools .product (* args ))
393- confs = [dict (c ) for c in configs ]
402+ confs : List [ Dict [ str , Union [ int , float , str ]]] = [dict (c ) for c in configs ]
394403 if filter_function :
395404 confs = [c for c in confs if filter_function (c )]
396405 return confs
0 commit comments