@@ -164,9 +164,6 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
164
164
total_num_attention_heads = self .hf_config .num_attention_heads
165
165
return total_num_attention_heads // parallel_config .tensor_parallel_size
166
166
167
- def get_max_model_len (self ) -> int :
168
- return self .max_model_len
169
-
170
167
def get_num_layers (self , parallel_config : "ParallelConfig" ) -> int :
171
168
total_num_hidden_layers = self .hf_config .num_hidden_layers
172
169
return total_num_hidden_layers // parallel_config .pipeline_parallel_size
@@ -378,10 +375,17 @@ def _get_and_verify_max_len(
378
375
if max_len_key is not None :
379
376
derived_max_model_len = min (derived_max_model_len , max_len_key )
380
377
if derived_max_model_len == float ("inf" ):
381
- raise ValueError (
382
- "The model's config.json must contain one of the following keys "
383
- "to determine the original maximum length of the model: "
384
- f"{ possible_keys } " )
378
+ if max_model_len is not None :
379
+ # If max_model_len is specified, we use it.
380
+ return max_model_len
381
+
382
+ default_max_len = 2048
383
+ logger .warning (
384
+ "The model's config.json does not contain any of the following "
385
+ "keys to determine the original maximum length of the model: "
386
+ f"{ possible_keys } . Assuming the model's maximum length is "
387
+ f"{ default_max_len } ." )
388
+ derived_max_model_len = default_max_len
385
389
386
390
rope_scaling = getattr (hf_config , "rope_scaling" , None )
387
391
if rope_scaling is not None :
0 commit comments