Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2d77fa5

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 0f2849b commit 2d77fa5

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

torchchat/export.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
"""
31-
Export for Server
31+
Export for Server
3232
"""
3333

3434

@@ -78,7 +78,7 @@ def export_for_server(
7878
"""
7979
Export for ExecuTorch
8080
81-
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
81+
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
8282
replace_attention_with_custom_sdpa_attention with ET's implementation
8383
"""
8484

@@ -94,6 +94,9 @@ def export_for_server(
9494
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
9595
XnnpackDynamicallyQuantizedPartitioner,
9696
)
97+
from executorch.backends.xnnpack.passes.convert_to_linear import (
98+
ConvertToLinearPass,
99+
)
97100
from executorch.exir import EdgeProgramManager, to_edge
98101

99102
from executorch.exir.capture._config import (
@@ -194,7 +197,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
194197
return self.wo(output)
195198

196199
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
197-
from executorch.examples.models.llama2.custom_ops import ( # noqa
200+
from executorch.extension.llm.custom_ops import ( # noqa
198201
sdpa_with_kv_cache,
199202
)
200203

@@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
274277
_skip_type_promotion=bool(target_precision == torch.float16),
275278
)
276279

277-
if target_precision == torch.float16 or target_precision == torch.bfloat16:
278-
if state_dict_dtype != torch.float16:
279-
print("model.to torch.float16")
280-
model = model.to(dtype=torch.float16)
281-
state_dict_dtype = torch.float16
282-
elif target_precision == torch.float32:
283-
if state_dict_dtype != torch.float32:
284-
print("model.to torch.float32")
285-
model = model.to(dtype=torch.float32)
286-
elif target_precision == torch.bfloat16:
287-
print("model.to torch.bfloat16")
288-
model = model.to(dtype=torch.bfloat16)
289-
else:
280+
if target_precision not in (torch.float16, torch.float32, torch.bfloat16):
290281
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")
291282

292-
replace_attention_with_custom_sdpa_attention(model)
283+
if state_dict_dtype != target_precision:
284+
print(f"model.to {target_precision}")
285+
model = model.to(dtype=target_precision)
286+
state_dict_dtype = target_precision
287+
288+
# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
289+
# support anything but bfloat32, and our attempt to use it anyway by converting
290+
# to and from float causes other errors.)
291+
if target_precision != torch.bfloat16:
292+
replace_attention_with_custom_sdpa_attention(model)
293+
293294
with torch.nn.attention.sdpa_kernel(
294295
[torch.nn.attention.SDPBackend.MATH]
295296
), torch.no_grad():
@@ -304,9 +305,9 @@ def export_for_et(model, device, output_path) -> str:
304305
edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner())
305306
export_program = edge_manager.to_executorch(
306307
ExecutorchBackendConfig(
307-
extract_constant_segment=True,
308308
extract_delegate_segments=True,
309309
passes=[
310+
ConvertToLinearPass(),
310311
QuantFusionPass(),
311312
],
312313
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),

torchchat/utils/scripts/install_utils.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ install_executorch() {
147147
-DEXECUTORCH_BUILD_XNNPACK=ON \
148148
${CROSS_COMPILE_ARGS} \
149149
-S . -B ${CMAKE_OUT_DIR} -G Ninja
150-
cmake --build ${CMAKE_OUT_DIR}
150+
cmake --build ${CMAKE_OUT_DIR} -j16
151151
cmake --install ${CMAKE_OUT_DIR} --prefix ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install
152152
popd
153153
}

0 commit comments

Comments
 (0)