Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2f52a9f

Browse files
committed
reformat
1 parent 52f5030 commit 2f52a9f

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

torchchat/generate.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,6 @@
2424

2525
from PIL import Image
2626

27-
# torchtune model definition dependencies
28-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
29-
30-
from torchtune.generation import sample as tune_sample
31-
from torchtune.models.llama3 import llama3_tokenizer
32-
33-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
34-
from torchtune.training import set_default_dtype
35-
3627
from torchchat.cli.builder import (
3728
_initialize_model,
3829
_initialize_tokenizer,
@@ -43,6 +34,15 @@
4334
from torchchat.utils.build_utils import device_sync, set_precision
4435
from torchchat.utils.device_info import get_device_info
4536

37+
# torchtune model definition dependencies
38+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
39+
40+
from torchtune.generation import sample as tune_sample
41+
from torchtune.models.llama3 import llama3_tokenizer
42+
43+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
44+
from torchtune.training import set_default_dtype
45+
4646

4747
class _ChatFormatter(ABC):
4848
def __init__(self, tokenizer):
@@ -180,10 +180,14 @@ def from_args(cls, args):
180180
# Validate that all image prompts exist before expensive model load
181181
if image_prompts := getattr(args, "image_prompts", None):
182182
non_existent_image_prompts = [
183-
image_prompt if (not os.path.exists(image_prompt)) for image_prompt in image_prompts
183+
image_prompt
184+
for image_prompt in image_prompts
185+
if (not os.path.exists(image_prompt))
184186
]
185187
if len(non_existent_image_prompts):
186-
raise RuntimeError(f"Image prompt {non_existent_image_prompts} does not exist")
188+
raise RuntimeError(
189+
f"Image prompt {non_existent_image_prompts} does not exist"
190+
)
187191

188192
return cls(
189193
prompt=getattr(args, "prompt", ""),
@@ -941,6 +945,7 @@ def chat(
941945
TransformerCrossAttentionLayer,
942946
TransformerSelfAttentionLayer,
943947
)
948+
944949
decoder = self.model.model.decoder
945950
for m in reversed(list(decoder.modules())):
946951
if isinstance(m, TransformerSelfAttentionLayer) or isinstance(
@@ -987,7 +992,10 @@ def chat(
987992
# `is_torchtune_model` is a misnomer since it doesn't capture all
988993
# torchtune models (i.e. Flamingo)
989994
# See Issue: https://github.com/pytorch/torchchat/issues/1273
990-
elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo:
995+
elif (
996+
not generator_args.is_torchtune_model
997+
and self.model.config.model_type != ModelType.Flamingo
998+
):
991999
max_seq_length = min(
9921000
encoded.size(0) + generator_args.max_new_tokens,
9931001
(

0 commit comments

Comments
 (0)