Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2e455d0

Browse files
author
vmpuri
committed
Fix multimodal input when no image prompt is present
1 parent c25e24f commit 2e455d0

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

torchchat/generate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ def prefill(
364364
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
365365
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
366366
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
367+
elif self.model.config.model_type == ModelType.Flamingo:
368+
logits = model(x)
367369
else:
368370
# input_pos: [B, S]
369371
logits = model(x, input_pos)
@@ -383,11 +385,14 @@ def decode_one_token(
383385
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
384386
# input_pos: [B, 1]
385387
assert input_pos.shape[-1] == 1
386-
if model.config.model_type == ModelType.Flamingo and batch is not None:
387-
x = x.view(1, -1)
388-
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
388+
x = x.view(1, -1)
389+
if model.config.model_type == ModelType.Flamingo:
390+
if batch is not None:
391+
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
392+
else:
393+
logits = model(x)
389394
else:
390-
logits = model(x.view(1, -1), input_pos)
395+
logits = model(x, input_pos)
391396
# print(f"x: {x},\n input_pos: {input_pos}\n")
392397
return self.sample(logits, need_probs=need_probs, **sampling_kwargs)
393398

0 commit comments

Comments
 (0)