diff --git a/torchchat/generate.py b/torchchat/generate.py index 5eb946f7d..a74344724 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -42,7 +42,9 @@ from torchchat.model import Model, ModelType from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info - +# torch._dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.capture_dynamic_output_shape_ops = True +# torch._dynamo.config.suppress_errors = True class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -415,7 +417,9 @@ def decode_one_token( x = x.view(1, -1) if model.config.model_type == ModelType.Flamingo: assert batch is not None, "Flamingo requires batch" - mask = batch["causal_mask"][None, input_pos.item(), None, :] + # breakpoint() + # start_pos = input_pos.item() + mask = batch["causal_mask"][None, input_pos, None, :].view(1, 1, -1) encoder_mask = batch["encoder_mask"] if "encoder_mask" in batch else None logits = model( x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos