55# LICENSE file in the root directory of this source tree.
66
77import os
8- from typing import Optional
8+ from typing import Dict , Optional
99
1010import torch
1111import torch ._inductor
@@ -39,6 +39,7 @@ def export_for_server(
3939 output_path : str = "model.pt2" ,
4040 dynamic_shapes : bool = False ,
4141 package : bool = True ,
42+ metadata : Optional [Dict [str , str ]] = None ,
4243) -> str :
4344 """
4445 Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,8 +68,10 @@ def export_for_server(
6768 dynamic_shapes = None
6869
6970 with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
70- metadata = {} # TODO: put more metadata here
71- options = {"aot_inductor.metadata" : metadata }
71+ options = {
72+ "aot_inductor.package" : package ,
73+ "aot_inductor.metadata" : metadata or {},
74+ }
7275 if not package :
7376 options = {"aot_inductor.output_path" : output_path }
7477
@@ -106,13 +109,13 @@ def export_for_server(
106109 from typing import Any , Dict , Tuple , Union
107110
108111 import executorch .exir as exir
112+ from executorch .backends .xnnpack ._passes .convert_to_linear import (
113+ ConvertToLinearPass ,
114+ )
109115
110116 from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
111117 XnnpackDynamicallyQuantizedPartitioner ,
112118 )
113- from executorch .backends .xnnpack ._passes .convert_to_linear import (
114- ConvertToLinearPass ,
115- )
116119 from executorch .exir import EdgeProgramManager , to_edge
117120
118121 from executorch .exir .capture ._config import (
@@ -170,18 +173,22 @@ def __init__(self, attention: Attention):
170173
171174 self .wo = attention .wo
172175
173- max_batch_size , n_heads , max_seq_length , head_dim = (
174- attention . kv_cache [ 0 ]. k_cache . shape
175- )
176+ max_batch_size , n_heads , max_seq_length , head_dim = attention . kv_cache [
177+ 0
178+ ]. k_cache . shape
176179 cache_dtype = attention .kv_cache [0 ].k_cache .dtype
177180 # The `Attention` module being replaced can have multiple KV caches
178181 # (denoted by `cache_lanes`). Thus we follow the same setup format
179182 # as in `Attention.setup_cache`.
180183 cache_lanes = len (attention .kv_cache )
181- self .kv_cache = nn .ModuleList ([
182- CustomKVCache (max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype )
183- for _ in range (cache_lanes )
184- ])
184+ self .kv_cache = nn .ModuleList (
185+ [
186+ CustomKVCache (
187+ max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype
188+ )
189+ for _ in range (cache_lanes )
190+ ]
191+ )
185192
186193 self .n_heads = attention .n_heads
187194 self .head_dim = attention .head_dim
@@ -219,9 +226,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
219226 return self .wo (output )
220227
221228 def replace_attention_with_custom_sdpa_attention (module : nn .Module ):
222- from executorch .extension .llm .custom_ops import ( # noqa
223- sdpa_with_kv_cache ,
224- )
229+ from executorch .extension .llm .custom_ops import sdpa_with_kv_cache # noqa
225230
226231 for name , child in module .named_children ():
227232 if isinstance (child , Attention ):
@@ -242,7 +247,9 @@ def _to_core_aten(
242247 raise ValueError (
243248 f"Expected passed in model to be an instance of fx.GraphModule, got { type (model )} "
244249 )
245- core_aten_ep = export_for_training (model , example_inputs , dynamic_shapes = dynamic_shapes )
250+ core_aten_ep = export_for_training (
251+ model , example_inputs , dynamic_shapes = dynamic_shapes
252+ )
246253 if verbose :
247254 logging .info (f"Core ATen graph:\n { core_aten_ep .graph } " )
248255 return core_aten_ep
@@ -354,7 +361,11 @@ def main(args):
354361
355362 print (f"Using device={ builder_args .device } " )
356363 set_precision (builder_args .precision )
357- set_backend (dso = args .output_dso_path , pte = args .output_pte_path , aoti_package = args .output_aoti_package_path )
364+ set_backend (
365+ dso = args .output_dso_path ,
366+ pte = args .output_pte_path ,
367+ aoti_package = args .output_aoti_package_path ,
368+ )
358369
359370 builder_args .dso_path = None
360371 builder_args .pte_path = None
@@ -376,6 +387,7 @@ def main(args):
376387
377388 # TODO: clean this up
378389 # This mess is because ET does not support _weight_int4pack_mm right now
390+ tokenizer_args = None
379391 if not builder_args .gguf_path :
380392 # tokenizer needed for quantization so get that here,
381393 try :
@@ -386,9 +398,8 @@ def main(args):
386398
387399 if builder_args .max_seq_length is None :
388400 if (
389- (output_dso_path is not None or output_aoti_package_path is not None )
390- and not builder_args .dynamic_shapes
391- ):
401+ output_dso_path is not None or output_aoti_package_path is not None
402+ ) and not builder_args .dynamic_shapes :
392403 print ("Setting max_seq_length to 300 for DSO export." )
393404 builder_args .max_seq_length = 300
394405 elif output_pte_path is not None :
@@ -401,7 +412,8 @@ def main(args):
401412 quantize ,
402413 tokenizer ,
403414 max_seq_length = builder_args .max_seq_length ,
404- support_tensor_subclass = output_dso_path is None and output_aoti_package_path is None ,
415+ support_tensor_subclass = output_dso_path is None
416+ and output_aoti_package_path is None ,
405417 )
406418 model_to_pte = model
407419 model_to_dso = model
@@ -439,7 +451,9 @@ def main(args):
439451 if output_dso_path :
440452 output_dso_path = str (os .path .abspath (output_dso_path ))
441453 print (f"Exporting model using AOT Inductor to { output_dso_path } " )
442- print ("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead." )
454+ print (
455+ "WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
456+ )
443457 export_for_server (
444458 model_to_dso ,
445459 builder_args .device ,
@@ -450,11 +464,23 @@ def main(args):
450464
451465 if output_aoti_package_path :
452466 output_aoti_package_path = str (os .path .abspath (output_aoti_package_path ))
453- print (f"Exporting model using AOT Inductor to { output_aoti_package_path } " )
467+
468+ if tokenizer_args is None :
469+ tokenizer_type = "0"
470+ elif tokenizer_args .is_sentencepiece :
471+ tokenizer_type = "2" # Corresponding to llama2
472+ else :
473+ tokenizer_type = "3" # Corresponding to llama3
474+
475+ metadata = {"tokenizer_type" : tokenizer_type }
476+ print (
477+ "Exporting model using AOT Inductor to " f"{ output_aoti_package_path } ."
478+ )
454479 export_for_server (
455480 model_to_aoti_package ,
456481 builder_args .device ,
457482 output_aoti_package_path ,
458483 builder_args .dynamic_shapes ,
459484 package = True ,
485+ metadata = metadata ,
460486 )
0 commit comments