Skip to content

Commit e4bfba0

Browse files
committed
rm redundant class
1 parent 95ab95a commit e4bfba0

File tree

1 file changed

+0
-67
lines changed

1 file changed

+0
-67
lines changed

specforge/modeling/draft/llama3_eagle.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,73 +1244,6 @@ def forward(self, hidden_states):
12441244
return self.weight * hidden_states.to(input_dtype)
12451245

12461246

1247-
class BasicDecoderLayer(nn.Module):
1248-
"""
1249-
The traditional decoder layer.
1250-
"""
1251-
1252-
def __init__(self, config, attention_backend: str = "sdpa"):
1253-
super().__init__()
1254-
self.hidden_size = config.hidden_size
1255-
1256-
if attention_backend == "sdpa":
1257-
self.self_attn = LlamaAttention(config=config, fused_input=False)
1258-
elif attention_backend == "flex_attention":
1259-
print_with_rank("Using flex attention on draft model training!")
1260-
self.self_attn = LlamaFlexAttention(config=config, fused_input=False)
1261-
elif attention_backend == "fa":
1262-
self.self_attn = LlamaFlashAttention(config=config, fused_input=False)
1263-
elif attention_backend == "usp":
1264-
self.self_attn = LlamaAttention(config=config, fused_input=False)
1265-
else:
1266-
raise ValueError(f"Unknown attention backend {attention_backend}")
1267-
1268-
self.attention_backend = attention_backend
1269-
self.mlp = LlamaMLP(config)
1270-
1271-
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1272-
self.post_attention_layernorm = LlamaRMSNorm(
1273-
config.hidden_size, eps=config.rms_norm_eps
1274-
)
1275-
1276-
def forward(
1277-
self,
1278-
hidden_states: torch.Tensor,
1279-
attention_mask: Optional[torch.Tensor] = None,
1280-
position_ids: Optional[torch.LongTensor] = None,
1281-
past_key_values: Optional[Cache] = None,
1282-
output_attentions: Optional[bool] = False,
1283-
use_cache: Optional[bool] = False,
1284-
) -> Tuple[
1285-
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
1286-
]:
1287-
"""
1288-
Basic decoder layer forward pass with self-attention and mlp.
1289-
"""
1290-
residual = hidden_states
1291-
1292-
hidden_states = self.input_layernorm(hidden_states)
1293-
hidden_states = self.self_attn(
1294-
hidden_states=hidden_states,
1295-
attention_mask=attention_mask,
1296-
position_ids=position_ids,
1297-
past_key_values=past_key_values,
1298-
output_attentions=output_attentions,
1299-
use_cache=use_cache,
1300-
)
1301-
1302-
# First residual connection
1303-
hidden_states = residual + hidden_states
1304-
1305-
# Feed Forward Network with res connection
1306-
residual = hidden_states
1307-
hidden_states = self.post_attention_layernorm(hidden_states)
1308-
hidden_states = self.mlp(hidden_states)
1309-
hidden_states = residual + hidden_states
1310-
1311-
return hidden_states
1312-
1313-
13141247
class LlamaDecoderLayer(nn.Module):
13151248
def __init__(self, config, attention_backend: str = "sdpa", fused_input=True):
13161249
super().__init__()

0 commit comments

Comments
 (0)