@@ -69,10 +69,16 @@ class BuilderArgs:
6969 prefill_possible : bool = False
7070 dynamic_shapes : bool = False
7171 max_seq_length : Optional [int ] = None
72+ attention_backend : str = "math"
7273
7374 def __post_init__ (self ):
7475 if self .device is None :
75- self .device = "cuda" if torch .cuda .is_available () else "cpu"
76+ if torch .cuda .is_available ():
77+ self .device = "cuda"
78+ elif torch .xpu .is_available ():
79+ self .device = "xpu"
80+ else :
81+ self .device = "cpu"
7682
7783 if not (
7884 (self .checkpoint_path and self .checkpoint_path .is_file ())
@@ -178,6 +184,17 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
178184 pp = getattr (args , "pp" , 1 )
179185 tp = getattr (args , "tp" , 1 )
180186 chpt_from = getattr (args , "chpt_from" , "hf" )
187+ sdp_backend_dict = {
188+ 'math' : torch .nn .attention .SDPBackend .MATH ,
189+ 'flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
190+ 'efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
191+ 'cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
192+ }
193+ attention_backend = sdp_backend_dict [args .attention_backend ]
194+ if args .device == "cpu" and (args .attention_backend == "efficient_attention"
195+ or args .attention_backend == "cudnn_attention" ):
196+ print (f"Warning: { args .attention_backend } is not supported on CPU. Using math instead." )
197+ attention_backend = torch .nn .attention .SDPBackend .MATH
181198 return cls (
182199 checkpoint_dir = checkpoint_dir ,
183200 checkpoint_path = checkpoint_path ,
@@ -202,6 +219,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
202219 is_chat_model = is_chat_model ,
203220 dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
204221 max_seq_length = getattr (args , "max_seq_length" , None ),
222+ attention_backend = attention_backend ,
205223 )
206224
207225 @classmethod
@@ -571,9 +589,8 @@ def do_nothing(max_batch_size, max_seq_length):
571589 # attributes will NOT be seen on by AOTI-compiled forward
572590 # function, e.g. calling model.setup_cache will NOT touch
573591 # AOTI compiled and maintained model buffers such as kv_cache.
574- from torch ._inductor .package import load_package
575592
576- aoti_compiled_model = load_package (
593+ aoti_compiled_model = torch . _inductor . aoti_load_package (
577594 str (builder_args .aoti_package_path .absolute ())
578595 )
579596
0 commit comments