Skip to content

Commit fd82bf2

Browse files
committed
rename eagle3 specific classes for clarity
1 parent e4bfba0 commit fd82bf2

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

specforge/modeling/draft/llama3_eagle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,7 +1244,7 @@ def forward(self, hidden_states):
12441244
return self.weight * hidden_states.to(input_dtype)
12451245

12461246

1247-
class LlamaDecoderLayer(nn.Module):
1247+
class Eagle3LlamaDecoderLayer(nn.Module):
12481248
def __init__(self, config, attention_backend: str = "sdpa", fused_input=True):
12491249
super().__init__()
12501250
self.hidden_size = config.hidden_size
@@ -1357,15 +1357,15 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
13571357

13581358
# Multi-layer decoder for Eagle3 draft model
13591359
# First being the embeds + hidden_states fuse layer
1360-
self.fuse_layer = LlamaDecoderLayer(
1360+
self.fuse_layer = Eagle3LlamaDecoderLayer(
13611361
config, attention_backend=attention_backend, fused_input=True
13621362
)
13631363
# the rests are the traditional decoder layers with only hidden_states as inputs
13641364
self.additional_layers = None
13651365
if self.num_hidden_layers > 1:
13661366
self.additional_layers = nn.ModuleList(
13671367
[
1368-
LlamaDecoderLayer(
1368+
Eagle3LlamaDecoderLayer(
13691369
config, attention_backend=attention_backend, fused_input=False
13701370
)
13711371
for _ in range(self.num_hidden_layers - 1)

0 commit comments

Comments
 (0)