@@ -1244,73 +1244,6 @@ def forward(self, hidden_states):
12441244 return self .weight * hidden_states .to (input_dtype )
12451245
12461246
1247- class BasicDecoderLayer (nn .Module ):
1248- """
1249- The traditional decoder layer.
1250- """
1251-
1252- def __init__ (self , config , attention_backend : str = "sdpa" ):
1253- super ().__init__ ()
1254- self .hidden_size = config .hidden_size
1255-
1256- if attention_backend == "sdpa" :
1257- self .self_attn = LlamaAttention (config = config , fused_input = False )
1258- elif attention_backend == "flex_attention" :
1259- print_with_rank ("Using flex attention on draft model training!" )
1260- self .self_attn = LlamaFlexAttention (config = config , fused_input = False )
1261- elif attention_backend == "fa" :
1262- self .self_attn = LlamaFlashAttention (config = config , fused_input = False )
1263- elif attention_backend == "usp" :
1264- self .self_attn = LlamaAttention (config = config , fused_input = False )
1265- else :
1266- raise ValueError (f"Unknown attention backend { attention_backend } " )
1267-
1268- self .attention_backend = attention_backend
1269- self .mlp = LlamaMLP (config )
1270-
1271- self .input_layernorm = LlamaRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
1272- self .post_attention_layernorm = LlamaRMSNorm (
1273- config .hidden_size , eps = config .rms_norm_eps
1274- )
1275-
1276- def forward (
1277- self ,
1278- hidden_states : torch .Tensor ,
1279- attention_mask : Optional [torch .Tensor ] = None ,
1280- position_ids : Optional [torch .LongTensor ] = None ,
1281- past_key_values : Optional [Cache ] = None ,
1282- output_attentions : Optional [bool ] = False ,
1283- use_cache : Optional [bool ] = False ,
1284- ) -> Tuple [
1285- torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]
1286- ]:
1287- """
1288- Basic decoder layer forward pass with self-attention and mlp.
1289- """
1290- residual = hidden_states
1291-
1292- hidden_states = self .input_layernorm (hidden_states )
1293- hidden_states = self .self_attn (
1294- hidden_states = hidden_states ,
1295- attention_mask = attention_mask ,
1296- position_ids = position_ids ,
1297- past_key_values = past_key_values ,
1298- output_attentions = output_attentions ,
1299- use_cache = use_cache ,
1300- )
1301-
1302- # First residual connection
1303- hidden_states = residual + hidden_states
1304-
1305- # Feed Forward Network with res connection
1306- residual = hidden_states
1307- hidden_states = self .post_attention_layernorm (hidden_states )
1308- hidden_states = self .mlp (hidden_states )
1309- hidden_states = residual + hidden_states
1310-
1311- return hidden_states
1312-
1313-
13141247class LlamaDecoderLayer (nn .Module ):
13151248 def __init__ (self , config , attention_backend : str = "sdpa" , fused_input = True ):
13161249 super ().__init__ ()
0 commit comments