File tree Expand file tree Collapse file tree 2 files changed +15
-5
lines changed
Expand file tree Collapse file tree 2 files changed +15
-5
lines changed Original file line number Diff line number Diff line change @@ -254,6 +254,17 @@ def torch_export_patches(
254254 may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
255255 It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
256256 """
257+ if verbose :
258+ print (f"[torch_export_patches] patch_sympy={ patch_sympy !r} " )
259+ print (f" . patch_torch={ patch_torch !r} " )
260+ print (f" . patch_transformers={ patch_transformers !r} " )
261+ print (f" . patch_diffusers={ patch_diffusers !r} " )
262+ print (f" . catch_constraints={ catch_constraints !r} " )
263+ print (f" . stop_if_static={ stop_if_static !r} " )
264+ print (f" . patch={ patch !r} " )
265+ print (f" . custom_patches={ custom_patches !r} " )
266+ print (f"[torch_export_patches] dump_rewriting={ dump_rewriting !r} " )
267+
257268 if rewrite :
258269 from .patch_module import torch_export_rewrite
259270
Original file line number Diff line number Diff line change @@ -394,12 +394,9 @@ def validate_model(
394394 same_as_pretrained = same_as_pretrained ,
395395 use_pretrained = use_pretrained ,
396396 )
397+ default_patch = dict (patch_transformers = True , patch_diffusers = True , patch = True )
397398 if isinstance (patch , bool ):
398- patch_kwargs = (
399- dict (patch_transformers = True , patch_diffusers = True , patch = True )
400- if patch
401- else dict (patch = False )
402- )
399+ patch_kwargs = default_patch if patch else dict (patch = False )
403400 elif isinstance (patch , str ):
404401 patch_kwargs = {"patch" : True , ** {p : True for p in patch .split ("," )}} # noqa: C420
405402 else :
@@ -408,6 +405,8 @@ def validate_model(
408405 if "patch" not in patch_kwargs :
409406 if any (patch_kwargs .values ()):
410407 patch_kwargs ["patch" ] = True
408+ elif len (patch ) == 1 and patch .get ("patch" , False ):
409+ patch_kwargs .update (default_patch )
411410
412411 assert not rewrite or patch_kwargs .get ("patch" , False ), (
413412 f"rewrite={ rewrite } , patch={ patch } , patch_kwargs={ patch_kwargs } "
You can’t perform that action at this time.
0 commit comments