@@ -302,6 +302,8 @@ def validate_model(
302302 )
303303 ),
304304 )
305+ if "ERR_create" in summary :
306+ return summary , data
305307
306308 if drop_inputs :
307309 if verbose :
@@ -364,18 +366,14 @@ def validate_model(
364366 # We make a copy of the input just in case the model modifies them inplace
365367 hash_inputs = string_type (data ["inputs" ], with_shape = True )
366368 inputs = torch_deepcopy (data ["inputs" ])
367- begin = time .perf_counter ()
368- if quiet :
369- try :
370- expected = data ["model" ](** inputs )
371- except Exception as e :
372- summary ["ERR_run" ] = str (e )
373- data ["ERR_run" ] = e
374- summary ["time_run" ] = time .perf_counter () - begin
375- return summary , data
376- else :
377- expected = data ["model" ](** inputs )
378- summary ["time_run" ] = time .perf_counter () - begin
369+ model = data ["model" ]
370+
371+ expected = _quiet_or_not_quiet (
372+ quiet , "run" , summary , data , (lambda m = model , inp = inputs : m (** inp ))
373+ )
374+ if "ERR_run" in summary :
375+ return summary , data
376+
379377 summary ["model_expected" ] = string_type (expected , with_shape = True )
380378 if verbose :
381379 print ("[validate_model] done (run)" )
@@ -417,18 +415,18 @@ def validate_model(
417415
418416 # We make a copy of the input just in case the model modifies them inplace
419417 inputs = torch_deepcopy (data ["inputs_export" ])
420- begin = time . perf_counter ()
421- if quiet :
422- try :
423- expected = data [ "model" ]( ** inputs )
424- except Exception as e :
425- summary [ "ERR_run_patched" ] = str ( e )
426- data [ "ERR_run_patched" ] = e
427- summary [ "time_run_patched" ] = time . perf_counter () - begin
428- return summary , data
429- else :
430- expected = data [ "model" ]( ** inputs )
431- summary [ "time_run_patched" ] = time . perf_counter () - begin
418+ model = data [ "model" ]
419+
420+ expected = _quiet_or_not_quiet (
421+ quiet ,
422+ "run_patched" ,
423+ summary ,
424+ data ,
425+ ( lambda m = model , inp = inputs : m ( ** inp )),
426+ )
427+ if "ERR_run_patched" in summary :
428+ return summary , data
429+
432430 disc = max_diff (data ["expected" ], expected )
433431 for k , v in disc .items ():
434432 summary [f"disc_patched_{ k } " ] = v
0 commit comments