Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion install/.pins/et-pin.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
af098c31b6f8d5f38e40a5cf35784b0969d97df8
286799c9c844ce6427b8eca260f9b2f28be03291
34 changes: 18 additions & 16 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


"""
Export for Server
Export for Server
"""


Expand Down Expand Up @@ -78,7 +78,7 @@ def export_for_server(
"""
Export for ExecuTorch

TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
replace_attention_with_custom_sdpa_attention with ET's implementation
"""

Expand All @@ -94,6 +94,9 @@ def export_for_server(
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
)
from executorch.backends.xnnpack.passes.convert_to_linear import (
ConvertToLinearPass,
)
from executorch.exir import EdgeProgramManager, to_edge

from executorch.exir.capture._config import (
Expand Down Expand Up @@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
_skip_type_promotion=bool(target_precision == torch.float16),
)

if target_precision == torch.float16 or target_precision == torch.bfloat16:
if state_dict_dtype != torch.float16:
print("model.to torch.float16")
model = model.to(dtype=torch.float16)
state_dict_dtype = torch.float16
elif target_precision == torch.float32:
if state_dict_dtype != torch.float32:
print("model.to torch.float32")
model = model.to(dtype=torch.float32)
elif target_precision == torch.bfloat16:
print("model.to torch.bfloat16")
model = model.to(dtype=torch.bfloat16)
else:
if target_precision not in (torch.float16, torch.float32, torch.bfloat16):
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")

replace_attention_with_custom_sdpa_attention(model)
if state_dict_dtype != target_precision:
print(f"model.to {target_precision}")
model = model.to(dtype=target_precision)
state_dict_dtype = target_precision

# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
# support anything but bfloat32, and our attempt to use it anyway by converting
# to and from float causes other errors.)
if target_precision != torch.bfloat16:
replace_attention_with_custom_sdpa_attention(model)

with torch.nn.attention.sdpa_kernel(
[torch.nn.attention.SDPBackend.MATH]
), torch.no_grad():
Expand All @@ -306,6 +307,7 @@ def export_for_et(model, device, output_path) -> str:
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
ConvertToLinearPass(),
QuantFusionPass(),
],
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
Expand Down
2 changes: 1 addition & 1 deletion torchchat/utils/scripts/install_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ install_executorch_cpp_libs() {
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=${EXECUTORCH_BUILD_KERNELS_CUSTOM_VAR} \
${CROSS_COMPILE_ARGS} \
-S . -B ${CMAKE_OUT_DIR} -G Ninja
cmake --build ${CMAKE_OUT_DIR}
cmake --build ${CMAKE_OUT_DIR} -j16
cmake --install ${CMAKE_OUT_DIR} --prefix ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install
popd
}
Expand Down
Loading