Skip to content

Commit 56e0b4c

Browse files
committed
fix _make_causal
1 parent 1ed180a commit 56e0b4c

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def test_get_untrained_model_with_inputs_clip_vit(self):
102102
mid = "openai/clip-vit-base-patch16"
103103
data = get_untrained_model_with_inputs(mid, verbose=1)
104104
model, inputs = data["model"], data["inputs"]
105-
model(**inputs)
105+
with bypass_export_some_errors(patch_transformers=True):
106+
model(**inputs)
106107
# different expected value for different version of transformers
107108
self.assertIn((data["size"], data["n_weights"]), [(188872708, 47218177)])
108109

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,35 @@ class patched_AttentionMaskConverter:
5555

5656
@staticmethod
5757
def _make_causal_mask(
58-
input_ids_shape: torch.Size,
59-
dtype: torch.dtype,
60-
device: torch.device,
61-
past_key_values_length: int = 0,
62-
sliding_window: Optional[int] = None,
58+
*args,
59+
**kwargs,
60+
# input_ids_shape: torch.Size,
61+
# dtype: torch.dtype,
62+
# device: torch.device,
63+
# past_key_values_length: int = 0,
64+
# sliding_window: Optional[int] = None,
6365
):
64-
"""Patched method."""
65-
return _patch_make_causal_mask(
66-
input_ids_shape=input_ids_shape,
67-
dtype=dtype,
68-
device=device,
69-
past_key_values_length=past_key_values_length,
70-
sliding_window=sliding_window,
71-
)
66+
"""
67+
Patched method.
68+
69+
This static method may be called with ``AttentionMaskConverter._make_causal_mask``
70+
or ``self._make_causal_mask``. That changes this argument is receives.
71+
That should not matter but...
72+
"""
73+
if args:
74+
index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
75+
names = [
76+
"input_ids_shape",
77+
"dtype",
78+
"device",
79+
"past_key_values_length",
80+
"sliding_window",
81+
]
82+
for i, a in enumerate(args):
83+
if i < index:
84+
continue
85+
kwargs[names[i - index]] = a
86+
return _patch_make_causal_mask(**kwargs)
7287

7388

7489
class patched_DynamicCache:

0 commit comments

Comments
 (0)