2828
2929
3030"""
31- Export for Server
31+ Export for Server
3232"""
3333
3434
@@ -78,7 +78,7 @@ def export_for_server(
7878"""
7979Export for ExecuTorch
8080
81- TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
81+ TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
8282replace_attention_with_custom_sdpa_attention with ET's implementation
8383"""
8484
@@ -94,6 +94,9 @@ def export_for_server(
9494 from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
9595 XnnpackDynamicallyQuantizedPartitioner ,
9696 )
97+ from executorch .backends .xnnpack .passes .convert_to_linear import (
98+ ConvertToLinearPass ,
99+ )
97100 from executorch .exir import EdgeProgramManager , to_edge
98101
99102 from executorch .exir .capture ._config import (
@@ -194,7 +197,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
194197 return self .wo (output )
195198
196199 def replace_attention_with_custom_sdpa_attention (module : nn .Module ):
197- from executorch .examples . models . llama2 .custom_ops import ( # noqa
200+ from executorch .extension . llm .custom_ops import ( # noqa
198201 sdpa_with_kv_cache ,
199202 )
200203
@@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
274277 _skip_type_promotion = bool (target_precision == torch .float16 ),
275278 )
276279
277- if target_precision == torch .float16 or target_precision == torch .bfloat16 :
278- if state_dict_dtype != torch .float16 :
279- print ("model.to torch.float16" )
280- model = model .to (dtype = torch .float16 )
281- state_dict_dtype = torch .float16
282- elif target_precision == torch .float32 :
283- if state_dict_dtype != torch .float32 :
284- print ("model.to torch.float32" )
285- model = model .to (dtype = torch .float32 )
286- elif target_precision == torch .bfloat16 :
287- print ("model.to torch.bfloat16" )
288- model = model .to (dtype = torch .bfloat16 )
289- else :
280+ if target_precision not in (torch .float16 , torch .float32 , torch .bfloat16 ):
290281 raise ValueError (f"Unsupported dtype for ET export: { target_precision } " )
291282
292- replace_attention_with_custom_sdpa_attention (model )
283+ if state_dict_dtype != target_precision :
284+ print (f"model.to { target_precision } " )
285+ model = model .to (dtype = target_precision )
286+ state_dict_dtype = target_precision
287+
288+ # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
289+ # support anything but bfloat32, and our attempt to use it anyway by converting
290+ # to and from float causes other errors.)
291+ if target_precision != torch .bfloat16 :
292+ replace_attention_with_custom_sdpa_attention (model )
293+
293294 with torch .nn .attention .sdpa_kernel (
294295 [torch .nn .attention .SDPBackend .MATH ]
295296 ), torch .no_grad ():
@@ -304,9 +305,9 @@ def export_for_et(model, device, output_path) -> str:
304305 edge_manager = edge_manager .to_backend (XnnpackDynamicallyQuantizedPartitioner ())
305306 export_program = edge_manager .to_executorch (
306307 ExecutorchBackendConfig (
307- extract_constant_segment = True ,
308308 extract_delegate_segments = True ,
309309 passes = [
310+ ConvertToLinearPass (),
310311 QuantFusionPass (),
311312 ],
312313 sym_shape_eval_pass = ConstraintBasedSymShapeEvalPass (),
0 commit comments