diff --git a/_doc/patches.rst b/_doc/patches.rst index bb589431..a2dde062 100644 --- a/_doc/patches.rst +++ b/_doc/patches.rst @@ -104,7 +104,7 @@ and triggered by ``with torch_export_patches(patch_transformers=True)``. This function does one class, :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization` does all known classes. -It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister` +It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization` or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`. Here is the list of supported caches: @@ -113,7 +113,10 @@ Here is the list of supported caches: import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p - print("\n".join(sorted(t.__name__ for t in p.serialization_functions()))) + print( + "\n".join(sorted(t.__name__ for t in p.serialization_functions( + patch_transformers=True, patch_diffusers=True))) + ) .. _l-control-flow-rewriting: diff --git a/_doc/status/patches_coverage.rst b/_doc/status/patches_coverage.rst index 1a3ac8b7..61d5b775 100644 --- a/_doc/status/patches_coverage.rst +++ b/_doc/status/patches_coverage.rst @@ -14,7 +14,10 @@ The following code shows the list of serialized classes in transformers. import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p - print('\n'.join(sorted(t.__name__ for t in p.serialization_functions()))) + print( + '\n'.join(sorted(t.__name__ for t in p.serialization_functions( + patch_transformers=True, patch_diffusers=True + )))) Patched Classes =============== diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index e669a129..4c4d2507 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -28,7 +28,8 @@ def register_class_serialization( ) -> bool: """ Registers a class. - It can be undone with :func:`unregister`. + It can be undone with + :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`. :param cls: class to register :param f_flatten: see ``torch.utils._pytree.register_pytree_node`` @@ -77,7 +78,8 @@ def register_cache_serialization( patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0 ) -> Dict[str, bool]: """ - Registers many classes with :func:`register_class_serialization`. + Registers many classes with + :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`. Returns information needed to undo the registration. :param patch_transformers: add serialization function for