Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1e3dcae

Browse files
author
vmpuri
committed
Fix control bug for image inputs
1 parent 91a68ab commit 1e3dcae

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

torchchat/generate.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -734,13 +734,7 @@ def _callback(self, x, *, buffer, done_generating):
734734
if len(buffer) == 4 or done_generating:
735735
print("".join(buffer), end="", flush=True)
736736
buffer.clear()
737-
# print(, end='', flush=True)
738-
739-
def print_m(self, message):
740-
print(
741-
message.role,
742-
[t["type"] if t["type"] != "text" else t for t in message.content],
743-
)
737+
print(, end='', flush=True)
744738

745739
def _gen_model_input(
746740
self,
@@ -764,7 +758,7 @@ def _gen_model_input(
764758
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
765759
"""
766760

767-
# Not Llama 3.2 11B
761+
# Text-Only model
768762
if self.model.config.model_type != ModelType.Flamingo:
769763
# Single String prompt
770764
if isinstance(prompt, str):
@@ -819,7 +813,7 @@ def _gen_model_input(
819813

820814
is_multimodal = images is not None
821815
content = [{"type": "text", "content": prompt_arg}]
822-
816+
[]
823817
if is_multimodal:
824818
content = [{"type": "image", "content": images[0]}] + content
825819

@@ -830,27 +824,24 @@ def _gen_model_input(
830824
)
831825
)
832826

833-
print("MESSAGE CONTENTS:")
834-
messages.append(Message(role="assistant", content=""))
835-
[self.print_m(m) for m in messages]
836-
837827
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
838828

839829
device = torch.device(device=self.builder_args.device)
840830

841831
with device, set_default_dtype(self.dtype):
842832
data = transform({"messages": messages}, inference=True)
843833

844-
if is_multimodal:
834+
if image_found:
845835
batch = padded_collate_tiled_images_and_mask(
846836
[data], pad_direction="left", pad_max_images=1
847837
)
848838
encoded = batch.pop("tokens").to(device).view(-1)
849-
seq_len = encoded.size(0)
839+
seq_len = encoded.size(0)
850840
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
851841
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
852842
self.dtype
853843
)
844+
854845
else:
855846
encoded = torch.tensor(data["tokens"], device=device).view(-1)
856847
seq_len = encoded.size(0)

0 commit comments

Comments
 (0)