diff --git a/torchchat/model.py b/torchchat/model.py index edb0ce3d5..844aaf977 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -31,7 +31,7 @@ ) from torch.nn import functional as F -from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder +from torchtune.models.llama3_2_vision import llama3_2_vision_decoder, llama3_2_vision_encoder from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder from torchtune.modules.model_fusion import DeepFusionModel from torchtune.models.clip import clip_vision_encoder @@ -213,7 +213,7 @@ def _llama3_1(cls): def _flamingo(cls): return cls( model_type=ModelType.Flamingo, - modules={"encoder": flamingo_vision_encoder, "decoder": flamingo_decoder}, + modules={"encoder": llama3_2_vision_encoder, "decoder": llama3_2_vision_decoder}, fusion_class=DeepFusionModel, )