Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 0ac5f50

Browse files
committed
remove hacky cache size, add comment for magic number
1 parent f15957e commit 0ac5f50

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchchat/generate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,8 @@ def generate(
604604
self.is_torchtune_model
605605
or self.model.config.model_type == ModelType.Flamingo
606606
):
607-
model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=max_seq_length-1)
607+
# 6404 is one-gpu affordable max_seq_length for single image input
608+
model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new)
608609
else:
609610
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
610611
if is_speculative and draft_model is not model:
@@ -758,14 +759,15 @@ def chat(
758759
data = transform({"messages": messages}, inference=True)
759760
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
760761
seq_len = len(data["tokens"])
762+
total_response_length = seq_len + generator_args.max_new_tokens
761763
batch["causal_mask"] = torch.tril(
762764
torch.ones(
763-
size=(generator_args.max_new_tokens, generator_args.max_new_tokens),
765+
size=(total_response_length, total_response_length),
764766
dtype=torch.bool,
765767
)
766768
)
767769
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
768-
encoded = batch["tokens"]
770+
encoded = batch["tokens"].view(-1)
769771

770772
else:
771773
encoded = self.encode_tokens(

0 commit comments

Comments
 (0)