-
Notifications
You must be signed in to change notification settings - Fork 251
Move imports to prepare for making torchtune an optional dependency #1539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1539
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 2 Cancelled JobsAs of commit 8b9d9df with merge base 0299a37 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Oops, it seems to me that there are some problems with linting... I've ran it according to the |
Thanks for the PR, just kicked off the CI hmm we'll take a gander at the linter (not blocking this pr on lint) |
Thanks for the initial pass (lint changes are also looking good) |
You can ignore the failing et/executorch CI, that's a separate issue |
Thanks for the review! Sure, let's put them deeper then. |
@Jack-Khuu Hey! Fixed according to your comments. Moved imports deeper in cases when it was reasonable |
Thanks again, I'm planning to take a look today (and hopefully merge in) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Things are looking solid, I'll add another commit on top of this to fix a few nits, but looks ready to merge
torchchat/cli/builder.py
Outdated
@@ -416,6 +414,8 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model: | |||
|
|||
|
|||
def _load_checkpoint(builder_args: BuilderArgs): | |||
from torchtune.models.convert_weights import meta_to_tune |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm let's push the install further into the check on line 419.
This function would error if we don't have torchtune installed
torchchat/cli/builder.py
Outdated
@@ -458,6 +458,12 @@ def _load_checkpoint(builder_args: BuilderArgs): | |||
|
|||
|
|||
def _load_model_default(builder_args: BuilderArgs) -> Model: | |||
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto, we can drop this into the Flamingo conditional
torchchat/generate.py
Outdated
@@ -450,6 +446,8 @@ def prefill( | |||
sequential_prefill=True, | |||
**sampling_kwargs, | |||
) -> torch.Tensor: | |||
from torchtune.generation import sample as tune_sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto to moving it into the flamingo check
@@ -870,6 +870,13 @@ def _gen_model_input( | |||
max_new_tokens: Optional[int] = None, | |||
max_seq_len: Optional[int] = 2048, | |||
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: | |||
# torchtune model definition dependencies | |||
from torchtune.data import Message, padded_collate_tiled_images_and_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move these 3 imports down to line 913
modules["encoder_trainable"] = False | ||
modules["decoder_trainable"] = False | ||
modules["fusion_trainable"] = False | ||
except ModuleNotFoundError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
@@ -1011,21 +1054,23 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
|
|||
try: | |||
# For llama::sdpa_with_kv_cache.out, preprocess ops | |||
from executorch.extension.llm.custom_ops import custom_ops # no-qa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be safe let's undo the lint reorder here
I think ET is sensitive to order in this case
Weird looks like some of my old comments either got rearranged/not sent I'll fix-em |
torchtune
now can be an optional dependency. In most of the cases, the import was just moved to a specific function or class. In the case of theopenai_api.py
, the better solution was to just remove the unused import andMessage
type hint, because it is not a real reason to have an import of an optional library in the class.