2828
2929
3030""" 
31- Export for Server   
31+ Export for Server 
3232""" 
3333
3434
@@ -78,7 +78,7 @@ def export_for_server(
7878""" 
7979Export for ExecuTorch 
8080
81- TODO (https://github.com/pytorch/torchchat/issues/1058): Replace   
81+ TODO (https://github.com/pytorch/torchchat/issues/1058): Replace 
8282replace_attention_with_custom_sdpa_attention with ET's implementation 
8383""" 
8484
@@ -94,6 +94,9 @@ def export_for_server(
9494    from  executorch .backends .xnnpack .partition .xnnpack_partitioner  import  (
9595        XnnpackDynamicallyQuantizedPartitioner ,
9696    )
97+     from  executorch .backends .xnnpack .passes .convert_to_linear  import  (
98+         ConvertToLinearPass ,
99+     )
97100    from  executorch .exir  import  EdgeProgramManager , to_edge 
98101
99102    from  executorch .exir .capture ._config  import  (
@@ -194,7 +197,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
194197            return  self .wo (output )
195198
196199    def  replace_attention_with_custom_sdpa_attention (module : nn .Module ):
197-         from  executorch .examples . models . llama2 .custom_ops  import  (  # noqa 
200+         from  executorch .extension . llm .custom_ops  import  (  # noqa 
198201            sdpa_with_kv_cache ,
199202        )
200203
@@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
274277            _skip_type_promotion = bool (target_precision  ==  torch .float16 ),
275278        )
276279
277-         if  target_precision  ==  torch .float16  or  target_precision  ==  torch .bfloat16 :
278-             if  state_dict_dtype  !=  torch .float16 :
279-                 print ("model.to torch.float16" )
280-                 model  =  model .to (dtype = torch .float16 )
281-                 state_dict_dtype  =  torch .float16 
282-         elif  target_precision  ==  torch .float32 :
283-             if  state_dict_dtype  !=  torch .float32 :
284-                 print ("model.to torch.float32" )
285-                 model  =  model .to (dtype = torch .float32 )
286-         elif  target_precision  ==  torch .bfloat16 :
287-             print ("model.to torch.bfloat16" )
288-             model  =  model .to (dtype = torch .bfloat16 )
289-         else :
280+         if  target_precision  not  in   (torch .float16 , torch .float32 , torch .bfloat16 ):
290281            raise  ValueError (f"Unsupported dtype for ET export: { target_precision }  " )
291282
292-         replace_attention_with_custom_sdpa_attention (model )
283+         if  state_dict_dtype  !=  target_precision :
284+             print (f"model.to { target_precision }  " )
285+             model  =  model .to (dtype = target_precision )
286+             state_dict_dtype  =  target_precision 
287+ 
288+         # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't 
289+         # support anything but bfloat32, and our attempt to use it anyway by converting 
290+         # to and from float causes other errors.) 
291+         if  target_precision  !=  torch .bfloat16 :
292+             replace_attention_with_custom_sdpa_attention (model )
293+ 
293294        with  torch .nn .attention .sdpa_kernel (
294295            [torch .nn .attention .SDPBackend .MATH ]
295296        ), torch .no_grad ():
@@ -304,9 +305,9 @@ def export_for_et(model, device, output_path) -> str:
304305        edge_manager  =  edge_manager .to_backend (XnnpackDynamicallyQuantizedPartitioner ())
305306        export_program  =  edge_manager .to_executorch (
306307            ExecutorchBackendConfig (
307-                 extract_constant_segment = True ,
308308                extract_delegate_segments = True ,
309309                passes = [
310+                     ConvertToLinearPass (),
310311                    QuantFusionPass (),
311312                ],
312313                sym_shape_eval_pass = ConstraintBasedSymShapeEvalPass (),
@@ -363,13 +364,17 @@ def main(args):
363364        except :
364365            tokenizer  =  None 
365366
366-         if  (
367-             output_dso_path  is  not   None 
368-             and  builder_args .max_seq_length  is  None 
369-             and  not  builder_args .dynamic_shapes 
370-         ):
371-             print ("Setting max_seq_length to 300 for DSO export." )
372-             builder_args .max_seq_length  =  300 
367+         if  builder_args .max_seq_length  is  None :
368+             if  (
369+                 output_dso_path  is  not   None 
370+                 and  not  builder_args .dynamic_shapes 
371+             ):
372+                 print ("Setting max_seq_length to 300 for DSO export." )
373+                 builder_args .max_seq_length  =  300 
374+             elif  output_pte_path  is  not   None :
375+                 # The value of 128 was chosen to match the ExecuTorch Llama example setup. 
376+                 print ("Setting max_seq_length to 128 for ExecuTorch export." )
377+                 builder_args .max_seq_length  =  128 
373378
374379        model  =  _initialize_model (
375380            builder_args ,
0 commit comments