Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 254978d

Browse files
committed
Fix torch.compile on the Llama3.2 vision model
1 parent 286527c commit 254978d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torchchat/generate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
from torchchat.model import Model, ModelType
4343
from torchchat.utils.build_utils import device_sync, set_precision
4444
from torchchat.utils.device_info import get_device_info
45-
45+
# torch._dynamo.config.capture_scalar_outputs = True
46+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
47+
# torch._dynamo.config.suppress_errors = True
4648

4749
class _ChatFormatter(ABC):
4850
def __init__(self, tokenizer):
@@ -415,7 +417,9 @@ def decode_one_token(
415417
x = x.view(1, -1)
416418
if model.config.model_type == ModelType.Flamingo:
417419
assert batch is not None, "Flamingo requires batch"
418-
mask = batch["causal_mask"][None, input_pos.item(), None, :]
420+
# breakpoint()
421+
# start_pos = input_pos.item()
422+
mask = batch["causal_mask"][None, input_pos, None, :].view(1, 1, -1)
419423
encoder_mask = batch["encoder_mask"] if "encoder_mask" in batch else None
420424
logits = model(
421425
x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos

0 commit comments

Comments
 (0)