File tree Expand file tree Collapse file tree 3 files changed +10
-7
lines changed
torch_export_patches/patches Expand file tree Collapse file tree 3 files changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -753,14 +753,18 @@ def change_dynamic_dimensions(self):
753753 :showcode:
754754
755755 import torch
756+ from onnx_diagnostic.helpers import string_type
756757 from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
757758
758- T3x1 = torch.rand((3, 1))
759+ T3x15 = torch.rand((3, 15))
760+ T3x20 = torch.rand((3, 20))
759761 T3x4 = torch.rand((3, 4))
760762 ds_batch = {0: "batch"}
761763 ds_batch_seq = {0: "batch", 1: "seq"}
762- kwargs = {"A": T3x4, "B": (T3x1, T3x1 )}
764+ kwargs = {"A": T3x4, "B": (T3x15, T3x20 )}
763765 ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
764- print(CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimension())
766+ new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions()
767+ print("before:", string_type(kwargs, with_shape=True))
768+ print("-after:", string_type(new_kwargs, with_shape=True))
765769 """
766770 return self ._generic_walker (self .ChangeDimensionProcessor ())
Original file line number Diff line number Diff line change @@ -47,7 +47,7 @@ def _patch_make_causal_mask(
4747if sys .version_info [:2 ] <= (3 , 11 ):
4848
4949 @dataclass
50- class kkpatched_AttentionMaskConverter :
50+ class patched_AttentionMaskConverter :
5151 """
5252 Patches
5353 ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
@@ -72,7 +72,7 @@ def _make_causal_mask(
7272else :
7373
7474 @dataclass
75- class kkpatched_AttentionMaskConverter :
75+ class patched_AttentionMaskConverter :
7676 """
7777 Patches
7878 ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
Original file line number Diff line number Diff line change @@ -277,8 +277,7 @@ def validate_model(
277277 if verbose :
278278 print (f"[validate_model] new inputs: { string_type (data ['inputs' ])} " )
279279 print (
280- f"[validate_model] new dynnamic_shapes: "
281- f"{ _ds_clean (data ['dynamic_shapes' ])} "
280+ f"[validate_model] new dynamic_hapes: { _ds_clean (data ['dynamic_shapes' ])} "
282281 )
283282
284283 if not empty (dtype ):
You can’t perform that action at this time.
0 commit comments