11from typing import Any , Callable , Dict , Optional , Tuple
22import torch
3- from ..helpers .cache_helper import make_dynamic_cache
4- from ..helpers .config_helper import update_config , check_hasattr , _pick
3+ # from ..helpers.cache_helper import make_dynamic_cache
4+ from ..helpers .config_helper import update_config # , check_hasattr, _pick
55
66__TASK__ = "MoE"
77
@@ -43,7 +43,7 @@ def get_inputs(
4343 ** kwargs , # unused
4444):
4545 """
46- Generates input for task ``text-generation ``.
46+ Generates input for task ``MoE ``.
4747
4848 :param model: model to get the missing information
4949 :param config: configuration used to generate the model
@@ -59,55 +59,7 @@ def get_inputs(
5959 :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
6060 :return: dictionary
6161 """
62- batch = torch .export .Dim ("batch" , min = 1 , max = 1024 )
63- seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
64- cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
65- images = "images" # torch.export.Dim("images", min=1, max=4096)
66-
67- shapes = {
68- "input_ids" : {0 : batch , 1 : seq_length },
69- "attention_mask" : {
70- 0 : batch ,
71- 1 : "cache+seq" , # cache_length + seq_length
72- },
73- "position_ids" : {
74- 0 : batch ,
75- 1 : "cache+seq" , # cache_length + seq_length
76- },
77- "past_key_values" : [
78- [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
79- [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
80- ],
81- "pixel_values" : {0 : batch , 1 : images },
82- "image_attention_mask" : {0 : batch , 1 : seq_length , 2 : images },
83- }
84- inputs = dict (
85- input_ids = torch .randint (0 , dummy_max_token_id , (batch_size , sequence_length2 )).to (
86- torch .int64
87- ),
88- attention_mask = torch .ones ((batch_size , sequence_length + sequence_length2 )).to (
89- torch .int64
90- ),
91- position_ids = torch .arange (sequence_length , sequence_length + sequence_length2 )
92- .to (torch .int64 )
93- .expand ((batch_size , - 1 )),
94- past_key_values = make_dynamic_cache (
95- [
96- (
97- torch .randn (batch_size , num_key_value_heads , sequence_length , head_dim ),
98- torch .randn (batch_size , num_key_value_heads , sequence_length , head_dim ),
99- )
100- for i in range (num_hidden_layers )
101- ]
102- ),
103- image_attention_mask = torch .ones ((batch_size , sequence_length2 , n_images )).to (
104- torch .int64
105- ),
106- pixel_values = torch .ones ((batch_size , n_images , num_channels , width , height )).to (
107- torch .int64
108- ),
109- )
110- return dict (inputs = inputs , dynamic_shapes = shapes )
62+ raise NotImplementedError (f"get_inputs not yet implemented for task { __TASK__ !r} ." )
11163
11264
11365def random_input_kwargs (config : Any ) -> Tuple [Dict [str , Any ], Callable ]:
@@ -116,39 +68,6 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
11668
11769 If the configuration is None, the function selects typical dimensions.
11870 """
119- if config is not None :
120- check_hasattr (
121- config ,
122- "vocab_size" ,
123- "hidden_size" ,
124- "num_attention_heads" ,
125- ("num_key_value_heads" , "num_attention_heads" ),
126- "intermediate_size" ,
127- "hidden_size" ,
128- "vision_config" ,
129- "audio_processor" ,
130- )
131- check_hasattr (config .vision_config , "image_size" , "num_channels" )
132- kwargs = dict (
133- batch_size = 2 ,
134- sequence_length = 30 ,
135- sequence_length2 = 3 ,
136- head_dim = (
137- 16
138- if config is None
139- else getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
140- ),
141- dummy_max_token_id = 31999 if config is None else config .vocab_size - 1 ,
142- num_hidden_layers = 4 if config is None else config .num_hidden_layers ,
143- num_key_value_heads = (
144- 8
145- if config is None
146- else _pick (config , "num_key_value_heads" , "num_attention_heads" )
147- ),
148- intermediate_size = 1024 if config is None else config .intermediate_size ,
149- hidden_size = 512 if config is None else config .hidden_size ,
150- width = 224 if config is None else config .vision_config .image_size ,
151- height = 224 if config is None else config .vision_config .image_size ,
152- num_channels = 3 if config is None else config .vision_config .num_channels ,
71+ raise NotImplementedError (
72+ f"random_input_kwargs not yet implemented for task { __TASK__ !r} ."
15373 )
154- return kwargs , get_inputs
0 commit comments