55# LICENSE file in the root directory of this source tree.
66
77import os
8- from typing import Optional
8+ from typing import Dict , Optional
99
1010import torch
11+ import torch ._inductor
1112import torch .nn as nn
1213
1314from torch .export import Dim
14- import torch ._inductor
1515
1616from torchchat .cli .builder import (
1717 _initialize_model ,
@@ -39,6 +39,7 @@ def export_for_server(
3939 output_path : str = "model.pt2" ,
4040 dynamic_shapes : bool = False ,
4141 package : bool = True ,
42+ metadata : Dict [str , str ] = {},
4243) -> str :
4344 """
4445 Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,7 +68,6 @@ def export_for_server(
6768 dynamic_shapes = None
6869
6970 with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
70- metadata = {} # TODO: put more metadata here
7171 options = {"aot_inductor.package" : package , "aot_inductor.metadata" : metadata }
7272 if not package :
7373 options = {"aot_inductor.output_path" : output_path }
@@ -81,6 +81,7 @@ def export_for_server(
8181
8282 if package :
8383 from torch ._inductor .package import package_aoti
84+
8485 path = package_aoti (output_path , path )
8586
8687 print (f"The generated packaged model can be found at: { path } " )
@@ -102,13 +103,13 @@ def export_for_server(
102103 from typing import Any , Dict , Tuple , Union
103104
104105 import executorch .exir as exir
106+ from executorch .backends .xnnpack ._passes .convert_to_linear import (
107+ ConvertToLinearPass ,
108+ )
105109
106110 from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
107111 XnnpackDynamicallyQuantizedPartitioner ,
108112 )
109- from executorch .backends .xnnpack ._passes .convert_to_linear import (
110- ConvertToLinearPass ,
111- )
112113 from executorch .exir import EdgeProgramManager , to_edge
113114
114115 from executorch .exir .capture ._config import (
@@ -166,18 +167,22 @@ def __init__(self, attention: Attention):
166167
167168 self .wo = attention .wo
168169
169- max_batch_size , n_heads , max_seq_length , head_dim = (
170- attention . kv_cache [ 0 ]. k_cache . shape
171- )
170+ max_batch_size , n_heads , max_seq_length , head_dim = attention . kv_cache [
171+ 0
172+ ]. k_cache . shape
172173 cache_dtype = attention .kv_cache [0 ].k_cache .dtype
173174 # The `Attention` module being replaced can have multiple KV caches
174175 # (denoted by `cache_lanes`). Thus we follow the same setup format
175176 # as in `Attention.setup_cache`.
176177 cache_lanes = len (attention .kv_cache )
177- self .kv_cache = nn .ModuleList ([
178- CustomKVCache (max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype )
179- for _ in range (cache_lanes )
180- ])
178+ self .kv_cache = nn .ModuleList (
179+ [
180+ CustomKVCache (
181+ max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype
182+ )
183+ for _ in range (cache_lanes )
184+ ]
185+ )
181186
182187 self .n_heads = attention .n_heads
183188 self .head_dim = attention .head_dim
@@ -215,9 +220,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
215220 return self .wo (output )
216221
217222 def replace_attention_with_custom_sdpa_attention (module : nn .Module ):
218- from executorch .extension .llm .custom_ops import ( # noqa
219- sdpa_with_kv_cache ,
220- )
223+ from executorch .extension .llm .custom_ops import sdpa_with_kv_cache # noqa
221224
222225 for name , child in module .named_children ():
223226 if isinstance (child , Attention ):
@@ -350,7 +353,11 @@ def main(args):
350353
351354 print (f"Using device={ builder_args .device } " )
352355 set_precision (builder_args .precision )
353- set_backend (dso = args .output_dso_path , pte = args .output_pte_path , aoti_package = args .output_aoti_package_path )
356+ set_backend (
357+ dso = args .output_dso_path ,
358+ pte = args .output_pte_path ,
359+ aoti_package = args .output_aoti_package_path ,
360+ )
354361
355362 builder_args .dso_path = None
356363 builder_args .pte_path = None
@@ -372,6 +379,7 @@ def main(args):
372379
373380 # TODO: clean this up
374381 # This mess is because ET does not support _weight_int4pack_mm right now
382+ tokenizer_args = None
375383 if not builder_args .gguf_path :
376384 # tokenizer needed for quantization so get that here,
377385 try :
@@ -382,9 +390,8 @@ def main(args):
382390
383391 if builder_args .max_seq_length is None :
384392 if (
385- (output_dso_path is not None or output_aoti_package_path is not None )
386- and not builder_args .dynamic_shapes
387- ):
393+ output_dso_path is not None or output_aoti_package_path is not None
394+ ) and not builder_args .dynamic_shapes :
388395 print ("Setting max_seq_length to 300 for DSO export." )
389396 builder_args .max_seq_length = 300
390397 elif output_pte_path is not None :
@@ -397,7 +404,8 @@ def main(args):
397404 quantize ,
398405 tokenizer ,
399406 max_seq_length = builder_args .max_seq_length ,
400- support_tensor_subclass = output_dso_path is None and output_aoti_package_path is None ,
407+ support_tensor_subclass = output_dso_path is None
408+ and output_aoti_package_path is None ,
401409 )
402410 model_to_pte = model
403411 model_to_dso = model
@@ -435,7 +443,9 @@ def main(args):
435443 if output_dso_path :
436444 output_dso_path = str (os .path .abspath (output_dso_path ))
437445 print (f"Exporting model using AOT Inductor to { output_dso_path } " )
438- print ("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead." )
446+ print (
447+ "WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
448+ )
439449 export_for_server (
440450 model_to_dso ,
441451 builder_args .device ,
@@ -446,11 +456,21 @@ def main(args):
446456
447457 if output_aoti_package_path :
448458 output_aoti_package_path = str (os .path .abspath (output_aoti_package_path ))
459+
460+ if tokenizer_args is None :
461+ tokenizer_type = "0"
462+ elif tokenizer_args .is_sentencepiece :
463+ tokenizer_type = "2" # Corresponding to llama2
464+ else :
465+ tokenizer_type = "3" # Corresponding to llama3
466+
467+ metadata = {"tokenizer_type" : tokenizer_type }
449468 print (f"Exporting model using AOT Inductor to { output_aoti_package_path } " )
450469 export_for_server (
451470 model_to_aoti_package ,
452471 builder_args .device ,
453472 output_aoti_package_path ,
454473 builder_args .dynamic_shapes ,
455474 package = True ,
475+ metadata = metadata ,
456476 )
0 commit comments