diff --git a/torchchat/generate.py b/torchchat/generate.py index 4d2439d2f..987fb3e44 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -949,7 +949,7 @@ def _gen_model_input( if image_found: batch = padded_collate_tiled_images_and_mask( - [data], pad_direction="left", pad_max_images=1 + [data], pad_direction="left", pad_max_images=1, pad_max_tiles=transform.max_num_tiles ) encoded = batch.pop("tokens").to(device).view(-1) seq_len = encoded.size(0)