@@ -21,9 +21,11 @@ def get_inputs(
2121 model : torch .nn .Module ,
2222 config : Optional [Any ],
2323 dummy_max_token_id : int ,
24- num_key_value_heads : int ,
24+ num_key_value_heads_encoder : int ,
25+ num_key_value_heads_decoder : int ,
2526 num_hidden_layers : int ,
26- head_dim : int ,
27+ head_dim_encoder : int ,
28+ head_dim_decoder : int ,
2729 encoder_dim : int ,
2830 batch_size : int = 2 ,
2931 sequence_length : int = 30 ,
@@ -36,7 +38,10 @@ def get_inputs(
3638
3739 :param model: model to get the missing information
3840 :param config: configuration used to generate the model
39- :param head_dim: last dimension of the cache
41+ :param head_dim_encoder: last dimension of the cache for the encoder
42+ :param head_dim_decoder: last dimension of the cache for the decoder
43+ :param num_key_value_heads_encoder: number of heads for the encoder
44+ :param num_key_value_heads_decoder: number of heads for the decoder
4045 :param dummy_max_token_id: dummy max token id
4146 :param batch_size: batch size
4247 :param encoder_dim: last dimension of encoder_last_hidden_state
@@ -83,6 +88,7 @@ def get_inputs(
8388 # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
8489 # "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
8590 }
91+
8692 inputs = dict (
8793 input_ids = torch .randint (0 , dummy_max_token_id , (batch_size , sequence_length )).to (
8894 torch .int64
@@ -99,10 +105,16 @@ def get_inputs(
99105 [
100106 (
101107 torch .randn (
102- batch_size , num_key_value_heads , sequence_length , head_dim
108+ batch_size ,
109+ num_key_value_heads_encoder ,
110+ sequence_length ,
111+ head_dim_encoder ,
103112 ),
104113 torch .randn (
105- batch_size , num_key_value_heads , sequence_length , head_dim
114+ batch_size ,
115+ num_key_value_heads_encoder ,
116+ sequence_length ,
117+ head_dim_encoder ,
106118 ),
107119 )
108120 for i in range (num_hidden_layers )
@@ -112,10 +124,16 @@ def get_inputs(
112124 [
113125 (
114126 torch .randn (
115- batch_size , num_key_value_heads , sequence_length2 , head_dim
127+ batch_size ,
128+ num_key_value_heads_decoder ,
129+ sequence_length2 ,
130+ head_dim_decoder ,
116131 ),
117132 torch .randn (
118- batch_size , num_key_value_heads , sequence_length2 , head_dim
133+ batch_size ,
134+ num_key_value_heads_decoder ,
135+ sequence_length2 ,
136+ head_dim_decoder ,
119137 ),
120138 )
121139 for i in range (num_hidden_layers )
@@ -132,9 +150,11 @@ def get_inputs(
132150 model = model ,
133151 config = config ,
134152 dummy_max_token_id = dummy_max_token_id ,
135- num_key_value_heads = num_key_value_heads ,
153+ num_key_value_heads_encoder = num_key_value_heads_encoder ,
154+ num_key_value_heads_decoder = num_key_value_heads_decoder ,
136155 num_hidden_layers = num_hidden_layers ,
137- head_dim = head_dim ,
156+ head_dim_encoder = head_dim_encoder ,
157+ head_dim_decoder = head_dim_decoder ,
138158 encoder_dim = encoder_dim ,
139159 batch_size = batch_size + 1 ,
140160 sequence_length = sequence_length + 1 ,
@@ -173,20 +193,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
173193 batch_size = 2 ,
174194 sequence_length = 30 ,
175195 sequence_length2 = 3 ,
176- head_dim = 16 if config is None else (config .d_kv if hasattr (config , "d_kv" ) else 1 ),
196+ head_dim_encoder = 16 if config is None else _pick (config , "d_kv" , "encoder_ffn_dim" ),
197+ head_dim_decoder = 16 if config is None else _pick (config , "d_kv" , "decoder_ffn_dim" ),
177198 dummy_max_token_id = 31999 if config is None else config .vocab_size - 1 ,
178199 num_hidden_layers = (
179200 8 if config is None else _pick (config , "num_hidden_layers" , "num_layers" )
180201 ),
181- num_key_value_heads = (
202+ num_key_value_heads_encoder = (
203+ 16
204+ if config is None
205+ else _pick (
206+ config ,
207+ "encoder_attention_heads" ,
208+ "num_key_value_heads" ,
209+ "num_heads" ,
210+ )
211+ ),
212+ num_key_value_heads_decoder = (
182213 16
183214 if config is None
184215 else _pick (
185216 config ,
217+ "decoder_attention_heads" ,
186218 "num_key_value_heads" ,
187219 "num_heads" ,
188- (sum , "encoder_attention_heads" , "decoder_attention_heads" ),
189- # exceptions=exceptions,
190220 )
191221 ),
192222 encoder_dim = 512 if config is None else _pick (config , "n_positions" , "d_model" ),
0 commit comments