diff --git a/torchchat/generate.py b/torchchat/generate.py index a596187f5..a06e215f4 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -685,7 +685,7 @@ def generate( sequential_prefill=True, callback=lambda x: x, max_seq_length: int, - attention_backend: str = "math", + attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, seed: Optional[int] = None, **sampling_kwargs, ) -> torch.Tensor: