2828
2929
3030"""
31- Export for Server
31+ Export for Server
3232"""
3333
3434
@@ -78,7 +78,7 @@ def export_for_server(
7878"""
7979Export for ExecuTorch
8080
81- TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
81+ TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
8282replace_attention_with_custom_sdpa_attention with ET's implementation
8383"""
8484
@@ -94,6 +94,9 @@ def export_for_server(
9494 from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
9595 XnnpackDynamicallyQuantizedPartitioner ,
9696 )
97+ from executorch .backends .xnnpack .passes .convert_to_linear import (
98+ ConvertToLinearPass ,
99+ )
97100 from executorch .exir import EdgeProgramManager , to_edge
98101
99102 from executorch .exir .capture ._config import (
@@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
274277 _skip_type_promotion = bool (target_precision == torch .float16 ),
275278 )
276279
277- if target_precision == torch .float16 or target_precision == torch .bfloat16 :
278- if state_dict_dtype != torch .float16 :
279- print ("model.to torch.float16" )
280- model = model .to (dtype = torch .float16 )
281- state_dict_dtype = torch .float16
282- elif target_precision == torch .float32 :
283- if state_dict_dtype != torch .float32 :
284- print ("model.to torch.float32" )
285- model = model .to (dtype = torch .float32 )
286- elif target_precision == torch .bfloat16 :
287- print ("model.to torch.bfloat16" )
288- model = model .to (dtype = torch .bfloat16 )
289- else :
280+ if target_precision not in (torch .float16 , torch .float32 , torch .bfloat16 ):
290281 raise ValueError (f"Unsupported dtype for ET export: { target_precision } " )
291282
292- replace_attention_with_custom_sdpa_attention (model )
283+ if state_dict_dtype != target_precision :
284+ print (f"model.to { target_precision } " )
285+ model = model .to (dtype = target_precision )
286+ state_dict_dtype = target_precision
287+
288+ # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
289+ # support anything but bfloat32, and our attempt to use it anyway by converting
290+ # to and from float causes other errors.)
291+ if target_precision != torch .bfloat16 :
292+ replace_attention_with_custom_sdpa_attention (model )
293+
293294 with torch .nn .attention .sdpa_kernel (
294295 [torch .nn .attention .SDPBackend .MATH ]
295296 ), torch .no_grad ():
@@ -306,6 +307,7 @@ def export_for_et(model, device, output_path) -> str:
306307 ExecutorchBackendConfig (
307308 extract_delegate_segments = True ,
308309 passes = [
310+ ConvertToLinearPass (),
309311 QuantFusionPass (),
310312 ],
311313 sym_shape_eval_pass = ConstraintBasedSymShapeEvalPass (),
0 commit comments