1919 XNNPACKQuantizer ,
2020)
2121from executorch .backends .xnnpack .utils .configs import get_xnnpack_edge_compile_config
22- from executorch .exir import to_edge
22+ from executorch .exir import to_edge_transform_and_lower
23+ from executorch .exir .capture ._config import ExecutorchBackendConfig
24+ from executorch .exir .passes import MemoryPlanningPass
25+ from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2326from torch .export import export_for_training
27+ from torch .nn .attention import SDPBackend
2428from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
2529
2630from transformers import Phi3ForCausalLM
31+ from transformers .cache_utils import StaticCacheConfig
2732
28- from .phi_3_mini import Phi3Mini
33+ from transformers .integrations .executorch import TorchExportableModuleForDecoderOnlyLM
34+
35+
36+ def _prepare_export_inputs (max_seq_len : int , sliding_window : int ):
37+ """
38+ Prepare example inputs and configurations for export.
39+
40+ Returns:
41+ example_input_ids (torch.Tensor): Example input IDs tensor.
42+ example_cache_position (torch.Tensor): Example cache position tensor.
43+ dynamic_shapes (dict or None): Dynamic shape specifications for export.
44+ strict (bool): Whether to use strict export mode.
45+ """
46+ # Prepare inputs with dynamic shapes
47+ seq_length = 3 # Sequence length > 1 to avoid specialization issues
48+ example_input_ids = torch .zeros ((1 , seq_length ), dtype = torch .long )
49+ example_cache_position = torch .arange (seq_length , dtype = torch .long )
50+ max_dim = min (max_seq_len , sliding_window ) - 1
51+ seq_len_dim = torch .export .Dim ("seq_length_dim" , max = max_dim )
52+ dynamic_shapes = {
53+ "input_ids" : {1 : seq_len_dim },
54+ "cache_position" : {0 : seq_len_dim },
55+ }
56+
57+ return example_input_ids , example_cache_position , dynamic_shapes
2958
3059
3160def export (args ) -> None :
@@ -40,51 +69,70 @@ def export(args) -> None:
4069 f"Invalid context length { args .context_length } . Should be either 4k or 128k"
4170 )
4271
43- with torch .no_grad ():
44- model = Phi3Mini (
45- # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
46- model = Phi3ForCausalLM .from_pretrained (model_name ),
72+ with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
73+ model = Phi3ForCausalLM .from_pretrained (model_name )
74+ model .generation_config .cache_implementation = "static"
75+ model .generation_config .cache_config = StaticCacheConfig (
76+ batch_size = 1 , max_cache_len = model .config .max_position_embeddings
77+ )
78+
79+ exportable_module = TorchExportableModuleForDecoderOnlyLM (
80+ model ,
4781 max_batch_size = 1 ,
48- max_seq_len = args . seq_len ,
82+ max_cache_len = model . config . max_position_embeddings ,
4983 )
50- example_inputs = (
51- torch .tensor (
52- [[1048 , 263 , 931 , 746 ]], dtype = torch .long , requires_grad = False
53- ),
84+ input_ids , cache_position , dynamic_shapes = _prepare_export_inputs (
85+ model .config .max_position_embeddings , model .config .sliding_window
86+ )
87+ example_inputs = (input_ids , cache_position )
88+ exported_program = exportable_module .export (
89+ input_ids , cache_position , dynamic_shapes , strict = False
90+ )
91+ # Apply RemoveTransposes pass to remove
92+ # any back-to-back transpose ops that are not needed
93+ # e.g. output of update_cache is transposed and
94+ # input to custom_sdpa is transposed.
95+ from executorch .extension .llm .export .export_passes import (
96+ RemoveRedundantTransposes ,
5497 )
55- dynamic_shapes = {
56- "input_ids" : {
57- 1 : torch .export .Dim ("sequence_length" , min = 1 , max = args .seq_len )
58- }
59- }
98+
99+ mutated_gm = RemoveRedundantTransposes ()(exported_program .module ())[0 ]
60100
61101 xnnpack_quant_config = get_symmetric_quantization_config (
62102 is_per_channel = True , is_dynamic = True
63103 )
64104 xnnpack_quantizer = XNNPACKQuantizer ()
65105 xnnpack_quantizer .set_global (xnnpack_quant_config )
66106
67- model = export_for_training (
68- model , example_inputs , dynamic_shapes = dynamic_shapes , strict = True
69- ).module ()
70- model = prepare_pt2e (model , xnnpack_quantizer ) # pyre-fixme[6]
71- model (* example_inputs )
72- model = convert_pt2e (model )
73- DuplicateDynamicQuantChainPass ()(model )
74- # TODO(lunwenh): update it to use export once
75- # https://github.com/pytorch/pytorch/issues/128394 is resolved.
76- model = torch .export ._trace ._export (
77- model ,
78- example_inputs ,
79- dynamic_shapes = dynamic_shapes ,
80- strict = False ,
81- pre_dispatch = False ,
107+ gm = prepare_pt2e (mutated_gm , xnnpack_quantizer ) # pyre-fixme[6]
108+ gm (* example_inputs )
109+ gm = convert_pt2e (gm )
110+ DuplicateDynamicQuantChainPass ()(gm )
111+ exported_program = export_for_training (
112+ gm , example_inputs , dynamic_shapes = dynamic_shapes , strict = False
82113 )
83114
84115 edge_config = get_xnnpack_edge_compile_config ()
85- edge_manager = to_edge (model , compile_config = edge_config )
116+ edge_manager = to_edge_transform_and_lower (
117+ exported_program ,
118+ partitioner = [XnnpackPartitioner ()],
119+ compile_config = edge_config ,
120+ constant_methods = {
121+ "get_eos_ids" : [32000 ],
122+ "use_kv_cache" : True ,
123+ "enable_dynamic_shape" : True ,
124+ "get_max_seq_len" : model .config .max_position_embeddings - 1 ,
125+ },
126+ )
86127 edge_manager = edge_manager .to_backend (XnnpackPartitioner ())
87- et_program = edge_manager .to_executorch ()
128+ et_program = edge_manager .to_executorch (
129+ ExecutorchBackendConfig (
130+ extract_delegate_segments = True ,
131+ do_quant_fusion_and_const_prop = True ,
132+ memory_planning_pass = MemoryPlanningPass (alloc_graph_input = False ),
133+ sym_shape_eval_pass = ConstraintBasedSymShapeEvalPass (),
134+ )
135+ )
88136
89137 with open (args .output_name , "wb" ) as file :
90138 file .write (et_program .buffer )
0 commit comments