11import inspect
22import os
3+ import pprint
34from typing import Any , Dict , Optional , Tuple
45import torch
56import transformers
@@ -22,6 +23,7 @@ def get_untrained_model_with_inputs(
2223 model_kwargs : Optional [Dict [str , Any ]] = None ,
2324 verbose : int = 0 ,
2425 dynamic_rope : Optional [bool ] = None ,
26+ use_pretrained : bool = False ,
2527 same_as_pretrained : bool = False ,
2628 use_preinstalled : bool = True ,
2729 add_second_input : bool = False ,
@@ -43,6 +45,7 @@ def get_untrained_model_with_inputs(
4345 :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
4446 :param same_as_pretrained: if True, do not change the default values
4547 to get a smaller model
48+ :param use_pretrained: download the pretrained weights as well
4649 :param use_preinstalled: use preinstalled configurations
4750 :param add_second_input: provides a second inputs to check a model
4851 supports different shapes
@@ -68,6 +71,10 @@ def get_untrained_model_with_inputs(
6871 print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes']))
6972 print("-- configuration:", pprint.pformat(data['configuration']))
7073 """
74+ assert not use_preinstalled or not use_only_preinstalled , (
75+ f"model_id={ model_id !r} , pretinstalled model is only available "
76+ f"if use_only_preinstalled is False."
77+ )
7178 if verbose :
7279 print (f"[get_untrained_model_with_inputs] model_id={ model_id !r} " )
7380 if use_preinstalled :
@@ -99,7 +106,7 @@ def get_untrained_model_with_inputs(
99106 print (f"[get_untrained_model_with_inputs] architectures={ archs !r} " )
100107 print (f"[get_untrained_model_with_inputs] cls={ config .__class__ .__name__ !r} " )
101108 if task is None :
102- task = task_from_arch (archs [0 ])
109+ task = task_from_arch (archs [0 ], model_id = model_id )
103110 if verbose :
104111 print (f"[get_untrained_model_with_inputs] task={ task !r} " )
105112
@@ -114,7 +121,6 @@ def get_untrained_model_with_inputs(
114121 )
115122
116123 # updating the configuration
117-
118124 mkwargs = reduce_model_config (config , task ) if not same_as_pretrained else {}
119125 if model_kwargs :
120126 for k , v in model_kwargs .items ():
@@ -139,27 +145,28 @@ def get_untrained_model_with_inputs(
139145 f"{ config ._attn_implementation !r} " # type: ignore[union-attr]
140146 )
141147
148+ if use_pretrained :
149+ model = transformers .AutoModel .from_pretrained (model_id , ** mkwargs )
150+ else :
151+ if archs is not None :
152+ model = getattr (transformers , archs [0 ])(config )
153+ else :
154+ assert same_as_pretrained and use_pretrained , (
155+ f"Model { model_id !r} cannot be built, the model cannot be built. "
156+ f"It must be downloaded. Use same_as_pretrained=True "
157+ f"and use_pretrained=True."
158+ )
159+
142160 # input kwargs
143161 kwargs , fct = random_input_kwargs (config , task )
144162 if verbose :
145163 print (f"[get_untrained_model_with_inputs] use fct={ fct } " )
146164 if os .environ .get ("PRINT_CONFIG" ) in (1 , "1" ):
147- import pprint
148-
149165 print (f"-- input kwargs for task { task !r} " )
150166 pprint .pprint (kwargs )
151167 if inputs_kwargs :
152168 kwargs .update (inputs_kwargs )
153169
154- if archs is not None :
155- model = getattr (transformers , archs [0 ])(config )
156- else :
157- assert same_as_pretrained , (
158- f"Model { model_id !r} cannot be built, the model cannot be built. "
159- f"It must be downloaded. Use same_as_pretrained=True."
160- )
161- model = None
162-
163170 # This line is important. Some models may produce different
164171 # outputs even with the same inputs in training mode.
165172 model .eval ()
0 commit comments