@@ -504,7 +504,7 @@ def forward(
504504
505505 if past_key_value is not None :
506506 chunk_position_ids = position_ids
507- if self .use_rope :
507+ if self .use_rope and self . config . attention_chunk_size :
508508 chunk_position_ids = torch .where (
509509 chunk_position_ids != - 1 , chunk_position_ids % self .config .attention_chunk_size , chunk_position_ids
510510 )
@@ -663,10 +663,16 @@ def forward(
663663 causal_mask = _create_causal_mask (
664664 position_ids = position_ids , target_length = past_key_values .layers [3 ].keys .shape [- 2 ]
665665 )
666- chunk_position_ids = torch .where (
667- position_ids != - 1 , position_ids % self .config .attention_chunk_size , position_ids
668- )
669- target_length = min (past_key_values .layers [0 ].keys .shape [- 2 ], torch .tensor (self .config .attention_chunk_size ))
666+ if self .config .attention_chunk_size :
667+ chunk_position_ids = torch .where (
668+ position_ids != - 1 , position_ids % self .config .attention_chunk_size , position_ids
669+ )
670+ target_length = min (
671+ past_key_values .layers [0 ].keys .shape [- 2 ], torch .tensor (self .config .attention_chunk_size )
672+ )
673+ else :
674+ chunk_position_ids = position_ids
675+ target_length = past_key_values .layers [0 ].keys .shape [- 2 ]
670676 chunk_causal_mask = _create_causal_mask (position_ids = chunk_position_ids , target_length = target_length )
671677 causal_mask_mapping = {
672678 "full_attention" : causal_mask ,
@@ -798,7 +804,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
798804 is_chunked_attention = torch .tensor (
799805 [bool ((i + 1 ) % 4 ) for i in range (config .num_hidden_layers )], dtype = torch .bool
800806 )
801- attention_chunk_size = getattr (config , "attention_chunk_size" , seq_len )
807+ attention_chunk_size = getattr (config , "attention_chunk_size" , None ) or seq_len
802808 global_cache_shape = [batch_size , n_heads , seq_len , d_head ]
803809 chunked_cache_shape = [
804810 batch_size ,
@@ -967,13 +973,12 @@ def get_specializations(
967973
968974 prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
969975 ctx_len = ctx_len if ctx_len else constants .INTERN_CTX_LEN
976+ attention_chunk_size = getattr (
977+ getattr (getattr (self , "config" , None ), "text_config" , None ), "attention_chunk_size" , None
978+ )
970979 chunk_ctx_len = min (
971980 ctx_len ,
972- (
973- self .config .text_config .attention_chunk_size
974- if hasattr (self , "config" )
975- else constants .LLAMA4_ATTENTION_CHUNK_SIZE
976- ),
981+ (attention_chunk_size if attention_chunk_size is not None else constants .LLAMA4_ATTENTION_CHUNK_SIZE ),
977982 )
978983 if (
979984 prefill_seq_len > constants .LLAMA4_MAX_POSITION_EMBEDDINGS
@@ -1158,7 +1163,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
11581163 is_chunked_attention = torch .tensor (
11591164 [bool ((i + 1 ) % 4 ) for i in range (config .num_hidden_layers )], dtype = torch .bool
11601165 )
1161- attention_chunk_size = getattr (config , "attention_chunk_size" , seq_len )
1166+ attention_chunk_size = getattr (config , "attention_chunk_size" , None ) or seq_len
11621167 global_cache_shape = [batch_size , n_heads , seq_len , d_head ]
11631168 chunked_cache_shape = [
11641169 batch_size ,
0 commit comments