@@ -263,7 +263,7 @@ def validate_model(
263263 use_pretrained : bool = False ,
264264 optimization : Optional [str ] = None ,
265265 quiet : bool = False ,
266- patch : bool = False ,
266+ patch : Union [ bool , str , Dict [ str , bool ]] = False ,
267267 rewrite : bool = False ,
268268 stop_if_static : int = 1 ,
269269 dump_folder : Optional [str ] = None ,
@@ -301,8 +301,10 @@ def validate_model(
301301 :param optimization: optimization to apply to the exported model,
302302 depend on the the exporter
303303 :param quiet: if quiet, catches exception if any issue
304- :param patch: applies patches (``patch_transformers=True``) before exporting,
305- see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
304+ :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
305+ if True before exporting
306+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
307+ a string can be used to specify only one of them
306308 :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
307309 see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
308310 :param stop_if_static: stops if a dynamic dimension becomes static,
@@ -346,8 +348,24 @@ def validate_model(
346348 exported model returns the same outputs as the original one, otherwise,
347349 :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
348350 """
349- assert not rewrite or patch , (
350- f"rewrite={ rewrite } , patch={ patch } , patch must be True to enable rewriting, "
351+ if isinstance (patch , bool ):
352+ patch_kwargs = (
353+ dict (patch_transformers = True , patch_diffusers = True , patch = True )
354+ if patch
355+ else dict (patch = False )
356+ )
357+ elif isinstance (patch , str ):
358+ patch_kwargs = {"patch" : True , ** {p : True for p in patch .split ("," )}} # noqa: C420
359+ else :
360+ assert isinstance (patch , dict ), f"Unable to interpret patch={ patch !r} "
361+ patch_kwargs = patch .copy ()
362+ if "patch" not in patch_kwargs :
363+ if any (patch_kwargs .values ()):
364+ patch_kwargs ["patch" ] = True
365+
366+ assert not rewrite or patch_kwargs .get ("patch" , False ), (
367+ f"rewrite={ rewrite } , patch={ patch } , patch_kwargs={ patch_kwargs } "
368+ f"patch must be True to enable rewriting, "
351369 f"if --no-patch was specified on the command line, --no-rewrite must be added."
352370 )
353371 summary = version_summary ()
@@ -362,6 +380,7 @@ def validate_model(
362380 version_optimization = optimization or "" ,
363381 version_quiet = str (quiet ),
364382 version_patch = str (patch ),
383+ version_patch_kwargs = str (patch_kwargs ).replace (" " , "" ),
365384 version_rewrite = str (rewrite ),
366385 version_dump_folder = dump_folder or "" ,
367386 version_drop_inputs = str (list (drop_inputs or "" )),
@@ -397,7 +416,7 @@ def validate_model(
397416 print (f"[validate_model] model_options={ model_options !r} " )
398417 print (f"[validate_model] get dummy inputs with input_options={ input_options } ..." )
399418 print (
400- f"[validate_model] rewrite={ rewrite } , patch= { patch } , "
419+ f"[validate_model] rewrite={ rewrite } , patch_kwargs= { patch_kwargs } , "
401420 f"stop_if_static={ stop_if_static } "
402421 )
403422 print (f"[validate_model] exporter={ exporter !r} , optimization={ optimization !r} " )
@@ -573,18 +592,18 @@ def validate_model(
573592 f"[validate_model] -- export the model with { exporter !r} , "
574593 f"optimization={ optimization !r} "
575594 )
576- if patch :
595+ if patch_kwargs :
577596 if verbose :
578597 print (
579598 f"[validate_model] applies patches before exporting "
580599 f"stop_if_static={ stop_if_static } "
581600 )
582601 with torch_export_patches ( # type: ignore
583- patch_transformers = True ,
584602 stop_if_static = stop_if_static ,
585603 verbose = max (0 , verbose - 1 ),
586604 rewrite = data .get ("rewrite" , None ),
587605 dump_rewriting = (os .path .join (dump_folder , "rewrite" ) if dump_folder else None ),
606+ ** patch_kwargs , # type: ignore[arg-type]
588607 ) as modificator :
589608 data ["inputs_export" ] = modificator (data ["inputs" ]) # type: ignore
590609
0 commit comments