From 7072dd65071a69437839c5e5466054a39b9f34c4 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 21 Jan 2025 18:12:12 -0800 Subject: [PATCH] Typo: Fix generate signature type hint for attention_backend `attention_backend` is a SDPBackend, not a string --- torchchat/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index a596187f5..a06e215f4 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -685,7 +685,7 @@ def generate( sequential_prefill=True, callback=lambda x: x, max_seq_length: int, - attention_backend: str = "math", + attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, seed: Optional[int] = None, **sampling_kwargs, ) -> torch.Tensor: