-
Notifications
You must be signed in to change notification settings - Fork 248
[llava 2/n] Support Llava Model Construction #1155
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1155
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (4 Unrelated Failures)As of commit 672915a with merge base f730056 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchchat/model.py
Outdated
| from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder | ||
| from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder | ||
| from torchtune.modules.model_fusion import DeepFusionModel | ||
| from torchtune.models.clip import clip_vision_encoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rebase is fun and makes clones
torchchat/model.py
Outdated
| from enum import Enum | ||
| from pathlib import Path | ||
| from PIL import Image | ||
| import requests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure that this is downloaded in install requirements (it probably already is)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they can be removed rn; they should be sth in the 3/n pr haha
torchchat/model.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class ProjectorArgs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add docstring for each of these dataclasses since they not the usuall Llama classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I can remove it rn. We can further take it back when we design arg class for different modules.
torchchat/model.py
Outdated
| encoder_output = self.encoder( | ||
| encoder_input, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| encoder_output = self.encoder( | |
| encoder_input, | |
| ) | |
| encoder_output = self.encoder(encoder_input) |
torchchat/model.py
Outdated
| def setup_caches(self, batch_size, max_seq_len): | ||
| self.decoder.setup_caches(batch_size, max_seq_len) | ||
|
|
||
| def _encoder_feature_select(self, encoder_output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
REturn type
torchchat/model.py
Outdated
| *, | ||
| encoder_output: Optional[Tensor], | ||
| post_tokens: Optional[Tensor], | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return type
| modules={ | ||
| 'encoder': clip_vision_encoder, | ||
| 'decoder': Transformer | ||
| }, | ||
| fusion_class=ConcateFusion, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's really cool to see them working together!
torchchat/model.py
Outdated
| elif model_type == ModelType.Llama3_1: | ||
| return cls._llama3_1() | ||
| elif model_type == ModelType.Llava: | ||
| return cls._llava() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match model_type:
case ModelType.TextOnly:
return cls._text_only()
case ModelType.Flamingo:
return cls.flamingo()
...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh YEAH we are in 3.10, it is a good timing to switch to match case statement!
torchchat/model.py
Outdated
| return recipe.fusion_class(**modules) | ||
|
|
||
|
|
||
| def _replace_know_params(self, params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def _replace_know_params(self, params): | |
| def _replace_known_params(self, params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stupid grammar issue
| from enum import Enum | ||
| from pathlib import Path | ||
|
|
||
| import torchvision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this torchvision import seems unused
This PR supports Llava Model Construction.
E2E model integration will be in the following PRs.