2626import torch .distributed as dist
2727import torch .multiprocessing as mp
2828from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
29+ from torch ._C import _SDPBackend as SDPBackend
2930
3031from PIL import Image
3132
@@ -531,6 +532,7 @@ def decode_n_tokens(
531532 callback = lambda _ : _ ,
532533 eos_token_id : int = 2 ,
533534 eot_id : Optional [int ] = None ,
535+ attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
534536 ** sampling_kwargs ,
535537 ):
536538 new_tokens , new_probs = [], []
@@ -539,7 +541,7 @@ def decode_n_tokens(
539541 num_new_tokens - 1
540542 ): # -1 to save space to run an EoS if dont generate it naturally
541543 # Actually better for Inductor to codegen attention here
542- with torch .nn .attention .sdpa_kernel ([torch . nn . attention . SDPBackend . MATH ]):
544+ with torch .nn .attention .sdpa_kernel ([attention_backend ]):
543545
544546 out_token = cur_token .clone ()
545547 next_token , next_prob = self .decode_one_token (
@@ -683,6 +685,7 @@ def generate(
683685 sequential_prefill = True ,
684686 callback = lambda x : x ,
685687 max_seq_length : int ,
688+ attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
686689 seed : Optional [int ] = None ,
687690 ** sampling_kwargs ,
688691 ) -> torch .Tensor :
@@ -799,6 +802,7 @@ def generate(
799802 if self .is_llama3_model
800803 else None
801804 ),
805+ attention_backend = attention_backend ,
802806 ** sampling_kwargs ,
803807 ):
804808 generated_tokens .append (generated_token .view (- 1 ))
@@ -1122,7 +1126,7 @@ def chat(
11221126 messages_to_encode .append (
11231127 {"role" : "system" , "content" : self .system_prompt }
11241128 )
1125- messages_to_encode .append ({"role" : "system " , "content" : prompt })
1129+ messages_to_encode .append ({"role" : "user " , "content" : prompt })
11261130 encoded = self .chat_formatter .encode_dialog_prompt (
11271131 messages_to_encode , add_generation_prompt = True ,
11281132 )
@@ -1186,6 +1190,7 @@ def callback(x, *, done_generating=False):
11861190 start_pos = start_pos ,
11871191 skip_cache_setup = not is_first_sample ,
11881192 max_seq_length = max_seq_length ,
1193+ attention_backend = self .builder_args .attention_backend ,
11891194 )
11901195 for token_tensor , metrics in generator_func :
11911196 if token_tensor is not None :
@@ -1203,8 +1208,10 @@ def callback(x, *, done_generating=False):
12031208 if hasattr (prof , "export_chrome_trace" ):
12041209 if self .builder_args .device == "cpu" :
12051210 print (prof .key_averages ().table (sort_by = "self_cpu_time_total" ))
1206- else :
1211+ elif self . builder_args . device == "cuda" :
12071212 print (prof .key_averages ().table (sort_by = "self_cuda_time_total" ))
1213+ else :
1214+ print (prof .key_averages ().table (sort_by = "self_xpu_time_total" ))
12081215 prof .export_chrome_trace (f"{ self .profile } .json" )
12091216
12101217 if start_pos >= max_seq_length :
@@ -1289,6 +1296,9 @@ def callback(x, *, done_generating=False):
12891296 )
12901297 if torch .cuda .is_available ():
12911298 print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
1299+ if torch .xpu .is_available ():
1300+ print (f"Memory used: { torch .xpu .max_memory_reserved () / 1e9 :.02f} GB" )
1301+
12921302
12931303
12941304class DistributedGenerator (LocalGenerator ):
@@ -1615,6 +1625,8 @@ def run_generator(
16151625 )
16161626 if torch .cuda .is_available ():
16171627 torch .cuda .reset_peak_memory_stats ()
1628+ if torch .xpu .is_available ():
1629+ torch .xpu .reset_peak_memory_stats ()
16181630
16191631 for _ in gen .chat (generator_args ):
16201632 pass
0 commit comments