Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions _doc/patches.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ and triggered by ``with torch_export_patches(patch_transformers=True)``.
This function does one class,
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`
does all known classes.
It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister`
It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`
or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`.
Here is the list of supported caches:

Expand All @@ -113,7 +113,10 @@ Here is the list of supported caches:

import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p

print("\n".join(sorted(t.__name__ for t in p.serialization_functions())))
print(
"\n".join(sorted(t.__name__ for t in p.serialization_functions(
patch_transformers=True, patch_diffusers=True)))
)

.. _l-control-flow-rewriting:

Expand Down
5 changes: 4 additions & 1 deletion _doc/status/patches_coverage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ The following code shows the list of serialized classes in transformers.

import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p

print('\n'.join(sorted(t.__name__ for t in p.serialization_functions())))
print(
'\n'.join(sorted(t.__name__ for t in p.serialization_functions(
patch_transformers=True, patch_diffusers=True
))))

Patched Classes
===============
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def register_class_serialization(
) -> bool:
"""
Registers a class.
It can be undone with :func:`unregister`.
It can be undone with
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`.

:param cls: class to register
:param f_flatten: see ``torch.utils._pytree.register_pytree_node``
Expand Down Expand Up @@ -77,7 +78,8 @@ def register_cache_serialization(
patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
) -> Dict[str, bool]:
"""
Registers many classes with :func:`register_class_serialization`.
Registers many classes with
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`.
Returns information needed to undo the registration.

:param patch_transformers: add serialization function for
Expand Down
Loading