Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@Gasoonjia
Copy link
Contributor

This PR supports Llava Model Construction.
E2E model integration will be in the following PRs.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1155

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (4 Unrelated Failures)

As of commit 672915a with merge base f730056 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 17, 2024
@Gasoonjia Gasoonjia changed the title Support Llava Model Construction [llava 2/n] Support Llava Model Construction Sep 17, 2024
Comment on lines 41 to 44
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.models.clip import clip_vision_encoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase is fun and makes clones

from enum import Enum
from pathlib import Path
from PIL import Image
import requests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure that this is downloaded in install requirements (it probably already is)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they can be removed rn; they should be sth in the 3/n pr haha



@dataclass
class ProjectorArgs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add docstring for each of these dataclasses since they not the usuall Llama classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I can remove it rn. We can further take it back when we design arg class for different modules.

Comment on lines 126 to 128
encoder_output = self.encoder(
encoder_input,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_output = self.encoder(
encoder_input,
)
encoder_output = self.encoder(encoder_input)

def setup_caches(self, batch_size, max_seq_len):
self.decoder.setup_caches(batch_size, max_seq_len)

def _encoder_feature_select(self, encoder_output):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REturn type

*,
encoder_output: Optional[Tensor],
post_tokens: Optional[Tensor],
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type

Comment on lines +239 to +243
modules={
'encoder': clip_vision_encoder,
'decoder': Transformer
},
fusion_class=ConcateFusion,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really cool to see them working together!

elif model_type == ModelType.Llama3_1:
return cls._llama3_1()
elif model_type == ModelType.Llava:
return cls._llava()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match model_type:
    case ModelType.TextOnly:
        return cls._text_only()
    case ModelType.Flamingo:
        return cls.flamingo()
...

Copy link
Contributor Author

@Gasoonjia Gasoonjia Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh YEAH we are in 3.10, it is a good timing to switch to match case statement!

return recipe.fusion_class(**modules)


def _replace_know_params(self, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _replace_know_params(self, params):
def _replace_known_params(self, params):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stupid grammar issue

@Gasoonjia Gasoonjia merged commit 3b162e2 into main Sep 18, 2024
47 of 51 checks passed
from enum import Enum
from pathlib import Path

import torchvision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this torchvision import seems unused

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants