Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8e18e7f

Browse files
authored
Update generate.py
Push backend manager into caller
1 parent f4ae60f commit 8e18e7f

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

torchchat/generate.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,6 @@ def decode_n_tokens(
532532
callback=lambda _: _,
533533
eos_token_id: int = 2,
534534
eot_id: Optional[int] = None,
535-
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
536535
**sampling_kwargs,
537536
):
538537
new_tokens, new_probs = [], []
@@ -541,7 +540,8 @@ def decode_n_tokens(
541540
num_new_tokens - 1
542541
): # -1 to save space to run an EoS if dont generate it naturally
543542
# Actually better for Inductor to codegen attention here
544-
with torch.nn.attention.sdpa_kernel([attention_backend]):
543+
# with torch.nn.attention.sdpa_kernel([attention_backend]):
544+
if True: # preserve indentation while testing
545545

546546
out_token = cur_token.clone()
547547
next_token, next_prob = self.decode_one_token(
@@ -685,7 +685,6 @@ def generate(
685685
sequential_prefill=True,
686686
callback=lambda x: x,
687687
max_seq_length: int,
688-
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
689688
seed: Optional[int] = None,
690689
**sampling_kwargs,
691690
) -> torch.Tensor:
@@ -802,7 +801,6 @@ def generate(
802801
if self.is_llama3_model
803802
else None
804803
),
805-
attention_backend=attention_backend,
806804
**sampling_kwargs,
807805
):
808806
generated_tokens.append(generated_token.view(-1))
@@ -1174,7 +1172,7 @@ def callback(x, *, done_generating=False):
11741172
prof = torch.profiler.profile()
11751173
t0 = time.perf_counter()
11761174
num_tokens_generated = 0
1177-
with prof:
1175+
with torch.nn.attention.sdpa_kernel([self.builder_args.attention_backend]), prof:
11781176
generator_func = self.generate(
11791177
self.model,
11801178
encoded,
@@ -1190,7 +1188,6 @@ def callback(x, *, done_generating=False):
11901188
start_pos=start_pos,
11911189
skip_cache_setup=not is_first_sample,
11921190
max_seq_length=max_seq_length,
1193-
attention_backend=self.builder_args.attention_backend,
11941191
)
11951192
if generator_args.chat_mode:
11961193
start_pos += encoded.size(0)

0 commit comments

Comments
 (0)