Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5a7b137

Browse files
committed
Update
[ghstack-poisoned]
1 parent 8bb9743 commit 5a7b137

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

torchchat/export.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
"""
31-
Export for Server
31+
Export for Server
3232
"""
3333

3434

@@ -78,7 +78,7 @@ def export_for_server(
7878
"""
7979
Export for ExecuTorch
8080
81-
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
81+
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
8282
replace_attention_with_custom_sdpa_attention with ET's implementation
8383
"""
8484

@@ -94,6 +94,9 @@ def export_for_server(
9494
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
9595
XnnpackDynamicallyQuantizedPartitioner,
9696
)
97+
from executorch.backends.xnnpack.passes.convert_to_linear import (
98+
ConvertToLinearPass,
99+
)
97100
from executorch.exir import EdgeProgramManager, to_edge
98101

99102
from executorch.exir.capture._config import (
@@ -194,7 +197,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
194197
return self.wo(output)
195198

196199
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
197-
from executorch.examples.models.llama2.custom_ops import ( # noqa
200+
from executorch.extension.llm.custom_ops import ( # noqa
198201
sdpa_with_kv_cache,
199202
)
200203

@@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
274277
_skip_type_promotion=bool(target_precision == torch.float16),
275278
)
276279

277-
if target_precision == torch.float16 or target_precision == torch.bfloat16:
278-
if state_dict_dtype != torch.float16:
279-
print("model.to torch.float16")
280-
model = model.to(dtype=torch.float16)
281-
state_dict_dtype = torch.float16
282-
elif target_precision == torch.float32:
283-
if state_dict_dtype != torch.float32:
284-
print("model.to torch.float32")
285-
model = model.to(dtype=torch.float32)
286-
elif target_precision == torch.bfloat16:
287-
print("model.to torch.bfloat16")
288-
model = model.to(dtype=torch.bfloat16)
289-
else:
280+
if target_precision not in (torch.float16, torch.float32, torch.bfloat16):
290281
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")
291282

292-
replace_attention_with_custom_sdpa_attention(model)
283+
if state_dict_dtype != target_precision:
284+
print(f"model.to {target_precision}")
285+
model = model.to(dtype=target_precision)
286+
state_dict_dtype = target_precision
287+
288+
# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
289+
# support anything but bfloat32, and our attempt to use it anyway by converting
290+
# to and from float causes other errors.)
291+
if target_precision != torch.bfloat16:
292+
replace_attention_with_custom_sdpa_attention(model)
293+
293294
with torch.nn.attention.sdpa_kernel(
294295
[torch.nn.attention.SDPBackend.MATH]
295296
), torch.no_grad():
@@ -304,9 +305,9 @@ def export_for_et(model, device, output_path) -> str:
304305
edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner())
305306
export_program = edge_manager.to_executorch(
306307
ExecutorchBackendConfig(
307-
extract_constant_segment=True,
308308
extract_delegate_segments=True,
309309
passes=[
310+
ConvertToLinearPass(),
310311
QuantFusionPass(),
311312
],
312313
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),

0 commit comments

Comments
 (0)