Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 0b8ca05

Browse files
authored
Update Flamingo Builder to use Llama3ScaledRoPE instead of RotaryPositionalEmbeddings (#1202)
* update rope class * remove old rope class
1 parent 021fd32 commit 0b8ca05

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchchat/cli/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from torchchat.model import Model, ModelArgs, ModelType
3333

34-
from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings
34+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
3535

3636
from torchchat.model_config.model_config import resolve_model_config
3737
from torchchat.utils.build_utils import (
@@ -402,7 +402,7 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
402402
max_seq_len = decoder_config['max_seq_len']
403403
rope_base = decoder_config['rope_base']
404404
for submodule in model.modules():
405-
if isinstance(submodule, RotaryPositionalEmbeddings):
405+
if isinstance(submodule, Llama3ScaledRoPE):
406406
submodule.__init__(head_dim, max_seq_len, rope_base)
407407
state_dict = flamingo_meta_to_tune(checkpoint)
408408
model.model.load_state_dict(state_dict, assign=True, strict=False)

0 commit comments

Comments
 (0)