@@ -1070,8 +1070,8 @@ def forward(
10701070 inputs_embeds = self .embed_tokens (input_ids )
10711071 hidden_states = inputs_embeds
10721072
1073- for decoder_layer in self .layers :
1074- if isinstance ( decoder_layer , MllamaCrossAttentionDecoderLayer ) :
1073+ for idx , decoder_layer in enumerate ( self .layers ) :
1074+ if idx in self . cross_attention_layers :
10751075 if not skip_cross_attention :
10761076 hidden_states = decoder_layer (
10771077 hidden_states = hidden_states ,
@@ -1081,16 +1081,13 @@ def forward(
10811081 full_text_row_masked_out_mask =
10821082 full_text_row_masked_out_mask ,
10831083 )
1084- elif isinstance ( decoder_layer , LlamaDecoderLayer ) :
1084+ else :
10851085 hidden_states , residual = decoder_layer (
10861086 positions = positions ,
10871087 hidden_states = hidden_states ,
10881088 residual = None ,
10891089 )
10901090 hidden_states = hidden_states + residual
1091- else :
1092- raise ValueError (
1093- f"Unknown decoder layer type { type (decoder_layer )} " )
10941091 hidden_states = self .norm (hidden_states )
10951092 return hidden_states
10961093
@@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor(
15511548 full_text_mask = ((mask != ninf ).any (dim = - 1 ).type_as (mask )[..., None ])
15521549 mask *= full_text_mask
15531550 # (num_prompt_tokens, num_encoder_tokens)
1554- return mask
1551+ return mask
0 commit comments