@@ -52,7 +52,9 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
5252 return name , to_patch
5353
5454
55- def patch_module_or_classes (mod , verbose : int = 0 ) -> Dict [type , Dict [type , Callable ]]:
55+ def patch_module_or_classes (
56+ mod , verbose : int = 0 , patch_details : Optional [PatchDetails ] = None
57+ ) -> Dict [type , Dict [type , Callable ]]:
5658 """
5759 Applies all patches defined in classes prefixed by ``patched_``
5860 ``cls._PATCHED_CLASS_`` defines the class to patch,
@@ -62,13 +64,16 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
6264
6365 :param mod: module of list of clsses to patch
6466 :param verbose: verbosity
67+ :param patch_details: used to store information about the applied patches
6568 :return: patch info
6669 """
6770 if isinstance (mod , list ):
6871 to_patch = mod
6972 name = "list"
73+ list_name = "auto:list"
7074 else :
7175 name , to_patch = get_patches (mod , verbose )
76+ list_name = f"auto:{ mod .__name__ } "
7277
7378 res = {}
7479 for cls in to_patch :
@@ -81,6 +86,8 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
8186 if verbose :
8287 print (f"[patch_module_or_classes] function: { original .__name__ } .{ f .__name__ } " )
8388 setattr (original , f .__name__ , cls ["patch" ])
89+ if patch_details :
90+ patch_details .append (list_name , original , f )
8491 continue
8592
8693 original = cls ._PATCHED_CLASS_
@@ -90,6 +97,11 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
9097
9198 keep = {n : getattr (original , n , None ) for n in methods }
9299 for n in methods :
100+ if patch_details :
101+ if hasattr (original , n ):
102+ patch_details .append (list_name , getattr (original , n ), getattr (cls , n ))
103+ else :
104+ patch_details .append (list_name , f"{ original .__name__ } { n } " , getattr (cls , n ))
93105 setattr (original , n , getattr (cls , n ))
94106 res [cls ] = keep
95107
@@ -343,7 +355,7 @@ def torch_export_patches(
343355 sympy .core .numbers .IntegerConstant .name = lambda self : f"IntCst{ str (self )} "
344356 if patch_details :
345357 patch_details .append (
346- ( "sympy" , f_sympy_name , sympy .core .numbers .IntegerConstant .name )
358+ "sympy" , f_sympy_name , sympy .core .numbers .IntegerConstant .name
347359 )
348360
349361 ###############
@@ -386,14 +398,14 @@ def torch_export_patches(
386398 f_infer_size = torch ._subclasses .fake_impls .infer_size
387399 torch ._subclasses .fake_impls .infer_size = patched_infer_size
388400 if patch_details :
389- patch_details .append (( "torch" , f_infer_size , patched_infer_size ) )
401+ patch_details .append ("torch" , f_infer_size , patched_infer_size )
390402
391403 # torch._refs._broadcast_shapes
392404 f__broadcast_shapes = torch ._refs ._broadcast_shapes
393405 torch ._refs ._broadcast_shapes = patched__broadcast_shapes
394406 torch ._meta_registrations ._broadcast_shapes = patched__broadcast_shapes
395407 if patch_details :
396- patch_details .append (( "torch" , f__broadcast_shapes , patched__broadcast_shapes ) )
408+ patch_details .append ("torch" , f__broadcast_shapes , patched__broadcast_shapes )
397409
398410 # torch._export.non_strict_utils._constrain_user_specified_dimhint_range
399411 f___constrain_user_specified_dimhint_range = (
@@ -404,11 +416,9 @@ def torch_export_patches(
404416 )
405417 if patch_details :
406418 patch_details .append (
407- (
408- "torch" ,
409- f___constrain_user_specified_dimhint_range ,
410- patched__constrain_user_specified_dimhint_range ,
411- )
419+ "torch" ,
420+ f___constrain_user_specified_dimhint_range ,
421+ patched__constrain_user_specified_dimhint_range ,
412422 )
413423
414424 # torch._prims._broadcast_in_dim_meta
@@ -422,20 +432,20 @@ def torch_export_patches(
422432 torch ._prims ._broadcast_in_dim_meta = _patched_dim_f
423433 torch ._prims .broadcast_in_dim = _patched_dim_f
424434 if patch_details :
425- patch_details .append (( "torch" , f_broadcast_in_dim , _patched_dim_f ) )
435+ patch_details .append ("torch" , f_broadcast_in_dim , _patched_dim_f )
426436
427437 # torch._refs._maybe_broadcast
428438 f__maybe_broadcast = torch ._refs ._maybe_broadcast
429439 torch ._refs ._maybe_broadcast = patched__maybe_broadcast
430440 if patch_details :
431- patch_details .append (( "torch" , f__maybe_broadcast , patched__maybe_broadcast ) )
441+ patch_details .append ("torch" , f__maybe_broadcast , patched__maybe_broadcast )
432442
433443 # ShapeEnv
434444 f_shape_env__evaluate_expr = ShapeEnv ._evaluate_expr
435445 ShapeEnv ._evaluate_expr = patched_ShapeEnv ._evaluate_expr
436446 if patch_details :
437447 patch_details .append (
438- ( "torch" , f_shape_env__evaluate_expr , patched_ShapeEnv ._evaluate_expr )
448+ "torch" , f_shape_env__evaluate_expr , patched_ShapeEnv ._evaluate_expr
439449 )
440450
441451 # torch._export.non_strict_utils.produce_guards_and_solve_constraints
@@ -470,7 +480,7 @@ def torch_export_patches(
470480 ShapeEnv ._set_replacement = patched_ShapeEnv ._set_replacement
471481 if patch_details :
472482 patch_details .append (
473- ( "torch" , f_shape_env__set_replacement , patched_ShapeEnv ._set_replacement )
483+ "torch" , f_shape_env__set_replacement , patched_ShapeEnv ._set_replacement
474484 )
475485
476486 if verbose :
@@ -479,7 +489,7 @@ def torch_export_patches(
479489 ShapeEnv ._log_guard = patched_ShapeEnv ._log_guard
480490 if patch_details :
481491 patch_details .append (
482- ( "torch" , f_shape_env__log_guard , patched_ShapeEnv ._log_guard )
492+ "torch" , f_shape_env__log_guard , patched_ShapeEnv ._log_guard
483493 )
484494
485495 if stop_if_static > 1 :
@@ -489,7 +499,7 @@ def torch_export_patches(
489499 ShapeEnv ._check_frozen = patched_ShapeEnv ._check_frozen
490500 if patch_details :
491501 patch_details .append (
492- ( "torch" , f_shape_env__check_frozen , ShapeEnv ._check_frozen )
502+ "torch" , f_shape_env__check_frozen , ShapeEnv ._check_frozen
493503 )
494504
495505 ####################
@@ -537,11 +547,9 @@ def torch_export_patches(
537547 masking_utils ._vmap_for_bhqkv = patch_transformers_list .patched__vmap_for_bhqkv
538548 if patch_details :
539549 patch_details .append (
540- (
541- "transformers" ,
542- f_transformers__vmap_for_bhqkv ,
543- patch_transformers_list .patched__vmap_for_bhqkv ,
544- )
550+ "transformers" ,
551+ f_transformers__vmap_for_bhqkv ,
552+ patch_transformers_list .patched__vmap_for_bhqkv ,
545553 )
546554
547555 if verbose :
@@ -555,11 +563,9 @@ def torch_export_patches(
555563 )
556564 if patch_details :
557565 patch_details .append (
558- (
559- "transformers" ,
560- f_transformers_sdpa_mask_recent_torch ,
561- patch_transformers_list .patched_sdpa_mask_recent_torch ,
562- )
566+ "transformers" ,
567+ f_transformers_sdpa_mask_recent_torch ,
568+ patch_transformers_list .patched_sdpa_mask_recent_torch ,
563569 )
564570 if masking_utils .sdpa_mask == f_transformers_sdpa_mask_recent_torch :
565571 if verbose :
@@ -573,11 +579,9 @@ def torch_export_patches(
573579 )
574580 if patch_details :
575581 patch_details .append (
576- (
577- "transformers" ,
578- f_transformers_sdpa_mask ,
579- patch_transformers_list .patched_sdpa_mask_recent_torch ,
580- )
582+ "transformers" ,
583+ f_transformers_sdpa_mask ,
584+ patch_transformers_list .patched_sdpa_mask_recent_torch ,
581585 )
582586 else :
583587 f_transformers_sdpa_mask = None
@@ -596,11 +600,9 @@ def torch_export_patches(
596600 masking_utils .eager_mask = patch_transformers_list .patched_eager_mask
597601 if patch_details :
598602 patch_details .append (
599- (
600- "transformers" ,
601- f_transformers_eager_mask ,
602- patch_transformers_list .patched_eager_mask ,
603- )
603+ "transformers" ,
604+ f_transformers_eager_mask ,
605+ patch_transformers_list .patched_eager_mask ,
604606 )
605607 if (
606608 "eager" in masking_utils .ALL_MASK_ATTENTION_FUNCTIONS
@@ -662,11 +664,9 @@ def torch_export_patches(
662664 )
663665 if patch_details :
664666 patch_details .append (
665- (
666- "transformers" ,
667- f_sdpa_attention_forward ,
668- patch_transformers_list .patched_sdpa_attention_forward ,
669- )
667+ "transformers" ,
668+ f_sdpa_attention_forward ,
669+ patch_transformers_list .patched_sdpa_attention_forward ,
670670 )
671671
672672 if custom_patches :
0 commit comments