|
24 | 24 |
|
25 | 25 | from PIL import Image |
26 | 26 |
|
27 | | -# torchtune model definition dependencies |
28 | | -from torchtune.data import Message, padded_collate_tiled_images_and_mask |
29 | | - |
30 | | -from torchtune.generation import sample as tune_sample |
31 | | -from torchtune.models.llama3 import llama3_tokenizer |
32 | | - |
33 | | -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
34 | | -from torchtune.training import set_default_dtype |
35 | | - |
36 | 27 | from torchchat.cli.builder import ( |
37 | 28 | _initialize_model, |
38 | 29 | _initialize_tokenizer, |
|
43 | 34 | from torchchat.utils.build_utils import device_sync, set_precision |
44 | 35 | from torchchat.utils.device_info import get_device_info |
45 | 36 |
|
| 37 | +# torchtune model definition dependencies |
| 38 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask |
| 39 | + |
| 40 | +from torchtune.generation import sample as tune_sample |
| 41 | +from torchtune.models.llama3 import llama3_tokenizer |
| 42 | + |
| 43 | +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
| 44 | +from torchtune.training import set_default_dtype |
| 45 | + |
46 | 46 |
|
47 | 47 | class _ChatFormatter(ABC): |
48 | 48 | def __init__(self, tokenizer): |
@@ -179,8 +179,15 @@ def from_args(cls, args): |
179 | 179 |
|
180 | 180 | # Validate that all image prompts exist before expensive model load |
181 | 181 | if image_prompts := getattr(args, "image_prompts", None): |
182 | | - if not all(os.path.exists(image_prompt) for image_prompt in image_prompts): |
183 | | - raise RuntimeError(f"Image prompt {image_prompt} does not exist") |
| 182 | + non_existent_image_prompts = [ |
| 183 | + image_prompt |
| 184 | + for image_prompt in image_prompts |
| 185 | + if (not os.path.exists(image_prompt)) |
| 186 | + ] |
| 187 | + if len(non_existent_image_prompts): |
| 188 | + raise RuntimeError( |
| 189 | + f"Image prompt {non_existent_image_prompts} does not exist" |
| 190 | + ) |
184 | 191 |
|
185 | 192 | return cls( |
186 | 193 | prompt=getattr(args, "prompt", ""), |
@@ -938,7 +945,8 @@ def chat( |
938 | 945 | TransformerCrossAttentionLayer, |
939 | 946 | TransformerSelfAttentionLayer, |
940 | 947 | ) |
941 | | - decoder = self.model.model.decoder |
| 948 | + |
| 949 | + decoder = self.model.model.decoder |
942 | 950 | for m in reversed(list(decoder.modules())): |
943 | 951 | if isinstance(m, TransformerSelfAttentionLayer) or isinstance( |
944 | 952 | m, TransformerCrossAttentionLayer |
@@ -984,7 +992,10 @@ def chat( |
984 | 992 | # `is_torchtune_model` is a misnomer since it doesn't capture all |
985 | 993 | # torchtune models (i.e. Flamingo) |
986 | 994 | # See Issue: https://github.com/pytorch/torchchat/issues/1273 |
987 | | - elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo: |
| 995 | + elif ( |
| 996 | + not generator_args.is_torchtune_model |
| 997 | + and self.model.config.model_type != ModelType.Flamingo |
| 998 | + ): |
988 | 999 | max_seq_length = min( |
989 | 1000 | encoded.size(0) + generator_args.max_new_tokens, |
990 | 1001 | ( |
|
0 commit comments