88from typing import Optional
99
1010import torch
11+ import torch ._inductor
1112import torch .nn as nn
1213
1314from torch .export import Dim
14- import torch ._inductor
1515
1616from torchchat .cli .builder import (
1717 _initialize_model ,
@@ -68,20 +68,24 @@ def export_for_server(
6868
6969 with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
7070 metadata = {} # TODO: put more metadata here
71- options = {"aot_inductor.package" : package , "aot_inductor. metadata" : metadata }
71+ options = {"aot_inductor.metadata" : metadata }
7272 if not package :
7373 options = {"aot_inductor.output_path" : output_path }
7474
75- path = torch ._export . aot_compile (
75+ ep = torch .export . export (
7676 model ,
7777 example_inputs ,
7878 dynamic_shapes = dynamic_shapes ,
79- options = options ,
8079 )
8180
8281 if package :
83- from torch ._inductor .package import package_aoti
84- path = package_aoti (output_path , path )
82+ path = torch ._inductor .aoti_compile_and_package (
83+ ep , package_path = output_path , inductor_configs = options
84+ )
85+ else :
86+ path = torch ._inductor .aot_compile (
87+ ep .module (), example_inputs , options = options
88+ )
8589
8690 print (f"The generated packaged model can be found at: { path } " )
8791 return path
@@ -106,9 +110,6 @@ def export_for_server(
106110 from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
107111 XnnpackDynamicallyQuantizedPartitioner ,
108112 )
109- from executorch .backends .xnnpack ._passes .convert_to_linear import (
110- ConvertToLinearPass ,
111- )
112113 from executorch .exir import EdgeProgramManager , to_edge
113114
114115 from executorch .exir .capture ._config import (
0 commit comments