77from ...helpers .config_helper import update_config
88from ...tasks import reduce_model_config , random_input_kwargs
99from .hub_api import task_from_arch , task_from_id , get_pretrained_config , download_code_modelid
10+ import diffusers
1011
1112
1213def _code_needing_rewriting (model : Any ) -> Any :
@@ -18,7 +19,7 @@ def _code_needing_rewriting(model: Any) -> Any:
1819def get_untrained_model_with_inputs (
1920 model_id : str ,
2021 config : Optional [Any ] = None ,
21- task : Optional [str ] = "" ,
22+ task : Optional [str ] = None ,
2223 inputs_kwargs : Optional [Dict [str , Any ]] = None ,
2324 model_kwargs : Optional [Dict [str , Any ]] = None ,
2425 verbose : int = 0 ,
@@ -88,14 +89,20 @@ def get_untrained_model_with_inputs(
8889 ** (model_kwargs or {}),
8990 )
9091
91- if hasattr (config , "architecture" ) and config .architecture :
92- archs = [config .architecture ]
93- if type (config ) is dict :
94- assert "_class_name" in config , f"Unable to get the architecture from config={ config } "
95- archs = [config ["_class_name" ]]
92+ # Extract architecture information from config
93+ archs = None
94+ if isinstance (config , dict ):
95+ if "_class_name" in config :
96+ archs = [config ["_class_name" ]]
97+ else :
98+ raise ValueError (f"Unable to get the architecture from config={ config } " )
9699 else :
97- archs = config .architectures # type: ignore
98- task = None
100+ # Config is an object (e.g., transformers config)
101+ if hasattr (config , "architecture" ) and config .architecture :
102+ archs = [config .architecture ]
103+ elif hasattr (config , "architectures" ) and config .architectures :
104+ archs = config .architectures
105+
99106 if archs is None :
100107 task = task_from_id (model_id )
101108 assert task is not None or (archs is not None and len (archs ) == 1 ), (
@@ -112,9 +119,9 @@ def get_untrained_model_with_inputs(
112119
113120 # model kwagrs
114121 if dynamic_rope is not None :
115- assert (
116- type ( config ) is not dict
117- ), f"Unable to set dynamic_rope if the configuration is a dictionary \n { config } "
122+ assert type ( config ) is not dict , (
123+ f"Unable to set dynamic_rope if the configuration is a dictionary \n { config } "
124+ )
118125 assert hasattr (config , "rope_scaling" ), f"Missing 'rope_scaling' in\n { config } "
119126 config .rope_scaling = (
120127 {"rope_type" : "dynamic" , "factor" : 10.0 } if dynamic_rope else None
@@ -150,9 +157,7 @@ def get_untrained_model_with_inputs(
150157 f"{ getattr (config , '_attn_implementation' , '?' )!r} " # type: ignore[union-attr]
151158 )
152159
153- if type (config ) is dict and "_diffusers_version" in config :
154- import diffusers
155-
160+ if isinstance (config , dict ) and "_diffusers_version" in config :
156161 package_source = diffusers
157162 else :
158163 package_source = transformers
@@ -206,7 +211,7 @@ def get_untrained_model_with_inputs(
206211 )
207212
208213 try :
209- if type (config ) is dict :
214+ if isinstance (config , dict ) :
210215 model = cls_model (** config )
211216 else :
212217 model = cls_model (config )
0 commit comments