1010
1111
1212@functools .cache
13- def config_class_from_architecture (arch : str ) -> type :
13+ def config_class_from_architecture (arch : str , exc : bool = False ) -> type :
1414 """
1515 Retrieves the configuration class for a given architecture.
16+
17+ :param arch: architecture (clas name)
18+ :param exc: raise an exception if not found
19+ :return: type
1620 """
1721 cls = getattr (transformers , arch )
1822 mod_name = cls .__module__
1923 mod = importlib .import_module (mod_name )
2024 source = inspect .getsource (mod )
2125 reg = re .compile ("config: ([A-Za-z0-9]+)" )
2226 fall = reg .findall (source )
27+ if len (fall ) == 0 :
28+ assert not exc , (
29+ f"Unable to guess Configuration class name for arch={ arch !r} , "
30+ f"module={ mod_name !r} , no candidate, source is\n { source } "
31+ )
32+ return None
2333 unique = set (fall )
2434 assert len (unique ) == 1 , (
2535 f"Unable to guess Configuration class name for arch={ arch !r} , "
26- f"module={ mod_name !r} , source is\n { source } "
36+ f"module={ mod_name !r} , found={ unique } (#{ len (unique )} ), "
37+ f"source is\n { source } "
2738 )
2839 cls_name = unique .pop ()
2940 return getattr (transformers , cls_name )
@@ -81,7 +92,14 @@ def get_untrained_model_with_inputs(
8192 arch = archs [0 ]
8293 if verbose :
8394 print (f"[get_untrained_model_with_inputs] architecture={ arch !r} " )
84- cls = config_class_from_architecture (arch )
95+ cls = config_class_from_architecture (arch , exc = False )
96+ if cls is None :
97+ if verbose :
98+ print (
99+ "[get_untrained_model_with_inputs] no found config name in the code, loads it"
100+ )
101+ config = get_pretrained_config (model_id )
102+ cls = config .__class__
85103 if verbose :
86104 print (f"[get_untrained_model_with_inputs] cls={ cls .__name__ !r} " )
87105
@@ -107,12 +125,16 @@ def get_untrained_model_with_inputs(
107125 batch_size = 2 ,
108126 sequence_length = 30 ,
109127 sequence_length2 = 3 ,
110- num_hidden_layers = config .num_hidden_layers ,
111- num_key_value_heads = config .num_key_value_heads ,
112128 head_dim = getattr (
113129 config , "head_dim" , config .hidden_size // config .num_attention_heads
114130 ),
115131 max_token_id = config .vocab_size - 1 ,
132+ num_hidden_layers = config .num_hidden_layers ,
133+ num_key_value_heads = (
134+ config .num_key_value_heads
135+ if hasattr (config , "num_key_value_heads" )
136+ else config .num_attention_heads
137+ ),
116138 )
117139 if inputs_kwargs :
118140 kwargs .update (inputs_kwargs )
0 commit comments