Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7191b57

Browse files
authored
[Hackability Refactor] Collapse export_util into export.py (#1057)
* executorch_portable_utils: Delete Unused Fucntions * Collapse util files * Removed unused materialize_broadcast function * Collapse all export related utils into export.py * Remove Redundant Exports * Move export.py under torchchat/torchchat * Undo the commit moving the files; will do that in separate PR
1 parent 19a47e7 commit 7191b57

File tree

5 files changed

+261
-344
lines changed

5 files changed

+261
-344
lines changed

export.py

Lines changed: 260 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,15 @@
2525

2626
from torch.export import Dim
2727

28-
try:
29-
executorch_export_available = True
30-
from export_util.export_et import export_model as export_model_et
31-
except Exception as e:
32-
executorch_exception = f"ET EXPORT EXCEPTION: {e}"
33-
executorch_export_available = False
34-
3528

3629
default_device = "cpu"
3730

3831

32+
"""
33+
Export for Server
34+
"""
35+
36+
3937
def export_for_server(
4038
model: nn.Module,
4139
device: Optional[str] = "cpu",
@@ -79,6 +77,260 @@ def export_for_server(
7977
return so
8078

8179

80+
"""
81+
Export for ExecuTorch
82+
83+
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
84+
replace_attention_with_custom_sdpa_attention with ET's implementation
85+
"""
86+
87+
try:
88+
executorch_export_available = True
89+
90+
import logging
91+
92+
from typing import Any, Dict, Tuple, Union
93+
94+
import executorch.exir as exir
95+
96+
from build.model import apply_rotary_emb, Attention
97+
from build.utils import get_precision
98+
99+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
100+
XnnpackDynamicallyQuantizedPartitioner,
101+
)
102+
from executorch.exir import EdgeProgramManager, to_edge
103+
104+
from executorch.exir.capture._config import (
105+
EdgeCompileConfig,
106+
ExecutorchBackendConfig,
107+
)
108+
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
109+
from executorch.exir.passes.sym_shape_eval_pass import (
110+
ConstraintBasedSymShapeEvalPass,
111+
)
112+
from executorch.exir.tracer import Value
113+
114+
from torch._export import capture_pre_autograd_graph
115+
from torch.export import export, ExportedProgram
116+
117+
default_device = "cpu"
118+
119+
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
120+
_check_ir_validity=True,
121+
)
122+
123+
class CustomKVCache(nn.Module):
124+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
125+
super().__init__()
126+
127+
dtype = torch.float
128+
129+
# This is flipped around from what is in build.model's KVCache
130+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
131+
self.register_buffer(
132+
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
133+
)
134+
self.register_buffer(
135+
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
136+
)
137+
138+
def update(self, input_pos, k_val, v_val):
139+
k_out = self.k_cache
140+
v_out = self.v_cache
141+
k_out[:, :, input_pos] = k_val.float()
142+
v_out[:, :, input_pos] = v_val.float()
143+
144+
return k_out, v_out
145+
146+
class CustomSDPAAttention(nn.Module):
147+
def __init__(self, attention: Attention):
148+
super().__init__()
149+
150+
self.wq = attention.wq
151+
self.wk = attention.wk
152+
self.wv = attention.wv
153+
154+
self.wo = attention.wo
155+
156+
max_batch_size, n_heads, max_seq_length, head_dim = (
157+
attention.kv_cache.k_cache.shape
158+
)
159+
cache_dtype = attention.kv_cache.k_cache.dtype
160+
self.kv_cache = CustomKVCache(
161+
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
162+
)
163+
164+
self.n_heads = attention.n_heads
165+
self.head_dim = attention.head_dim
166+
self.n_local_heads = attention.n_local_heads
167+
self.dim = attention.dim
168+
169+
def forward(self, x, freqs_cis, mask, input_pos=None):
170+
bsz, seqlen, _ = x.shape
171+
172+
q = self.wq(x)
173+
k = self.wk(x)
174+
v = self.wv(x)
175+
176+
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
177+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
178+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
179+
180+
q = apply_rotary_emb(q, freqs_cis).to(dtype=torch.float)
181+
k = apply_rotary_emb(k, freqs_cis).to(dtype=torch.float)
182+
v = v.to(dtype=torch.float)
183+
184+
# KV cache should always be enabled
185+
assert self.kv_cache is not None
186+
output = torch.ops.llama.sdpa_with_kv_cache(
187+
q,
188+
k,
189+
v,
190+
self.kv_cache.k_cache,
191+
self.kv_cache.v_cache,
192+
input_pos[-1].item(),
193+
seqlen,
194+
)
195+
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
196+
return self.wo(output)
197+
198+
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
199+
from executorch.examples.models.llama2.custom_ops import ( # noqa
200+
sdpa_with_kv_cache,
201+
)
202+
203+
for name, child in module.named_children():
204+
if isinstance(child, Attention):
205+
setattr(module, name, CustomSDPAAttention(child))
206+
else:
207+
replace_attention_with_custom_sdpa_attention(child)
208+
209+
def _to_core_aten(
210+
model: Union[torch.fx.GraphModule, torch.nn.Module],
211+
example_inputs: Tuple[Value, ...],
212+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
213+
verbose=True,
214+
) -> ExportedProgram:
215+
# post autograd export. eventually this will become .to_core_aten
216+
if not isinstance(model, torch.fx.GraphModule) and not isinstance(
217+
model, torch.nn.Module
218+
):
219+
raise ValueError(
220+
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
221+
)
222+
core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes)
223+
if verbose:
224+
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
225+
return core_aten_ep
226+
227+
def _core_aten_to_edge(
228+
core_aten_exir_ep: ExportedProgram,
229+
edge_constant_methods: Optional[Dict[str, Any]] = None,
230+
edge_compile_config=None,
231+
verbose=True,
232+
) -> EdgeProgramManager:
233+
if not edge_compile_config:
234+
edge_compile_config = exir.EdgeCompileConfig(
235+
_check_ir_validity=False, # quant ops currently break ir verification
236+
)
237+
edge_manager: EdgeProgramManager = to_edge(
238+
core_aten_exir_ep,
239+
constant_methods=edge_constant_methods,
240+
compile_config=edge_compile_config,
241+
)
242+
if verbose:
243+
logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}")
244+
return edge_manager
245+
246+
def export_to_edge(
247+
model: Union[torch.fx.GraphModule, torch.nn.Module],
248+
example_inputs: Tuple[Value, ...],
249+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
250+
edge_constant_methods: Optional[Dict[str, Any]] = None,
251+
edge_compile_config=_EDGE_COMPILE_CONFIG,
252+
verbose=True,
253+
) -> EdgeProgramManager:
254+
core_aten_ep = _to_core_aten(
255+
model, example_inputs, dynamic_shapes, verbose=verbose
256+
)
257+
return _core_aten_to_edge(
258+
core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
259+
)
260+
261+
def export_for_et(model, device, output_path, args=None) -> str: # noqa: C901
262+
263+
input = (
264+
torch.tensor([[1]], dtype=torch.long, device=device),
265+
torch.tensor([0], dtype=torch.long, device=device),
266+
)
267+
268+
state_dict = model.state_dict()
269+
state_dict_dtype = state_dict[next(iter(state_dict))].dtype
270+
target_precision = get_precision()
271+
dynamic_shapes = None
272+
273+
# TODO: need to use kv sdpa?
274+
edge_config = EdgeCompileConfig(
275+
_check_ir_validity=False,
276+
_skip_type_promotion=bool(target_precision == torch.float16),
277+
)
278+
279+
if target_precision == torch.float16 or target_precision == torch.bfloat16:
280+
if state_dict_dtype != torch.float16:
281+
print("model.to torch.float16")
282+
model = model.to(dtype=torch.float16)
283+
state_dict_dtype = torch.float16
284+
elif target_precision == torch.float32:
285+
if state_dict_dtype != torch.float32:
286+
print("model.to torch.float32")
287+
model = model.to(dtype=torch.float32)
288+
elif target_precision == torch.bfloat16:
289+
print("model.to torch.bfloat16")
290+
model = model.to(dtype=torch.bfloat16)
291+
else:
292+
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")
293+
294+
replace_attention_with_custom_sdpa_attention(model)
295+
with torch.nn.attention.sdpa_kernel(
296+
[torch.nn.attention.SDPBackend.MATH]
297+
), torch.no_grad():
298+
m = capture_pre_autograd_graph(model, input, dynamic_shapes=dynamic_shapes)
299+
300+
edge_manager = export_to_edge(
301+
m,
302+
input,
303+
dynamic_shapes=dynamic_shapes,
304+
edge_compile_config=edge_config,
305+
)
306+
edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner())
307+
export_program = edge_manager.to_executorch(
308+
ExecutorchBackendConfig(
309+
extract_constant_segment=True,
310+
extract_delegate_segments=True,
311+
passes=[
312+
QuantFusionPass(),
313+
],
314+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
315+
)
316+
)
317+
318+
print("The methods are: ", export_program.methods)
319+
with open(output_path, "wb") as f:
320+
export_program.write_to_file(f)
321+
322+
return output_path
323+
324+
except Exception as e:
325+
executorch_exception = f"ET EXPORT EXCEPTION: {e}"
326+
executorch_export_available = False
327+
328+
329+
"""
330+
Exporting Flow
331+
"""
332+
333+
82334
def main(args):
83335
builder_args = BuilderArgs.from_args(args)
84336
quantize = args.quantize
@@ -153,7 +405,7 @@ def main(args):
153405
output_pte_path = str(os.path.abspath(output_pte_path))
154406
if executorch_export_available:
155407
print(f"Exporting model using ExecuTorch to {output_pte_path}")
156-
export_model_et(
408+
export_for_et(
157409
model_to_pte, builder_args.device, args.output_pte_path, args
158410
)
159411
else:

0 commit comments

Comments
 (0)