Skip to content

Commit 6e1c7e6

Browse files
committed
fix documentation
1 parent ab0f3de commit 6e1c7e6

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -753,14 +753,18 @@ def change_dynamic_dimensions(self):
753753
:showcode:
754754
755755
import torch
756+
from onnx_diagnostic.helpers import string_type
756757
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
757758
758-
T3x1 = torch.rand((3, 1))
759+
T3x15 = torch.rand((3, 15))
760+
T3x20 = torch.rand((3, 20))
759761
T3x4 = torch.rand((3, 4))
760762
ds_batch = {0: "batch"}
761763
ds_batch_seq = {0: "batch", 1: "seq"}
762-
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
764+
kwargs = {"A": T3x4, "B": (T3x15, T3x20)}
763765
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
764-
print(CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimension())
766+
new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions()
767+
print("before:", string_type(kwargs, with_shape=True))
768+
print("-after:", string_type(new_kwargs, with_shape=True))
765769
"""
766770
return self._generic_walker(self.ChangeDimensionProcessor())

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _patch_make_causal_mask(
4747
if sys.version_info[:2] <= (3, 11):
4848

4949
@dataclass
50-
class kkpatched_AttentionMaskConverter:
50+
class patched_AttentionMaskConverter:
5151
"""
5252
Patches
5353
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
@@ -72,7 +72,7 @@ def _make_causal_mask(
7272
else:
7373

7474
@dataclass
75-
class kkpatched_AttentionMaskConverter:
75+
class patched_AttentionMaskConverter:
7676
"""
7777
Patches
7878
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@ def validate_model(
277277
if verbose:
278278
print(f"[validate_model] new inputs: {string_type(data['inputs'])}")
279279
print(
280-
f"[validate_model] new dynnamic_shapes: "
281-
f"{_ds_clean(data['dynamic_shapes'])}"
280+
f"[validate_model] new dynamic_hapes: {_ds_clean(data['dynamic_shapes'])}"
282281
)
283282

284283
if not empty(dtype):

0 commit comments

Comments
 (0)