@@ -40,6 +40,12 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
4040 return getattr (transformers , cls_name )
4141
4242
43+ def _update_config (config : Any , kwargs : Dict [str , Any ]):
44+ for k , v in kwargs .items ():
45+ if hasattr (config , k ):
46+ setattr (config , k , v )
47+
48+
4349def get_untrained_model_with_inputs (
4450 model_id : str ,
4551 config : Optional [Any ] = None ,
@@ -48,6 +54,7 @@ def get_untrained_model_with_inputs(
4854 model_kwargs : Optional [Dict [str , Any ]] = None ,
4955 verbose : int = 0 ,
5056 dynamic_rope : Optional [bool ] = None ,
57+ same_as_pretrained : bool = False ,
5158) -> Dict [str , Any ]:
5259 """
5360 Gets a non initialized model similar to the original model
@@ -62,6 +69,8 @@ def get_untrained_model_with_inputs(
6269 :param model_kwargs: to change the model generation
6370 :param verbose: display found information
6471 :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
72+ :param same_as_pretrained: if True, do not change the default values
73+ to get a smaller model
6574 :return: dictionary with a model, inputs, dynamic shapes, and the configuration
6675
6776 Example:
@@ -115,8 +124,6 @@ def get_untrained_model_with_inputs(
115124 if model_kwargs :
116125 kwargs .update (model_kwargs )
117126 config = cls (** kwargs )
118- model = getattr (transformers , arch )(config )
119-
120127 task = task_from_arch (arch )
121128 if verbose :
122129 print (f"[get_untrained_model_with_inputs] task={ task !r} " )
@@ -130,18 +137,46 @@ def get_untrained_model_with_inputs(
130137 config , "head_dim" , config .hidden_size // config .num_attention_heads
131138 ),
132139 max_token_id = config .vocab_size - 1 ,
133- num_hidden_layers = config .num_hidden_layers ,
140+ num_hidden_layers = min ( config .num_hidden_layers , 2 ) ,
134141 num_key_value_heads = (
135142 config .num_key_value_heads
136143 if hasattr (config , "num_key_value_heads" )
137144 else config .num_attention_heads
138145 ),
146+ intermediate_size = (
147+ min (config .intermediate_size , 24576 // 4 )
148+ if config .intermediate_size % 4 == 0
149+ else config .intermediate_size
150+ ),
151+ hidden_size = (
152+ min (config .hidden_size , 3072 // 4 )
153+ if config .hidden_size % 4 == 0
154+ else config .hidden_size
155+ ),
139156 )
140157 if inputs_kwargs :
141158 kwargs .update (inputs_kwargs )
142159
143- return get_inputs_for_text_generation (model , config , ** kwargs )
144- raise NotImplementedError (f"Input generation for task { task !r} not implemented yet." )
160+ _update_config (config , kwargs )
161+ model = getattr (transformers , arch )(config )
162+ fct = get_inputs_for_text_generation
163+ elif task == "image-classification" :
164+ kwargs = dict (
165+ batch_size = 2 ,
166+ width = config .image_size ,
167+ height = config .image_size ,
168+ channels = config .num_channels ,
169+ )
170+ if inputs_kwargs :
171+ kwargs .update (inputs_kwargs )
172+ fct = get_inputs_for_image_classification
173+ else :
174+ raise NotImplementedError (f"Input generation for task { task !r} not implemented yet." )
175+
176+ true_kwargs = (inputs_kwargs or {}) if same_as_pretrained else kwargs
177+ _update_config (config , true_kwargs )
178+ model = getattr (transformers , arch )(config )
179+ return fct (model , config , ** true_kwargs )
145180
146181
147182def compute_model_size (model : torch .nn .Module ) -> Tuple [int , int ]:
@@ -168,6 +203,8 @@ def get_inputs_for_text_generation(
168203 ** kwargs ,
169204):
170205 """
206+ Generates input for task ``text-generation``.
207+
171208 :param model: model to get the missing information
172209 :param config: configuration used to generate the model
173210 :param head_dim: last dimension of the cache
@@ -226,3 +263,44 @@ def get_inputs_for_text_generation(
226263 n_weights = sizes [1 ],
227264 configuration = config ,
228265 )
266+
267+
268+ def get_inputs_for_image_classification (
269+ model : torch .nn .Module ,
270+ config : Optional [Any ],
271+ width : int ,
272+ height : int ,
273+ channels : int ,
274+ batch_size : int = 2 ,
275+ dynamic_rope : bool = False ,
276+ ** kwargs ,
277+ ):
278+ """
279+ Generates inputs for task ``image-classification``.
280+
281+ :param model: model to get the missing information
282+ :param config: configuration used to generate the model
283+ :param batch_size: batch size
284+ :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
285+ :return: dictionary
286+ """
287+
288+ shapes = {
289+ "pixel_values" : {
290+ 0 : torch .export .Dim ("batch" , min = 1 , max = 1024 ),
291+ 2 : torch .export .Dim ("width" , min = 1 , max = 4096 ),
292+ 3 : torch .export .Dim ("height" , min = 1 , max = 4096 ),
293+ },
294+ }
295+ inputs = dict (
296+ pixel_values = torch .randn (batch_size , channels , width , height ).clamp (- 1 , 1 ),
297+ )
298+ sizes = compute_model_size (model )
299+ return dict (
300+ model = model ,
301+ inputs = inputs ,
302+ dynamic_shapes = shapes ,
303+ size = sizes [0 ],
304+ n_weights = sizes [1 ],
305+ configuration = config ,
306+ )
0 commit comments