diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 38ffa9174..1049b346f 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -31,7 +31,7 @@ from torchchat.model import Model, ModelArgs, ModelType -from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( @@ -402,7 +402,7 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: max_seq_len = decoder_config['max_seq_len'] rope_base = decoder_config['rope_base'] for submodule in model.modules(): - if isinstance(submodule, RotaryPositionalEmbeddings): + if isinstance(submodule, Llama3ScaledRoPE): submodule.__init__(head_dim, max_seq_len, rope_base) state_dict = flamingo_meta_to_tune(checkpoint) model.model.load_state_dict(state_dict, assign=True, strict=False)