Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 148d4ff

Browse files
committed
flamingo e2e enable
1 parent c1a8ff4 commit 148d4ff

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

torchchat/generate.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -359,16 +359,16 @@ def prefill(
359359
if batch is not None:
360360
# TODO: Verify sequential prefill works with multimodal models
361361
tokens = batch["tokens"]
362-
if 'encoder_input' in tokens:
363-
encoder_input = tokens['encoder_input']
362+
if 'encoder_input' in batch:
363+
encoder_input = batch['encoder_input']
364364
else:
365365
encoder_input = None
366-
366+
367+
seq_len = tokens.size(1)
367368
mask = batch["causal_mask"][None, :seq_len]
368-
input_pos = batch["input_pos"][None, :seq_len]
369369
encoder_mask = batch["encoder_mask"]
370-
371-
logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_post, encoder_mask=encoder_mask)[:, -1]
370+
input_pos = input_pos.view(1, -1)
371+
logits = model(tokens=tokens, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1]
372372
return tune_sample(logits, temperature=0, top_k=500)
373373
elif sequential_prefill:
374374
for i in range(width):
@@ -604,7 +604,7 @@ def generate(
604604
self.is_torchtune_model
605605
or self.model.config.model_type == ModelType.Flamingo
606606
):
607-
model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new)
607+
model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=max_seq_length-1)
608608
else:
609609
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
610610
if is_speculative and draft_model is not model:
@@ -753,18 +753,19 @@ def chat(
753753
]
754754

755755
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
756-
data = transform({"messages": messages}, inference=True)
757-
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
758-
seq_len = len(data["tokens"])
759-
total_response_length = seq_len + generator_args.max_new_tokens
760-
batch["causal_mask"] = torch.tril(
761-
torch.ones(
762-
size=(total_response_length, total_response_length),
763-
dtype=torch.bool,
756+
757+
with torch.device(device=self.builder_args.device):
758+
data = transform({"messages": messages}, inference=True)
759+
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
760+
seq_len = len(data["tokens"])
761+
batch["causal_mask"] = torch.tril(
762+
torch.ones(
763+
size=(generator_args.max_new_tokens, generator_args.max_new_tokens),
764+
dtype=torch.bool,
765+
)
764766
)
765-
)
766-
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
767-
encoded = batch["tokens"]
767+
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
768+
encoded = batch["tokens"]
768769

769770
else:
770771
encoded = self.encode_tokens(

0 commit comments

Comments
 (0)