@@ -532,7 +532,6 @@ def decode_n_tokens(
532532 callback = lambda _ : _ ,
533533 eos_token_id : int = 2 ,
534534 eot_id : Optional [int ] = None ,
535- attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
536535 ** sampling_kwargs ,
537536 ):
538537 new_tokens , new_probs = [], []
@@ -541,7 +540,8 @@ def decode_n_tokens(
541540 num_new_tokens - 1
542541 ): # -1 to save space to run an EoS if dont generate it naturally
543542 # Actually better for Inductor to codegen attention here
544- with torch .nn .attention .sdpa_kernel ([attention_backend ]):
543+ # with torch.nn.attention.sdpa_kernel([attention_backend]):
544+ if True : # preserve indentation while testing
545545
546546 out_token = cur_token .clone ()
547547 next_token , next_prob = self .decode_one_token (
@@ -685,7 +685,6 @@ def generate(
685685 sequential_prefill = True ,
686686 callback = lambda x : x ,
687687 max_seq_length : int ,
688- attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
689688 seed : Optional [int ] = None ,
690689 ** sampling_kwargs ,
691690 ) -> torch .Tensor :
@@ -802,7 +801,6 @@ def generate(
802801 if self .is_llama3_model
803802 else None
804803 ),
805- attention_backend = attention_backend ,
806804 ** sampling_kwargs ,
807805 ):
808806 generated_tokens .append (generated_token .view (- 1 ))
@@ -1174,7 +1172,7 @@ def callback(x, *, done_generating=False):
11741172 prof = torch .profiler .profile ()
11751173 t0 = time .perf_counter ()
11761174 num_tokens_generated = 0
1177- with prof :
1175+ with torch . nn . attention . sdpa_kernel ([ self . builder_args . attention_backend ]), prof :
11781176 generator_func = self .generate (
11791177 self .model ,
11801178 encoded ,
@@ -1190,7 +1188,6 @@ def callback(x, *, done_generating=False):
11901188 start_pos = start_pos ,
11911189 skip_cache_setup = not is_first_sample ,
11921190 max_seq_length = max_seq_length ,
1193- attention_backend = self .builder_args .attention_backend ,
11941191 )
11951192 if generator_args .chat_mode :
11961193 start_pos += encoded .size (0 )
0 commit comments