Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ad84f51

Browse files
author
vmpuri
committed
Fix control bug for image inputs
1 parent 91a68ab commit ad84f51

File tree

1 file changed

+13
-21
lines changed

1 file changed

+13
-21
lines changed

torchchat/generate.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,9 @@ def generate(
655655
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656656
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)
657657

658-
input_pos = torch.tensor([start_pos + prompt_length], device=device, dtype=torch.int)
658+
input_pos = torch.tensor(
659+
[start_pos + prompt_length], device=device, dtype=torch.int
660+
)
659661
accept_counts = [0] * (
660662
speculate_k + 1
661663
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -736,12 +738,6 @@ def _callback(self, x, *, buffer, done_generating):
736738
buffer.clear()
737739
# print(, end='', flush=True)
738740

739-
def print_m(self, message):
740-
print(
741-
message.role,
742-
[t["type"] if t["type"] != "text" else t for t in message.content],
743-
)
744-
745741
def _gen_model_input(
746742
self,
747743
prompt: Union[str | List[Any]],
@@ -764,7 +760,7 @@ def _gen_model_input(
764760
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
765761
"""
766762

767-
# Not Llama 3.2 11B
763+
# Text-Only model
768764
if self.model.config.model_type != ModelType.Flamingo:
769765
# Single String prompt
770766
if isinstance(prompt, str):
@@ -819,7 +815,7 @@ def _gen_model_input(
819815

820816
is_multimodal = images is not None
821817
content = [{"type": "text", "content": prompt_arg}]
822-
818+
[]
823819
if is_multimodal:
824820
content = [{"type": "image", "content": images[0]}] + content
825821

@@ -830,18 +826,14 @@ def _gen_model_input(
830826
)
831827
)
832828

833-
print("MESSAGE CONTENTS:")
834-
messages.append(Message(role="assistant", content=""))
835-
[self.print_m(m) for m in messages]
836-
837829
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
838830

839831
device = torch.device(device=self.builder_args.device)
840832

841833
with device, set_default_dtype(self.dtype):
842834
data = transform({"messages": messages}, inference=True)
843835

844-
if is_multimodal:
836+
if image_found:
845837
batch = padded_collate_tiled_images_and_mask(
846838
[data], pad_direction="left", pad_max_images=1
847839
)
@@ -851,6 +843,7 @@ def _gen_model_input(
851843
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
852844
self.dtype
853845
)
846+
854847
else:
855848
encoded = torch.tensor(data["tokens"], device=device).view(-1)
856849
seq_len = encoded.size(0)
@@ -883,13 +876,6 @@ def chat(
883876
if generator_args.chat_mode:
884877
print("Starting Interactive Chat")
885878

886-
encoded, batch = self._gen_model_input(
887-
generator_args.prompt,
888-
generator_args.image_prompts,
889-
generator_args.max_new_tokens,
890-
generator_args.max_seq_length,
891-
)
892-
893879
model_size = sum(
894880
[
895881
p.numel() * p.dtype.itemsize
@@ -935,6 +921,12 @@ def chat(
935921
max_seq_length = (
936922
text_transformer_args.max_seq_length if text_transformer_args else 2048
937923
)
924+
encoded, batch = self._gen_model_input(
925+
generator_args.prompt,
926+
generator_args.image_prompts,
927+
generator_args.max_new_tokens,
928+
max_seq_length,
929+
)
938930

939931
if generator_args.chat_mode:
940932
print(

0 commit comments

Comments
 (0)