33import time
44import torch
55from ..helpers import max_diff , string_type , string_diff
6+ from ..helpers .helper import flatten_object
7+ from ..helpers .ort_session import make_feeds
68from ..helpers .torch_test_helper import to_any , torch_deepcopy
79from ..torch_export_patches import bypass_export_some_errors
810from .hghub import get_untrained_model_with_inputs
@@ -259,6 +261,14 @@ def validate_model(
259261 f .write (str (ep .graph ))
260262 if verbose :
261263 print ("[validate_model] done (dump ep)" )
264+ if "onnx_program" in data :
265+ epo = data ["onnx_program" ]
266+ if verbose :
267+ print (f"[validate_model] dumps onnx program in { dump_folder !r} ..." )
268+ onnx_file_name = os .path .join (dump_folder , f"{ folder_name } .onnx" )
269+ epo .save (onnx_file_name )
270+ if verbose :
271+ print ("[validate_model] done (dump onnx)" )
262272 if verbose :
263273 print (f"[validate_model] dumps statistics in { dump_folder !r} ..." )
264274 with open (os .path .join (dump_folder , f"{ folder_name } .stats" ), "w" ) as f :
@@ -267,6 +277,15 @@ def validate_model(
267277 if verbose :
268278 print ("[validate_model] done (dump)" )
269279
280+ if exporter and exporter .startswith ("onnx-" ) and do_run :
281+ summary_valid , data = validate_onnx_model (
282+ data = data ,
283+ quiet = quiet ,
284+ verbose = verbose ,
285+ optimization = optimization ,
286+ )
287+ summary .update (summary_valid )
288+
270289 if verbose :
271290 print ("[validate_model] done (final)" )
272291 return summary , data
@@ -288,7 +307,6 @@ def call_exporter(
288307 :param exporter: exporter to call
289308 :param quiet: catch exception or not
290309 :param verbose: verbosity
291- :param patch: apply patches
292310 :param optimization: optimization to do
293311 :param do_run: runs and compute discrepancies
294312 :return: two dictionaries, one with some metrics,
@@ -305,6 +323,16 @@ def call_exporter(
305323 do_run = do_run ,
306324 )
307325 return summary , data
326+ if exporter .startswith ("onnx-" ):
327+ # torch export
328+ summary , data = call_torch_export_onnx (
329+ exporter = exporter ,
330+ data = data ,
331+ quiet = quiet ,
332+ verbose = verbose ,
333+ optimization = optimization ,
334+ )
335+ return summary , data
308336 raise NotImplementedError (
309337 f"export with { exporter !r} and optimization={ optimization !r} not implemented yet"
310338 )
@@ -331,19 +359,23 @@ def call_torch_export_export(
331359 do_run : bool = False ,
332360):
333361 """
334- Calls an exporter on a model;
362+ Exports a model with :func:`torch.export.export`.
335363 If a patch must be applied, it should be before this functions.
336364
337- :param data: dictionary with all the necessary inputs
365+ :param data: dictionary with all the necessary inputs, the dictionary must
366+ contains keys ``model`` and ``inputs_export``
338367 :param exporter: exporter to call
339368 :param quiet: catch exception or not
340369 :param verbose: verbosity
341- :param patch: apply patches
342370 :param optimization: optimization to do
343371 :param do_run: runs and compute discrepancies
344372 :return: two dictionaries, one with some metrics,
345373 another one with whatever the function produces
346374 """
375+ assert exporter in {
376+ "export-strict" ,
377+ "export-nostrict" ,
378+ }, f"Unexpected value for exporter={ exporter !r} "
347379 assert "model" in data , f"model is missing from data: { sorted (data )} "
348380 assert "inputs_export" in data , f"inputs_export is missing from data: { sorted (data )} "
349381 summary : Dict [str , Union [str , int , float ]] = {}
@@ -355,8 +387,8 @@ def call_torch_export_export(
355387 f"[call_torch_export_export] exporter={ exporter !r} , "
356388 f"strict={ strict } , optimization={ optimization !r} "
357389 )
358- print (f"[call_torch_export_export] args={ string_type (args )} " )
359- print (f"[call_torch_export_export] kwargs={ string_type (kwargs )} " )
390+ print (f"[call_torch_export_export] args={ string_type (args , with_shape = True )} " )
391+ print (f"[call_torch_export_export] kwargs={ string_type (kwargs , with_shape = True )} " )
360392 print (f"[call_torch_export_export] dynamic_shapes={ _ds_clean (ds )} " )
361393 print ("[call_torch_export_export] export..." )
362394 summary ["export_exporter" ] = exporter
@@ -431,3 +463,205 @@ def call_torch_export_export(
431463 f" after: { string_type (data ['inputs_export' ], with_shape = True )} "
432464 )
433465 return summary , data
466+
467+
468+ def call_torch_export_onnx (
469+ data : Dict [str , Any ],
470+ exporter : str ,
471+ quiet : bool = False ,
472+ verbose : int = 0 ,
473+ optimization : Optional [str ] = None ,
474+ ):
475+ """
476+ Exports a model into onnx.
477+ If a patch must be applied, it should be before this functions.
478+
479+ :param data: dictionary with all the necessary inputs, the dictionary must
480+ contains keys ``model`` and ``inputs_export``
481+ :param exporter: exporter to call
482+ :param quiet: catch exception or not
483+ :param verbose: verbosity
484+ :param optimization: optimization to do
485+ :return: two dictionaries, one with some metrics,
486+ another one with whatever the function produces
487+ """
488+ assert optimization in {
489+ "" ,
490+ "ir" ,
491+ None ,
492+ }, f"unexpected value for optimization={ optimization } "
493+ assert exporter in {
494+ "onnx-dynamo" ,
495+ "onnx-script" ,
496+ }, f"Unexpected value for exporter={ exporter !r} "
497+ assert "model" in data , f"model is missing from data: { sorted (data )} "
498+ assert "inputs_export" in data , f"inputs_export is missing from data: { sorted (data )} "
499+ summary : Dict [str , Union [str , int , float ]] = {}
500+ dynamo = "nostrict" not in exporter
501+ args , kwargs = split_args_kwargs (data ["inputs_export" ])
502+ ds = data .get ("dynamic_shapes" , None )
503+ if verbose :
504+ print (
505+ f"[call_torch_export_onnx] exporter={ exporter !r} , "
506+ f"optimization={ optimization !r} "
507+ )
508+ print (f"[call_torch_export_onnx] args={ string_type (args , with_shape = True )} " )
509+ print (f"[call_torch_export_onnx] kwargs={ string_type (kwargs , with_shape = True )} " )
510+ print (f"[call_torch_export_onnx] dynamic_shapes={ _ds_clean (ds )} " )
511+ print ("[call_torch_export_onnx] export..." )
512+ summary ["export_exporter" ] = exporter
513+ summary ["export_optimization" ] = optimization or ""
514+ summary ["export_dynamo" ] = dynamo
515+ summary ["export_args" ] = string_type (args , with_shape = True )
516+ summary ["export_kwargs" ] = string_type (kwargs , with_shape = True )
517+
518+ begin = time .perf_counter ()
519+ if quiet :
520+ try :
521+ epo = torch .onnx .export (
522+ data ["model" ],
523+ args ,
524+ kwargs = kwargs ,
525+ dynamic_shapes = ds ,
526+ dynamo = dynamo ,
527+ )
528+ except Exception as e :
529+ summary ["ERR_export_export" ] = str (e )
530+ data ["ERR_export_export" ] = e
531+ summary ["time_export_export" ] = time .perf_counter () - begin
532+ return summary , data
533+ else :
534+ epo = torch .onnx .export (
535+ data ["model" ],
536+ args ,
537+ kwargs = kwargs ,
538+ dynamic_shapes = ds ,
539+ dynamo = dynamo ,
540+ )
541+
542+ summary ["time_export_export" ] = time .perf_counter () - begin
543+ assert epo is not None , "no onnx export was found"
544+ if verbose :
545+ print ("[call_torch_export_onnx] done (export)" )
546+ data ["onnx_program" ] = epo
547+ if verbose > 1 :
548+ print ("[call_torch_export_onnx] -- ONNXProgram" )
549+ print (epo )
550+ print ("[call_torch_export_onnx] -- End of ONNXProgram" )
551+
552+ begin = time .perf_counter ()
553+ if optimization == "ir" :
554+ if verbose :
555+ print (f"[call_torch_export_onnx] starts optimization={ optimization !r} ..." )
556+ if quiet :
557+ try :
558+ epo .optimize ()
559+ except Exception as e :
560+ summary ["ERR_export_optimize_ir" ] = str (e )
561+ data ["ERR_export_optimize_ir" ] = e
562+ summary ["time_export_optimize_ir" ] = time .perf_counter () - begin
563+ return summary , data
564+ else :
565+ epo .optimize ()
566+ summary ["time_export_optimize_ir" ] = time .perf_counter () - begin
567+ if verbose :
568+ print ("[call_torch_export_onnx] done (optimization)" )
569+
570+ return summary , data
571+
572+
573+ def validate_onnx_model (
574+ data : Dict [str , Any ],
575+ quiet : bool = False ,
576+ verbose : int = 0 ,
577+ optimization : Optional [str ] = None ,
578+ ):
579+ """
580+ Verifies that an onnx model produces the same
581+ expected outputs.
582+
583+ :param data: dictionary with all the necessary inputs, the dictionary must
584+ contains keys ``model`` and ``inputs_export``
585+ :param quiet: catch exception or not
586+ :param verbose: verbosity
587+ :param optimization: optimization to do
588+ :return: two dictionaries, one with some metrics,
589+ another one with whatever the function produces
590+ """
591+ import onnxruntime
592+
593+ summary = {}
594+ flat_inputs = flatten_object (data ["inputs" ], drop_keys = True )
595+ d = flat_inputs [0 ].get_device ()
596+ providers = (
597+ ["CPUExecutionProvider" ]
598+ if d < 0
599+ else ["CUDAExecutionProvider" , "CPUExecutionProvider" ]
600+ )
601+ if "onnx_file_name" in data :
602+ source = data ["onnx_file_name" ]
603+ summary ["onnx_filename" ] = source
604+ summary ["onnx_size" ] = os .stats (source ).st_size
605+ else :
606+ assert (
607+ "onnx_program" in data
608+ ), f"onnx_program is missing from data which has { sorted (data )} "
609+ source = data ["onnx_program" ].model_proto .SerializeToString ()
610+ assert len (source ) < 2 ** 31 , f"The model is highger than 2Gb: { len (source ) / 2 ** 30 } Gb"
611+ summary ["onnx_size" ] = len (source )
612+ if verbose :
613+ print (f"[validate_onnx_model] verify onnx model with providers { providers } ..." )
614+
615+ begin = time .perf_counter ()
616+ if quiet :
617+ try :
618+ sess = onnxruntime .InferenceSession (source , providers = providers )
619+ except Exception as e :
620+ summary ["ERR_onnx_ort_create" ] = str (e )
621+ data ["ERR_onnx_ort_create" ] = e
622+ summary ["time_onnx_ort_create" ] = time .perf_counter () - begin
623+ return summary , data
624+ else :
625+ sess = onnxruntime .InferenceSession (source , providers = providers )
626+
627+ summary ["time_onnx_ort_create" ] = time .perf_counter () - begin
628+ data ["onnx_ort_sess" ] = sess
629+ if verbose :
630+ print ("[validate_onnx_model] done (ort_session)" )
631+
632+ # make_feeds
633+ if verbose :
634+ print ("[validate_onnx_model] make_feeds..." )
635+ print (f"[validate_onnx_model] inputs={ string_type (data ['inputs' ], with_shape = True )} " )
636+ feeds = make_feeds ([i .name for i in sess .get_inputs ()], data ["inputs" ], use_numpy = True )
637+ if verbose :
638+ print (f"[validate_onnx_model] ort inputs={ string_type (feeds , with_shape = True )} " )
639+ summary ["onnx_ort_inputs" ] = string_type (feeds , with_shape = True )
640+ if verbose :
641+ print ("[validate_onnx_model] done (make_feeds)" )
642+
643+ # run ort
644+ if verbose :
645+ print ("[validate_onnx_model] run session..." )
646+ begin = time .perf_counter ()
647+ if quiet :
648+ try :
649+ got = sess .run (None , feeds )
650+ except Exception as e :
651+ summary ["ERR_onnx_ort_run" ] = str (e )
652+ data ["ERR_onnx_ort_run" ] = e
653+ summary ["time_onnx_ort_run" ] = time .perf_counter () - begin
654+ return summary , data
655+ else :
656+ got = sess .run (None , feeds )
657+ if verbose :
658+ print ("[validate_onnx_model] done (run)" )
659+ print (f"[validate_onnx_model] got={ string_type (got , with_shape = True )} " )
660+
661+ # compute discrepancies
662+ disc = max_diff (data ["expected" ], got , flatten = True )
663+ if verbose :
664+ print (f"[validate_onnx_model] discrepancies={ string_diff (disc )} " )
665+ for k , v in disc .items ():
666+ summary [f"disc_onnx_ort_run_{ k } " ] = v
667+ return summary , data
0 commit comments