Skip to content

Commit bd3b352

Browse files
authored
fix documentation (#121)
1 parent 49a41ef commit bd3b352

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.6.1
55
+++++
66

7+
* :pr:`120`: enables TorchOnnxEvaluator in command line ``python -m onnx_diagnostic validate ...``
78
* :pr:`115`, :pr:`116`, :pr:`117`, :pr:`118`, :pr:`119`:
89
first steps for TorchOnnxEvaluator
910
* :pr:`114`: extends the list of known rewritings

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class TorchOnnxEvaluator:
5858
5959
The class is not multithreaded. `runtime_info` gets updated
6060
by the the class. The list of available kernels is returned by function
61-
:func:`get_kernels`.
61+
:func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
6262
"""
6363

6464
def __init__(

onnx_diagnostic/torch_export_patches/patch_module_helper.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
6767
:showcode:
6868
6969
import pprint
70-
from onnx_diagnostic.torch_export_patches.patch_model_helper import (
71-
known_transformers_rewritings,
70+
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
71+
known_transformers_rewritings_clamp_float16,
7272
)
7373
74-
pprint.pprint(known_transformers_rewritings())
74+
pprint.pprint(known_transformers_rewritings_clamp_float16())
7575
"""
7676
_alias = {
7777
"AutoformerEncoder": "AutoformerEncoderLayer",
@@ -130,13 +130,13 @@ def rewritings_transformers_clamp_float16(cls_name) -> List[type]:
130130
:showcode:
131131
132132
import pprint
133-
from onnx_diagnostic.torch_export_patches.patch_model_helper import (
133+
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
134134
_rewrite_forward_clamp_float16,
135135
)
136136
137-
pprint.pprint(_rewrite_forward_clamp_float16()
137+
pprint.pprint(_rewrite_forward_clamp_float16())
138138
139-
Function :func:`known_transformers_rewritings` collects
139+
Function `_rewrite_forward_clamp_float16` collects
140140
all model classes using those layers.
141141
"""
142142
_known = _rewrite_forward_clamp_float16()
@@ -161,7 +161,7 @@ def _add(f):
161161
def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
162162
"""
163163
Returns a known list of classes mapped to a known rewritings
164-
because of control flow. See :func:`registered_transformers_rewritings`.
164+
because of control flow. See :func:`known_transformers_rewritings_clamp_float16`.
165165
166166
:param cls_name: name of the class
167167
:return: a list of rewriting

0 commit comments

Comments
 (0)