Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions src/speculators/models/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, ClassVar, Literal

import torch
from django.contrib.gis.gdal.prototypes.srs import from_user_input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this import if it's not used?

from pydantic import Field, field_serializer, field_validator
from torch import nn
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
Expand All @@ -28,6 +29,8 @@
)

from speculators import SpeculatorModel, SpeculatorModelConfig
import os
from typing import Any, ClassVar, Literal, Optional

__all__ = [
"Eagle3Attention",
Expand Down Expand Up @@ -337,6 +340,10 @@ def __init__(
| None = None,
reduce_vocab_size: bool = True,
has_drafter_embedding: bool = True,
"""new parameters"""
t2d: Optional[torch.Tensor] = None,
d2t: Optional[torch.Tensor] = None,
"""end of the new parameters"""
):
"""
Initialize Eagle3 speculator.
Expand All @@ -351,7 +358,10 @@ def __init__(
raise ValueError(
f"config must be Eagle3SpeculatorConfig, got {type(config)}"
)

"""new Tensors"""
self.d2t = d2t
self.t2d = t2d
"""end of the new Tensors"""
self.config: Eagle3SpeculatorConfig = config

self.hidden_size = config.transformer_layer_config.hidden_size
Expand All @@ -364,6 +374,11 @@ def __init__(
if config.target_hidden_size is not None
else self.hidden_size
)
"""control of consistent"""
if(self.t2d is not None) != (self.d2t is not None):
raise ValueError(
"You must provide both t2d and d2t."
)

super().__init__(
config=config,
Expand Down Expand Up @@ -406,21 +421,24 @@ def __init__(
self.draft_vocab_size,
bias=False,
)
if reduce_vocab_size:
self.register_buffer( # type: ignore[attr-defined]
"""new upgraded buffer"""
if self.t2d is not None and self.d2t is not None:
self.register_buffer("d2t", self.d2t)
self.register_buffer("t2d", self.t2d)
elif reduce_vocab_size:
self.register_buffer( # type: ignore[attr-defined]
"d2t",
torch.zeros(self.draft_vocab_size, dtype=torch.long),
)
self.register_buffer( # type: ignore[attr-defined]
"t2d",
self.register_buffer( "t2d",
torch.zeros(self.target_vocab_size, dtype=torch.bool),
)

# Type hints for buffers
if hasattr(self, "d2t"):
self.d2t: torch.Tensor
self.t2d: torch.Tensor
self.post_init() # type: ignore[attr-defined]

self.post_init() # type: ignore[attr-defined]
"""end of the new buffer"""
def tie_weights(self):
"""
Override tie_weights to prevent vocabulary corruption in transformers 4.54.1+
Expand Down Expand Up @@ -463,4 +481,45 @@ def forward(
:param return_dict: Return dict output
:return: Model outputs with draft vocabulary logits
"""
raise NotImplementedError("Eagle3Speculator.forward is not implemented yet.")
"""new changes"""
# 1-embeddings
inputs_embeds = self.embed_tokens(input_ids)

# 2- FC
fused_input = torch.cat([inputs_embeds, hidden_states], dim=-1) # [B, L, H + 3*H]

# 3
hidden_states = self.fc(fused_input)

# 4- decoder
for decoder_layer in self.layers:
layer_outputs = decoder_layer(
hidden_states=hidden_states,
past_key_value=past_key_values[0] if past_key_values else None,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=return_dict,


)
hidden_states = layer_outputs[0]

#5 normalization
hidden_states = self.lm_head(hidden_states)

#6 calculating logit with LM head
draft_logits = self.lm_head(hidden_states)

#7 Control

if self.d2t is not None:
final_logits = torch.matmul(draft_logits, self.d2t.T.to(draft_logits.dtype))

else:
final_logits = draft_logits

return final_logits