diff --git a/torchchat/export.py b/torchchat/export.py index 21e7fcaa8..6b06f1df1 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -28,7 +28,7 @@ """ -Export for Server +Export for Server """ @@ -78,7 +78,7 @@ def export_for_server( """ Export for ExecuTorch -TODO (https://github.com/pytorch/torchchat/issues/1058): Replace +TODO (https://github.com/pytorch/torchchat/issues/1058): Replace replace_attention_with_custom_sdpa_attention with ET's implementation """ @@ -94,6 +94,9 @@ def export_for_server( from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, ) + from executorch.backends.xnnpack.passes.convert_to_linear import ( + ConvertToLinearPass, + ) from executorch.exir import EdgeProgramManager, to_edge from executorch.exir.capture._config import ( @@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str: _skip_type_promotion=bool(target_precision == torch.float16), ) - if target_precision == torch.float16 or target_precision == torch.bfloat16: - if state_dict_dtype != torch.float16: - print("model.to torch.float16") - model = model.to(dtype=torch.float16) - state_dict_dtype = torch.float16 - elif target_precision == torch.float32: - if state_dict_dtype != torch.float32: - print("model.to torch.float32") - model = model.to(dtype=torch.float32) - elif target_precision == torch.bfloat16: - print("model.to torch.bfloat16") - model = model.to(dtype=torch.bfloat16) - else: + if target_precision not in (torch.float16, torch.float32, torch.bfloat16): raise ValueError(f"Unsupported dtype for ET export: {target_precision}") - replace_attention_with_custom_sdpa_attention(model) + if state_dict_dtype != target_precision: + print(f"model.to {target_precision}") + model = model.to(dtype=target_precision) + state_dict_dtype = target_precision + + # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't + # support anything but bfloat32, and our attempt to use it anyway by converting + # to and from float causes other errors.) + if target_precision != torch.bfloat16: + replace_attention_with_custom_sdpa_attention(model) + with torch.nn.attention.sdpa_kernel( [torch.nn.attention.SDPBackend.MATH] ), torch.no_grad(): @@ -306,6 +307,7 @@ def export_for_et(model, device, output_path) -> str: ExecutorchBackendConfig( extract_delegate_segments=True, passes=[ + ConvertToLinearPass(), QuantFusionPass(), ], sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),