@@ -69,10 +69,16 @@ class BuilderArgs:
6969    prefill_possible : bool  =  False 
7070    dynamic_shapes : bool  =  False 
7171    max_seq_length : Optional [int ] =  None 
72+     attention_backend : str  =  "math" 
7273
7374    def  __post_init__ (self ):
7475        if  self .device  is  None :
75-             self .device  =  "cuda"  if  torch .cuda .is_available () else  "cpu" 
76+             if  torch .cuda .is_available ():
77+                 self .device  =  "cuda" 
78+             elif  torch .xpu .is_available ():
79+                 self .device  =  "xpu" 
80+             else :
81+                 self .device  =  "cpu" 
7682
7783        if  not  (
7884            (self .checkpoint_path  and  self .checkpoint_path .is_file ())
@@ -178,6 +184,17 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
178184        pp  =  getattr (args , "pp" , 1 )
179185        tp  =  getattr (args , "tp" , 1 )
180186        chpt_from  =  getattr (args , "chpt_from" , "hf" )
187+         sdp_backend_dict  =  {
188+             'math' : torch .nn .attention .SDPBackend .MATH ,
189+             'flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
190+             'efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
191+             'cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
192+         }
193+         attention_backend  =  sdp_backend_dict [args .attention_backend ]
194+         if  args .device  ==  "cpu"  and  (args .attention_backend  ==  "efficient_attention" 
195+                                      or  args .attention_backend  ==  "cudnn_attention" ):
196+             print (f"Warning: { args .attention_backend }  )
197+             attention_backend  =  torch .nn .attention .SDPBackend .MATH 
181198        return  cls (
182199            checkpoint_dir = checkpoint_dir ,
183200            checkpoint_path = checkpoint_path ,
@@ -202,6 +219,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
202219            is_chat_model = is_chat_model ,
203220            dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
204221            max_seq_length = getattr (args , "max_seq_length" , None ),
222+             attention_backend = attention_backend ,
205223        )
206224
207225    @classmethod  
@@ -571,9 +589,8 @@ def do_nothing(max_batch_size, max_seq_length):
571589            # attributes will NOT be seen on by AOTI-compiled forward 
572590            # function, e.g. calling model.setup_cache will NOT touch 
573591            # AOTI compiled and maintained model buffers such as kv_cache. 
574-             from  torch ._inductor .package  import  load_package 
575592
576-             aoti_compiled_model  =  load_package (
593+             aoti_compiled_model  =  torch . _inductor . aoti_load_package (
577594                str (builder_args .aoti_package_path .absolute ())
578595            )
579596
0 commit comments