Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torchchat.model import Model, ModelArgs, ModelType

from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get rid of the old one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
Expand Down Expand Up @@ -402,7 +403,7 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
max_seq_len = decoder_config['max_seq_len']
rope_base = decoder_config['rope_base']
for submodule in model.modules():
if isinstance(submodule, RotaryPositionalEmbeddings):
if isinstance(submodule, Llama3ScaledRoPE):
submodule.__init__(head_dim, max_seq_len, rope_base)
Copy link
Contributor

@Jack-Khuu Jack-Khuu Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We should be populating the rest of the RoPE fields based on the config. This can be a separate PR

From tune:

scale_factor: int = 8,
low_freq_factor: int = 1,
high_freq_factor: int = 4,
old_context_len: int = 8192,

state_dict = flamingo_meta_to_tune(checkpoint)
model.model.load_state_dict(state_dict, assign=True, strict=False)
Expand Down
Loading