@@ -55,20 +55,35 @@ class patched_AttentionMaskConverter:
5555
5656 @staticmethod
5757 def _make_causal_mask (
58- input_ids_shape : torch .Size ,
59- dtype : torch .dtype ,
60- device : torch .device ,
61- past_key_values_length : int = 0 ,
62- sliding_window : Optional [int ] = None ,
58+ * args ,
59+ ** kwargs ,
60+ # input_ids_shape: torch.Size,
61+ # dtype: torch.dtype,
62+ # device: torch.device,
63+ # past_key_values_length: int = 0,
64+ # sliding_window: Optional[int] = None,
6365 ):
64- """Patched method."""
65- return _patch_make_causal_mask (
66- input_ids_shape = input_ids_shape ,
67- dtype = dtype ,
68- device = device ,
69- past_key_values_length = past_key_values_length ,
70- sliding_window = sliding_window ,
71- )
66+ """
67+ Patched method.
68+
69+ This static method may be called with ``AttentionMaskConverter._make_causal_mask``
70+ or ``self._make_causal_mask``. That changes this argument is receives.
71+ That should not matter but...
72+ """
73+ if args :
74+ index = 0 if isinstance (args [0 ], (tuple , torch .Size )) else 1
75+ names = [
76+ "input_ids_shape" ,
77+ "dtype" ,
78+ "device" ,
79+ "past_key_values_length" ,
80+ "sliding_window" ,
81+ ]
82+ for i , a in enumerate (args ):
83+ if i < index :
84+ continue
85+ kwargs [names [i - index ]] = a
86+ return _patch_make_causal_mask (** kwargs )
7287
7388
7489class patched_DynamicCache :
0 commit comments