2626import torch .distributed as dist
2727import torch .multiprocessing as mp
2828from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
29+ from torch ._C import _SDPBackend as SDPBackend
2930
3031from PIL import Image
3132
@@ -531,6 +532,7 @@ def decode_n_tokens(
531532 callback = lambda _ : _ ,
532533 eos_token_id : int = 2 ,
533534 eot_id : Optional [int ] = None ,
535+ attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
534536 ** sampling_kwargs ,
535537 ):
536538 new_tokens , new_probs = [], []
@@ -539,7 +541,7 @@ def decode_n_tokens(
539541 num_new_tokens - 1
540542 ): # -1 to save space to run an EoS if dont generate it naturally
541543 # Actually better for Inductor to codegen attention here
542- with torch .nn .attention .sdpa_kernel ([torch . nn . attention . SDPBackend . MATH ]):
544+ with torch .nn .attention .sdpa_kernel ([attention_backend ]):
543545
544546 out_token = cur_token .clone ()
545547 next_token , next_prob = self .decode_one_token (
@@ -683,6 +685,7 @@ def generate(
683685 sequential_prefill = True ,
684686 callback = lambda x : x ,
685687 max_seq_length : int ,
688+ attention_backend : str = "math" ,
686689 seed : Optional [int ] = None ,
687690 ** sampling_kwargs ,
688691 ) -> torch .Tensor :
@@ -799,6 +802,7 @@ def generate(
799802 if self .is_llama3_model
800803 else None
801804 ),
805+ attention_backend = attention_backend ,
802806 ** sampling_kwargs ,
803807 ):
804808 generated_tokens .append (generated_token .view (- 1 ))
@@ -1186,6 +1190,7 @@ def callback(x, *, done_generating=False):
11861190 start_pos = start_pos ,
11871191 skip_cache_setup = not is_first_sample ,
11881192 max_seq_length = max_seq_length ,
1193+ attention_backend = self .builder_args .attention_backend ,
11891194 )
11901195 if generator_args .chat_mode :
11911196 start_pos += encoded .size (0 )
@@ -1205,8 +1210,10 @@ def callback(x, *, done_generating=False):
12051210 if hasattr (prof , "export_chrome_trace" ):
12061211 if self .builder_args .device == "cpu" :
12071212 print (prof .key_averages ().table (sort_by = "self_cpu_time_total" ))
1208- else :
1213+ elif self . builder_args . device == "cuda" :
12091214 print (prof .key_averages ().table (sort_by = "self_cuda_time_total" ))
1215+ else :
1216+ print (prof .key_averages ().table (sort_by = "self_xpu_time_total" ))
12101217 prof .export_chrome_trace (f"{ self .profile } .json" )
12111218
12121219 if start_pos >= max_seq_length :
@@ -1291,6 +1298,9 @@ def callback(x, *, done_generating=False):
12911298 )
12921299 if torch .cuda .is_available ():
12931300 print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
1301+ if torch .xpu .is_available ():
1302+ print (f"Memory used: { torch .xpu .max_memory_reserved () / 1e9 :.02f} GB" )
1303+
12941304
12951305
12961306class DistributedGenerator (LocalGenerator ):
@@ -1617,6 +1627,8 @@ def run_generator(
16171627 )
16181628 if torch .cuda .is_available ():
16191629 torch .cuda .reset_peak_memory_stats ()
1630+ if torch .xpu .is_available ():
1631+ torch .xpu .reset_peak_memory_stats ()
16201632
16211633 for _ in gen .chat (generator_args ):
16221634 pass
0 commit comments