Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e4b36f9

Browse files
authored
Support text-only input with llama3.2-11b (#1216)
1 parent ec7b510 commit e4b36f9

File tree

1 file changed

+47
-25
lines changed

1 file changed

+47
-25
lines changed

torchchat/generate.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -353,27 +353,34 @@ def prefill(
353353
width = x.size(1)
354354
assert input_pos.size(0) == width
355355

356-
if batch is not None:
356+
if self.model.config.model_type == ModelType.Flamingo:
357+
assert batch is not None, "Flamingo requires batch"
358+
357359
# TODO: Verify sequential prefill works with multimodal models
358-
tokens = batch["tokens"]
360+
is_multimodal = True
359361
if 'encoder_input' in batch:
360362
encoder_input = batch['encoder_input']
363+
encoder_mask = batch["encoder_mask"]
364+
is_multimodal = True
361365
else:
362366
encoder_input = None
367+
encoder_mask = None
368+
is_multimodal = False
363369

364-
seq_len = tokens.size(1)
370+
seq_len = x.size(1)
365371
mask = batch["causal_mask"][None, :seq_len]
366-
encoder_mask = batch["encoder_mask"]
367372
input_pos = input_pos.view(1, -1)
368-
logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1]
373+
logits = model(tokens=x, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1]
374+
375+
if is_multimodal:
376+
batch["encoder_mask"] = batch["encoder_mask"][:, -1:]
377+
369378
return tune_sample(logits, temperature=0, top_k=500)
370379
elif sequential_prefill:
371380
for i in range(width):
372381
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
373382
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
374-
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
375-
elif self.model.config.model_type == ModelType.Flamingo:
376-
assert False, "Flamingo requires batch"
383+
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])da
377384
else:
378385
# input_pos: [B, S]
379386
logits = model(x, input_pos)
@@ -397,7 +404,7 @@ def decode_one_token(
397404
if model.config.model_type == ModelType.Flamingo:
398405
assert batch is not None, "Flamingo requires batch"
399406
mask = batch["causal_mask"][None, input_pos.item(), None, :]
400-
encoder_mask = batch["encoder_mask"][:, -1:]
407+
encoder_mask = batch["encoder_mask"] if "encoder_mask" in batch else None
401408
logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:]
402409
else:
403410
logits = model(x, input_pos)
@@ -733,41 +740,56 @@ def chat(
733740
if generator_args.chat_mode:
734741
print("Starting Interactive Chat")
735742

736-
if generator_args.image_prompts is not None:
737-
print("Image prompts", generator_args.image_prompts)
743+
if self.model.config.model_type == ModelType.Flamingo:
744+
745+
is_multimodal = generator_args.image_prompts is not None
746+
content = [{"type": "text", "content": generator_args.prompt}]
747+
748+
if is_multimodal:
749+
print("Image prompts", generator_args.image_prompts)
750+
751+
# Support for just the first image prompt for now
752+
images = [Image.open(generator_args.image_prompts[0])]
753+
content = [{"type": "image", "content": images[0]}] + content
738754

739-
# Support for just the first image prompt for now
740-
images = [Image.open(generator_args.image_prompts[0])]
741755
messages = [
742756
Message(
743757
role="user",
744-
content=[
745-
{"type": "image", "content": images[0]},
746-
{"type": "text", "content": generator_args.prompt},
747-
],
758+
content=content,
748759
eot=True,
749760
),
750761
Message(role="assistant", content=""),
751762
]
752763

753764
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
754765

755-
with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype):
766+
device = torch.device(device=self.builder_args.device)
767+
768+
with device, set_default_dtype(self.dtype):
756769
data = transform({"messages": messages}, inference=True)
757-
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
758-
# set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it
759-
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
760-
seq_len = len(data["tokens"])
770+
771+
if is_multimodal:
772+
batch = padded_collate_tiled_images_and_mask(
773+
[data], pad_direction="left", pad_max_images=1
774+
)
775+
encoded = batch.pop("tokens").to(device).view(-1)
776+
seq_len = encoded.size(0)
777+
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
778+
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
779+
else:
780+
encoded = torch.tensor(
781+
data["tokens"], device=device
782+
).view(-1)
783+
seq_len = encoded.size(0)
784+
batch = {}
785+
761786
total_response_length = seq_len + generator_args.max_new_tokens
762787
batch["causal_mask"] = torch.tril(
763788
torch.ones(
764789
size=(total_response_length, total_response_length),
765790
dtype=torch.bool,
766791
)
767792
)
768-
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
769-
encoded = batch["tokens"].view(-1)
770-
771793
else:
772794
encoded = self.encode_tokens(
773795
generator_args.prompt, bos=True, device=self.builder_args.device

0 commit comments

Comments
 (0)