11from typing import Any , Callable , Dict , Optional , Tuple
22import torch
33from ..helpers .config_helper import update_config , check_hasattr
4+ from ..helpers .cache_helper import make_dynamic_cache , make_encoder_decoder_cache
45
56__TASK__ = "feature-extraction"
67
78
89def reduce_model_config (config : Any ) -> Dict [str , Any ]:
910 """Reduces a model size."""
10- check_hasattr (config , "num_attention_heads" , "num_hidden_layers" )
11- kwargs = dict (
12- num_hidden_layers = min (config .num_hidden_layers , 2 ),
13- num_attention_heads = min (config .num_attention_heads , 4 ),
14- )
11+ check_hasattr (config , "num_hidden_layers" )
12+ kwargs = dict (num_hidden_layers = min (config .num_hidden_layers , 2 ))
1513 update_config (config , kwargs )
1614 return kwargs
1715
@@ -22,6 +20,12 @@ def get_inputs(
2220 batch_size : int ,
2321 sequence_length : int ,
2422 dummy_max_token_id : int ,
23+ sequence_length2 : int = 3 ,
24+ decoder_attention_heads : Optional [int ] = None ,
25+ encoder_attention_heads : Optional [int ] = None ,
26+ encoder_ffn_dim : Optional [int ] = None ,
27+ decoder_ffn_dim : Optional [int ] = None ,
28+ num_hidden_layers : Optional [int ] = None ,
2529 add_second_input : int = 1 ,
2630 ** kwargs , # unused
2731):
@@ -50,6 +54,66 @@ def get_inputs(
5054 ),
5155 attention_mask = torch .ones ((batch_size , sequence_length )).to (torch .int64 ),
5256 )
57+ if (
58+ encoder_attention_heads
59+ and decoder_attention_heads
60+ and encoder_ffn_dim
61+ and decoder_ffn_dim
62+ and num_hidden_layers
63+ ):
64+ inputs ["past_key_values" ] = make_encoder_decoder_cache (
65+ make_dynamic_cache (
66+ [
67+ (
68+ torch .randn (
69+ batch_size ,
70+ encoder_attention_heads ,
71+ sequence_length ,
72+ encoder_ffn_dim ,
73+ ),
74+ torch .randn (
75+ batch_size ,
76+ encoder_attention_heads ,
77+ sequence_length ,
78+ encoder_ffn_dim ,
79+ ),
80+ )
81+ for i in range (num_hidden_layers )
82+ ]
83+ ),
84+ make_dynamic_cache (
85+ [
86+ (
87+ torch .randn (
88+ batch_size ,
89+ decoder_attention_heads ,
90+ sequence_length2 ,
91+ decoder_ffn_dim ,
92+ ),
93+ torch .randn (
94+ batch_size ,
95+ decoder_attention_heads ,
96+ sequence_length2 ,
97+ decoder_ffn_dim ,
98+ ),
99+ )
100+ for i in range (num_hidden_layers )
101+ ]
102+ ),
103+ )
104+ cache_length = "cache_length_key"
105+ cache_length2 = "cache_length_val"
106+ shapes ["past_key_values" ] = [ # type: ignore[assignment]
107+ [
108+ [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
109+ [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
110+ ],
111+ [
112+ [{0 : batch , 2 : cache_length2 } for _ in range (num_hidden_layers )],
113+ [{0 : batch , 2 : cache_length2 } for _ in range (num_hidden_layers )],
114+ ],
115+ ]
116+
53117 res = dict (inputs = inputs , dynamic_shapes = shapes )
54118 if add_second_input :
55119 assert (
@@ -61,6 +125,12 @@ def get_inputs(
61125 batch_size = batch_size + 1 ,
62126 sequence_length = sequence_length + add_second_input ,
63127 dummy_max_token_id = dummy_max_token_id ,
128+ sequence_length2 = sequence_length2 ,
129+ decoder_attention_heads = decoder_attention_heads ,
130+ encoder_attention_heads = encoder_attention_heads ,
131+ encoder_ffn_dim = encoder_ffn_dim ,
132+ decoder_ffn_dim = decoder_ffn_dim ,
133+ num_hidden_layers = num_hidden_layers ,
64134 add_second_input = 0 ,
65135 ** kwargs ,
66136 )["inputs" ]
@@ -80,4 +150,15 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
80150 sequence_length = 30 ,
81151 dummy_max_token_id = 31999 if config is None else (config .vocab_size - 1 ),
82152 )
153+ for att in [
154+ "decoder_attention_heads" ,
155+ "encoder_attention_heads" ,
156+ "encoder_ffn_dim" ,
157+ "decoder_ffn_dim" ,
158+ "num_hidden_layers" ,
159+ ]:
160+ if hasattr (config , att ):
161+ kwargs [att ] = getattr (config , att )
162+ kwargs ["decoder_ffn_dim" ] = kwargs ["encoder_ffn_dim" ] = 64
163+ print (kwargs )
83164 return kwargs , get_inputs
0 commit comments