From 034d8759cb6c8c61a616b5321393030aa27d0a8c Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 1 Jul 2025 10:44:56 +0200 Subject: [PATCH 1/2] doc --- _doc/patches.rst | 5 ++++- _doc/status/patches_coverage.rst | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/_doc/patches.rst b/_doc/patches.rst index bb589431..7ca82456 100644 --- a/_doc/patches.rst +++ b/_doc/patches.rst @@ -113,7 +113,10 @@ Here is the list of supported caches: import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p - print("\n".join(sorted(t.__name__ for t in p.serialization_functions()))) + print( + "\n".join(sorted(t.__name__ for t in p.serialization_functions( + patch_transformers=True, patch_diffusers=True))) + ) .. _l-control-flow-rewriting: diff --git a/_doc/status/patches_coverage.rst b/_doc/status/patches_coverage.rst index 1a3ac8b7..61d5b775 100644 --- a/_doc/status/patches_coverage.rst +++ b/_doc/status/patches_coverage.rst @@ -14,7 +14,10 @@ The following code shows the list of serialized classes in transformers. import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p - print('\n'.join(sorted(t.__name__ for t in p.serialization_functions()))) + print( + '\n'.join(sorted(t.__name__ for t in p.serialization_functions( + patch_transformers=True, patch_diffusers=True + )))) Patched Classes =============== From 834b896e5e10171c8ff5ab4bef6b5e3c331c3652 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 1 Jul 2025 10:53:15 +0200 Subject: [PATCH 2/2] documentation --- _doc/patches.rst | 2 +- .../torch_export_patches/onnx_export_serialization.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/_doc/patches.rst b/_doc/patches.rst index 7ca82456..a2dde062 100644 --- a/_doc/patches.rst +++ b/_doc/patches.rst @@ -104,7 +104,7 @@ and triggered by ``with torch_export_patches(patch_transformers=True)``. This function does one class, :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization` does all known classes. -It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister` +It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization` or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`. Here is the list of supported caches: diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index e669a129..4c4d2507 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -28,7 +28,8 @@ def register_class_serialization( ) -> bool: """ Registers a class. - It can be undone with :func:`unregister`. + It can be undone with + :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`. :param cls: class to register :param f_flatten: see ``torch.utils._pytree.register_pytree_node`` @@ -77,7 +78,8 @@ def register_cache_serialization( patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0 ) -> Dict[str, bool]: """ - Registers many classes with :func:`register_class_serialization`. + Registers many classes with + :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`. Returns information needed to undo the registration. :param patch_transformers: add serialization function for