@@ -349,22 +349,23 @@ def validate_model(
349349 :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
350350 """
351351 if isinstance (patch , bool ):
352- patch = (
352+ patch_kwargs = (
353353 dict (patch_transformers = True , patch_diffusers = True , patch = True )
354354 if patch
355355 else dict (patch = False )
356356 )
357357 elif isinstance (patch , str ):
358- patch = {"patch" : True , ** {p : True for p in patch .split ("," )}} # noqa: C420
358+ patch_kwargs = {"patch" : True , ** {p : True for p in patch .split ("," )}} # noqa: C420
359359 else :
360360 assert isinstance (patch , dict ), f"Unable to interpret patch={ patch !r} "
361- patch = patch .copy ()
362- if "patch" not in patch :
363- if any (patch .values ):
364- patch ["patch" ] = True
365-
366- assert not rewrite or patch , (
367- f"rewrite={ rewrite } , patch={ patch } , patch must be True to enable rewriting, "
361+ patch_kwargs = patch .copy ()
362+ if "patch" not in patch_kwargs :
363+ if any (patch_kwargs .values ()):
364+ patch_kwargs ["patch" ] = True
365+
366+ assert not rewrite or patch_kwargs .get ("patch" , False ), (
367+ f"rewrite={ rewrite } , patch={ patch } , patch_kwargs={ patch_kwargs } "
368+ f"patch must be True to enable rewriting, "
368369 f"if --no-patch was specified on the command line, --no-rewrite must be added."
369370 )
370371 summary = version_summary ()
@@ -379,6 +380,7 @@ def validate_model(
379380 version_optimization = optimization or "" ,
380381 version_quiet = str (quiet ),
381382 version_patch = str (patch ),
383+ version_patch_kwargs = str (patch_kwargs ).replace (" " , "" ),
382384 version_rewrite = str (rewrite ),
383385 version_dump_folder = dump_folder or "" ,
384386 version_drop_inputs = str (list (drop_inputs or "" )),
@@ -414,7 +416,7 @@ def validate_model(
414416 print (f"[validate_model] model_options={ model_options !r} " )
415417 print (f"[validate_model] get dummy inputs with input_options={ input_options } ..." )
416418 print (
417- f"[validate_model] rewrite={ rewrite } , patch= { patch } , "
419+ f"[validate_model] rewrite={ rewrite } , patch_kwargs= { patch_kwargs } , "
418420 f"stop_if_static={ stop_if_static } "
419421 )
420422 print (f"[validate_model] exporter={ exporter !r} , optimization={ optimization !r} " )
@@ -590,7 +592,7 @@ def validate_model(
590592 f"[validate_model] -- export the model with { exporter !r} , "
591593 f"optimization={ optimization !r} "
592594 )
593- if patch :
595+ if patch_kwargs :
594596 if verbose :
595597 print (
596598 f"[validate_model] applies patches before exporting "
@@ -601,7 +603,7 @@ def validate_model(
601603 verbose = max (0 , verbose - 1 ),
602604 rewrite = data .get ("rewrite" , None ),
603605 dump_rewriting = (os .path .join (dump_folder , "rewrite" ) if dump_folder else None ),
604- ** patch ,
606+ ** patch_kwargs ,
605607 ) as modificator :
606608 data ["inputs_export" ] = modificator (data ["inputs" ]) # type: ignore
607609
0 commit comments