@@ -102,31 +102,21 @@ def get_untrained_model_with_inputs(
102102 arch = archs [0 ]
103103 if verbose :
104104 print (f"[get_untrained_model_with_inputs] architecture={ arch !r} " )
105- cls = config_class_from_architecture (arch , exc = False )
106- if cls is None :
107- if verbose :
108- print (
109- "[get_untrained_model_with_inputs] no found config name in the code, loads it"
110- )
111- config = get_pretrained_config (model_id )
112- cls = config .__class__
105+ config = get_pretrained_config (model_id )
106+ if verbose :
107+ print (f"[get_untrained_model_with_inputs] cls={ config .__class__ .__name__ !r} " )
108+ task = task_from_arch (arch )
113109 if verbose :
114- print (f"[get_untrained_model_with_inputs] cls= { cls . __name__ !r} " )
110+ print (f"[get_untrained_model_with_inputs] task= { task !r} " )
115111
116- # model creation
117- kwargs : Dict [str , Any ] = dict (
118- num_hidden_layers = 1 ,
119- )
112+ # model kwagrs
120113 if dynamic_rope is not None :
121- kwargs [ " rope_scaling" ] = (
114+ config . rope_scaling = (
122115 {"rope_type" : "dynamic" , "factor" : 10.0 } if dynamic_rope else None
123116 )
124117 if model_kwargs :
125- kwargs .update (model_kwargs )
126- config = cls (** kwargs )
127- task = task_from_arch (arch )
128- if verbose :
129- print (f"[get_untrained_model_with_inputs] task={ task !r} " )
118+ for k , v in model_kwargs .items ():
119+ setattr (config , k , v )
130120
131121 if task == "text-generation" :
132122 kwargs = dict (
@@ -136,7 +126,7 @@ def get_untrained_model_with_inputs(
136126 head_dim = getattr (
137127 config , "head_dim" , config .hidden_size // config .num_attention_heads
138128 ),
139- max_token_id = config .vocab_size - 1 ,
129+ dummy_max_token_id = config .vocab_size - 1 ,
140130 num_hidden_layers = min (config .num_hidden_layers , 2 ),
141131 num_key_value_heads = (
142132 config .num_key_value_heads
@@ -154,25 +144,29 @@ def get_untrained_model_with_inputs(
154144 else config .hidden_size
155145 ),
156146 )
157- if inputs_kwargs :
158- kwargs .update (inputs_kwargs )
159147
160- _update_config (config , kwargs )
161- model = getattr (transformers , arch )(config )
162148 fct = get_inputs_for_text_generation
163149 elif task == "image-classification" :
164- kwargs = dict (
165- batch_size = 2 ,
166- width = config .image_size ,
167- height = config .image_size ,
168- channels = config .num_channels ,
169- )
170- if inputs_kwargs :
171- kwargs .update (inputs_kwargs )
150+ if isinstance (config .image_size , int ):
151+ kwargs = dict (
152+ batch_size = 2 ,
153+ input_width = config .image_size ,
154+ input_height = config .image_size ,
155+ input_channels = config .num_channels ,
156+ )
157+ else :
158+ kwargs = dict (
159+ batch_size = 2 ,
160+ input_width = config .image_size [0 ],
161+ input_height = config .image_size [1 ],
162+ input_channels = config .num_channels ,
163+ )
172164 fct = get_inputs_for_image_classification
173165 else :
174166 raise NotImplementedError (f"Input generation for task { task !r} not implemented yet." )
175167
168+ if inputs_kwargs :
169+ kwargs .update (inputs_kwargs )
176170 true_kwargs = (inputs_kwargs or {}) if same_as_pretrained else kwargs
177171 _update_config (config , true_kwargs )
178172 model = getattr (transformers , arch )(config )
@@ -192,7 +186,7 @@ def compute_model_size(model: torch.nn.Module) -> Tuple[int, int]:
192186def get_inputs_for_text_generation (
193187 model : torch .nn .Module ,
194188 config : Optional [Any ],
195- max_token_id : int ,
189+ dummy_max_token_id : int ,
196190 num_key_value_heads : int ,
197191 num_hidden_layers : int ,
198192 head_dim : int ,
@@ -208,6 +202,7 @@ def get_inputs_for_text_generation(
208202 :param model: model to get the missing information
209203 :param config: configuration used to generate the model
210204 :param head_dim: last dimension of the cache
205+ :param dummy_max_token_id: dummy max token id
211206 :param batch_size: batch size
212207 :param sequence_length: sequence length
213208 :param sequence_length2: new sequence length
@@ -235,7 +230,7 @@ def get_inputs_for_text_generation(
235230 ],
236231 }
237232 inputs = dict (
238- input_ids = torch .randint (0 , max_token_id , (batch_size , sequence_length2 )).to (
233+ input_ids = torch .randint (0 , dummy_max_token_id , (batch_size , sequence_length2 )).to (
239234 torch .int64
240235 ),
241236 attention_mask = torch .ones ((batch_size , sequence_length + sequence_length2 )).to (
@@ -268,9 +263,9 @@ def get_inputs_for_text_generation(
268263def get_inputs_for_image_classification (
269264 model : torch .nn .Module ,
270265 config : Optional [Any ],
271- width : int ,
272- height : int ,
273- channels : int ,
266+ input_width : int ,
267+ input_height : int ,
268+ input_channels : int ,
274269 batch_size : int = 2 ,
275270 dynamic_rope : bool = False ,
276271 ** kwargs ,
@@ -281,9 +276,18 @@ def get_inputs_for_image_classification(
281276 :param model: model to get the missing information
282277 :param config: configuration used to generate the model
283278 :param batch_size: batch size
279+ :param input_channel: input channel
280+ :param input_width: input width
281+ :param input_height: input height
284282 :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
285283 :return: dictionary
286284 """
285+ assert isinstance (
286+ input_width , int
287+ ), f"Unexpected type for input_width { type (input_width )} { config } "
288+ assert isinstance (
289+ input_width , int
290+ ), f"Unexpected type for input_height { type (input_height )} { config } "
287291
288292 shapes = {
289293 "pixel_values" : {
@@ -293,7 +297,9 @@ def get_inputs_for_image_classification(
293297 },
294298 }
295299 inputs = dict (
296- pixel_values = torch .randn (batch_size , channels , width , height ).clamp (- 1 , 1 ),
300+ pixel_values = torch .randn (batch_size , input_channels , input_width , input_height ).clamp (
301+ - 1 , 1
302+ ),
297303 )
298304 sizes = compute_model_size (model )
299305 return dict (
0 commit comments