Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 437fd3e

Browse files
committed
manually cast dtype
1 parent 21ffafe commit 437fd3e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torchchat/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,8 @@ def chat(
758758
with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype):
759759
data = transform({"messages": messages}, inference=True)
760760
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
761+
# set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it
762+
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
761763
seq_len = len(data["tokens"])
762764
total_response_length = seq_len + generator_args.max_new_tokens
763765
batch["causal_mask"] = torch.tril(

0 commit comments

Comments
 (0)