2828
2929
3030""" 
31- Export for Server   
31+ Export for Server 
3232""" 
3333
3434
@@ -78,7 +78,7 @@ def export_for_server(
7878""" 
7979Export for ExecuTorch 
8080
81- TODO (https://github.com/pytorch/torchchat/issues/1058): Replace   
81+ TODO (https://github.com/pytorch/torchchat/issues/1058): Replace 
8282replace_attention_with_custom_sdpa_attention with ET's implementation 
8383""" 
8484
@@ -94,6 +94,9 @@ def export_for_server(
9494    from  executorch .backends .xnnpack .partition .xnnpack_partitioner  import  (
9595        XnnpackDynamicallyQuantizedPartitioner ,
9696    )
97+     from  executorch .backends .xnnpack .passes .convert_to_linear  import  (
98+         ConvertToLinearPass ,
99+     )
97100    from  executorch .exir  import  EdgeProgramManager , to_edge 
98101
99102    from  executorch .exir .capture ._config  import  (
@@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
274277            _skip_type_promotion = bool (target_precision  ==  torch .float16 ),
275278        )
276279
277-         if  target_precision  ==  torch .float16  or  target_precision  ==  torch .bfloat16 :
278-             if  state_dict_dtype  !=  torch .float16 :
279-                 print ("model.to torch.float16" )
280-                 model  =  model .to (dtype = torch .float16 )
281-                 state_dict_dtype  =  torch .float16 
282-         elif  target_precision  ==  torch .float32 :
283-             if  state_dict_dtype  !=  torch .float32 :
284-                 print ("model.to torch.float32" )
285-                 model  =  model .to (dtype = torch .float32 )
286-         elif  target_precision  ==  torch .bfloat16 :
287-             print ("model.to torch.bfloat16" )
288-             model  =  model .to (dtype = torch .bfloat16 )
289-         else :
280+         if  target_precision  not  in torch .float16 , torch .float32 , torch .bfloat16 ):
290281            raise  ValueError (f"Unsupported dtype for ET export: { target_precision }  )
291282
292-         replace_attention_with_custom_sdpa_attention (model )
283+         if  state_dict_dtype  !=  target_precision :
284+             print (f"model.to { target_precision }  )
285+             model  =  model .to (dtype = target_precision )
286+             state_dict_dtype  =  target_precision 
287+ 
288+         # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't 
289+         # support anything but bfloat32, and our attempt to use it anyway by converting 
290+         # to and from float causes other errors.) 
291+         if  target_precision  !=  torch .bfloat16 :
292+             replace_attention_with_custom_sdpa_attention (model )
293+ 
293294        with  torch .nn .attention .sdpa_kernel (
294295            [torch .nn .attention .SDPBackend .MATH ]
295296        ), torch .no_grad ():
@@ -306,6 +307,7 @@ def export_for_et(model, device, output_path) -> str:
306307            ExecutorchBackendConfig (
307308                extract_delegate_segments = True ,
308309                passes = [
310+                     ConvertToLinearPass (),
309311                    QuantFusionPass (),
310312                ],
311313                sym_shape_eval_pass = ConstraintBasedSymShapeEvalPass (),
0 commit comments