|
24 | 24 |
|
25 | 25 | from PIL import Image |
26 | 26 |
|
27 | | -# torchtune model definition dependencies |
28 | | -from torchtune.data import Message, padded_collate_tiled_images_and_mask |
29 | | - |
30 | | -from torchtune.generation import sample as tune_sample |
31 | | -from torchtune.models.llama3 import llama3_tokenizer |
32 | | - |
33 | | -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
34 | | -from torchtune.training import set_default_dtype |
35 | | - |
36 | 27 | from torchchat.cli.builder import ( |
37 | 28 | _initialize_model, |
38 | 29 | _initialize_tokenizer, |
|
43 | 34 | from torchchat.utils.build_utils import device_sync, set_precision |
44 | 35 | from torchchat.utils.device_info import get_device_info |
45 | 36 |
|
| 37 | +# torchtune model definition dependencies |
| 38 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask |
| 39 | + |
| 40 | +from torchtune.generation import sample as tune_sample |
| 41 | +from torchtune.models.llama3 import llama3_tokenizer |
| 42 | + |
| 43 | +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
| 44 | +from torchtune.training import set_default_dtype |
| 45 | + |
46 | 46 |
|
47 | 47 | class _ChatFormatter(ABC): |
48 | 48 | def __init__(self, tokenizer): |
@@ -180,10 +180,14 @@ def from_args(cls, args): |
180 | 180 | # Validate that all image prompts exist before expensive model load |
181 | 181 | if image_prompts := getattr(args, "image_prompts", None): |
182 | 182 | non_existent_image_prompts = [ |
183 | | - image_prompt if (not os.path.exists(image_prompt)) for image_prompt in image_prompts |
| 183 | + image_prompt |
| 184 | + for image_prompt in image_prompts |
| 185 | + if (not os.path.exists(image_prompt)) |
184 | 186 | ] |
185 | 187 | if len(non_existent_image_prompts): |
186 | | - raise RuntimeError(f"Image prompt {non_existent_image_prompts} does not exist") |
| 188 | + raise RuntimeError( |
| 189 | + f"Image prompt {non_existent_image_prompts} does not exist" |
| 190 | + ) |
187 | 191 |
|
188 | 192 | return cls( |
189 | 193 | prompt=getattr(args, "prompt", ""), |
@@ -941,6 +945,7 @@ def chat( |
941 | 945 | TransformerCrossAttentionLayer, |
942 | 946 | TransformerSelfAttentionLayer, |
943 | 947 | ) |
| 948 | + |
944 | 949 | decoder = self.model.model.decoder |
945 | 950 | for m in reversed(list(decoder.modules())): |
946 | 951 | if isinstance(m, TransformerSelfAttentionLayer) or isinstance( |
@@ -987,7 +992,10 @@ def chat( |
987 | 992 | # `is_torchtune_model` is a misnomer since it doesn't capture all |
988 | 993 | # torchtune models (i.e. Flamingo) |
989 | 994 | # See Issue: https://github.com/pytorch/torchchat/issues/1273 |
990 | | - elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo: |
| 995 | + elif ( |
| 996 | + not generator_args.is_torchtune_model |
| 997 | + and self.model.config.model_type != ModelType.Flamingo |
| 998 | + ): |
991 | 999 | max_seq_length = min( |
992 | 1000 | encoded.size(0) + generator_args.max_new_tokens, |
993 | 1001 | ( |
|
0 commit comments