@@ -91,6 +91,7 @@ class ModelArgs:
9191 norm_eps : float = 1e-5
9292 max_batch_size : int = 32
9393 max_seq_len : int = 2048
94+ max_context_len : int = 2048
9495 moe : bool = False # True to enable the MoE (Mixture of Experts)
9596 num_experts : int = 8 # Number of experts
9697 num_activated_experts : int = 2 # Number of experts to activate
@@ -163,9 +164,9 @@ def __init__(self, params: ModelArgs):
163164 freqs_cos , freqs_sin = self .precompute_freqs_cis (
164165 self .params .head_dim ,
165166 (
166- self .params .max_seq_len # Normal llama2.
167+ self .params .max_context_len # Normal llama2.
167168 if self .params .ffn_dim_multiplier is None
168- else self .params .max_seq_len * 2 # Sharded checkpoint.
169+ else self .params .max_context_len * 2 # Sharded checkpoint.
169170 ),
170171 self .params .rope_freq_base ,
171172 )
@@ -205,7 +206,7 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
205206 # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
206207 input_pos_item = input_pos [- 1 ].item ()
207208 torch ._check_is_size (input_pos_item )
208- torch ._check (input_pos_item < self .params .max_seq_len )
209+ torch ._check (input_pos_item < self .params .max_context_len )
209210 # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
210211 freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seq_len )
211212 # pyre-ignore: Incompatible parameter type [6]
@@ -229,15 +230,15 @@ class KVCache(nn.Module):
229230 def __init__ (
230231 self ,
231232 max_batch_size : int ,
232- max_seq_length : int ,
233+ max_context_length : int ,
233234 n_heads : int ,
234235 head_dim : int ,
235236 enable_dynamic_shape : bool ,
236237 dtype = torch .float32 ,
237238 ):
238239 super ().__init__ ()
239- self .max_seq_length = max_seq_length
240- cache_shape = (max_batch_size , n_heads , max_seq_length , head_dim )
240+ self .max_context_length = max_context_length
241+ cache_shape = (max_batch_size , n_heads , max_context_length , head_dim )
241242
242243 self .max_batch_size = max_batch_size
243244 self .n_heads = n_heads
@@ -257,7 +258,7 @@ def update(
257258 if self .enable_dynamic_shape :
258259 start_pos = input_pos [0 ].item ()
259260 torch ._check_is_size (start_pos )
260- torch ._check (start_pos < self .max_seq_length )
261+ torch ._check (start_pos < self .max_context_length )
261262 dim_to_slice = 2
262263 seq_length = k_val .size (dim_to_slice )
263264 # Replace the entry in the cache for this token
@@ -289,14 +290,14 @@ def __init__(
289290 dim : int ,
290291 head_dim : int ,
291292 n_rep : int ,
292- max_seq_len : int ,
293+ max_context_len : int ,
293294 enable_dynamic_shape : bool ,
294295 ):
295296 super ().__init__ ()
296297 self .dim = dim
297298 self .head_dim = head_dim
298299 self .n_rep = n_rep
299- self .max_seq_len = max_seq_len
300+ self .max_context_len = max_context_len
300301 self .enable_dynamic_shape = enable_dynamic_shape
301302
302303 def forward (
@@ -312,7 +313,7 @@ def forward(
312313 if self .enable_dynamic_shape :
313314 start_pos = input_pos [- 1 ].item ()
314315 torch ._check_is_size (start_pos )
315- torch ._check (start_pos < self .max_seq_len )
316+ torch ._check (start_pos < self .max_context_len )
316317 seq_length = q .size (2 )
317318 # pyre-ignore: Incompatible parameter type [6]
318319 attn_mask = mask .narrow (0 , start_pos , seq_length )
@@ -341,7 +342,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
341342 self .n_rep = self .n_local_heads // self .n_local_kv_heads
342343 self .head_dim = args .head_dim
343344 self .max_batch_size = args .max_batch_size
344- self .max_seq_len = args .max_seq_len
345+ self .max_context_len = args .max_context_len
345346 self .dim = args .dim
346347 self .wq = nn .Linear (self .dim , self .n_heads * self .head_dim , bias = False )
347348 self .wk = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
@@ -354,8 +355,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
354355
355356 causal_mask = torch .tril (
356357 torch .ones (
357- self .max_seq_len ,
358- self .max_seq_len ,
358+ self .max_context_len ,
359+ self .max_context_len ,
359360 dtype = torch .bool ,
360361 device = "cpu" ,
361362 )
@@ -365,7 +366,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
365366 if self .use_kv_cache :
366367 self .kv_cache = KVCache (
367368 args .max_batch_size ,
368- args .max_seq_len ,
369+ args .max_context_len ,
369370 self .n_kv_heads ,
370371 self .head_dim ,
371372 args .enable_dynamic_shape ,
@@ -374,7 +375,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
374375 dim = self .n_local_heads * self .head_dim ,
375376 head_dim = self .head_dim ,
376377 n_rep = self .n_rep ,
377- max_seq_len = self .max_seq_len ,
378+ max_context_len = self .max_context_len ,
378379 enable_dynamic_shape = args .enable_dynamic_shape ,
379380 )
380381
@@ -528,6 +529,7 @@ def __init__(self, params: ModelArgs):
528529 self .use_kv_cache = params .use_kv_cache
529530 self .generate_full_logits = params .generate_full_logits
530531 self .max_seq_len = params .max_seq_len
532+ self .max_context_len = params .max_context_len
531533 self .input_prune_map = params .input_prune_map
532534 self .output_prune_map = params .output_prune_map
533535
0 commit comments