|
25 | 25 |
|
26 | 26 | from torch.export import Dim |
27 | 27 |
|
28 | | -try: |
29 | | - executorch_export_available = True |
30 | | - from export_util.export_et import export_model as export_model_et |
31 | | -except Exception as e: |
32 | | - executorch_exception = f"ET EXPORT EXCEPTION: {e}" |
33 | | - executorch_export_available = False |
34 | | - |
35 | 28 |
|
36 | 29 | default_device = "cpu" |
37 | 30 |
|
38 | 31 |
|
| 32 | +""" |
| 33 | +Export for Server |
| 34 | +""" |
| 35 | + |
| 36 | + |
39 | 37 | def export_for_server( |
40 | 38 | model: nn.Module, |
41 | 39 | device: Optional[str] = "cpu", |
@@ -79,6 +77,260 @@ def export_for_server( |
79 | 77 | return so |
80 | 78 |
|
81 | 79 |
|
| 80 | +""" |
| 81 | +Export for ExecuTorch |
| 82 | +
|
| 83 | +TODO (https://github.com/pytorch/torchchat/issues/1058): Replace |
| 84 | +replace_attention_with_custom_sdpa_attention with ET's implementation |
| 85 | +""" |
| 86 | + |
| 87 | +try: |
| 88 | + executorch_export_available = True |
| 89 | + |
| 90 | + import logging |
| 91 | + |
| 92 | + from typing import Any, Dict, Tuple, Union |
| 93 | + |
| 94 | + import executorch.exir as exir |
| 95 | + |
| 96 | + from build.model import apply_rotary_emb, Attention |
| 97 | + from build.utils import get_precision |
| 98 | + |
| 99 | + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( |
| 100 | + XnnpackDynamicallyQuantizedPartitioner, |
| 101 | + ) |
| 102 | + from executorch.exir import EdgeProgramManager, to_edge |
| 103 | + |
| 104 | + from executorch.exir.capture._config import ( |
| 105 | + EdgeCompileConfig, |
| 106 | + ExecutorchBackendConfig, |
| 107 | + ) |
| 108 | + from executorch.exir.passes.quant_fusion_pass import QuantFusionPass |
| 109 | + from executorch.exir.passes.sym_shape_eval_pass import ( |
| 110 | + ConstraintBasedSymShapeEvalPass, |
| 111 | + ) |
| 112 | + from executorch.exir.tracer import Value |
| 113 | + |
| 114 | + from torch._export import capture_pre_autograd_graph |
| 115 | + from torch.export import export, ExportedProgram |
| 116 | + |
| 117 | + default_device = "cpu" |
| 118 | + |
| 119 | + _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( |
| 120 | + _check_ir_validity=True, |
| 121 | + ) |
| 122 | + |
| 123 | + class CustomKVCache(nn.Module): |
| 124 | + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype): |
| 125 | + super().__init__() |
| 126 | + |
| 127 | + dtype = torch.float |
| 128 | + |
| 129 | + # This is flipped around from what is in build.model's KVCache |
| 130 | + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) |
| 131 | + self.register_buffer( |
| 132 | + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") |
| 133 | + ) |
| 134 | + self.register_buffer( |
| 135 | + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") |
| 136 | + ) |
| 137 | + |
| 138 | + def update(self, input_pos, k_val, v_val): |
| 139 | + k_out = self.k_cache |
| 140 | + v_out = self.v_cache |
| 141 | + k_out[:, :, input_pos] = k_val.float() |
| 142 | + v_out[:, :, input_pos] = v_val.float() |
| 143 | + |
| 144 | + return k_out, v_out |
| 145 | + |
| 146 | + class CustomSDPAAttention(nn.Module): |
| 147 | + def __init__(self, attention: Attention): |
| 148 | + super().__init__() |
| 149 | + |
| 150 | + self.wq = attention.wq |
| 151 | + self.wk = attention.wk |
| 152 | + self.wv = attention.wv |
| 153 | + |
| 154 | + self.wo = attention.wo |
| 155 | + |
| 156 | + max_batch_size, n_heads, max_seq_length, head_dim = ( |
| 157 | + attention.kv_cache.k_cache.shape |
| 158 | + ) |
| 159 | + cache_dtype = attention.kv_cache.k_cache.dtype |
| 160 | + self.kv_cache = CustomKVCache( |
| 161 | + max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype |
| 162 | + ) |
| 163 | + |
| 164 | + self.n_heads = attention.n_heads |
| 165 | + self.head_dim = attention.head_dim |
| 166 | + self.n_local_heads = attention.n_local_heads |
| 167 | + self.dim = attention.dim |
| 168 | + |
| 169 | + def forward(self, x, freqs_cis, mask, input_pos=None): |
| 170 | + bsz, seqlen, _ = x.shape |
| 171 | + |
| 172 | + q = self.wq(x) |
| 173 | + k = self.wk(x) |
| 174 | + v = self.wv(x) |
| 175 | + |
| 176 | + q = q.view(bsz, seqlen, self.n_heads, self.head_dim) |
| 177 | + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 178 | + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 179 | + |
| 180 | + q = apply_rotary_emb(q, freqs_cis).to(dtype=torch.float) |
| 181 | + k = apply_rotary_emb(k, freqs_cis).to(dtype=torch.float) |
| 182 | + v = v.to(dtype=torch.float) |
| 183 | + |
| 184 | + # KV cache should always be enabled |
| 185 | + assert self.kv_cache is not None |
| 186 | + output = torch.ops.llama.sdpa_with_kv_cache( |
| 187 | + q, |
| 188 | + k, |
| 189 | + v, |
| 190 | + self.kv_cache.k_cache, |
| 191 | + self.kv_cache.v_cache, |
| 192 | + input_pos[-1].item(), |
| 193 | + seqlen, |
| 194 | + ) |
| 195 | + output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype) |
| 196 | + return self.wo(output) |
| 197 | + |
| 198 | + def replace_attention_with_custom_sdpa_attention(module: nn.Module): |
| 199 | + from executorch.examples.models.llama2.custom_ops import ( # noqa |
| 200 | + sdpa_with_kv_cache, |
| 201 | + ) |
| 202 | + |
| 203 | + for name, child in module.named_children(): |
| 204 | + if isinstance(child, Attention): |
| 205 | + setattr(module, name, CustomSDPAAttention(child)) |
| 206 | + else: |
| 207 | + replace_attention_with_custom_sdpa_attention(child) |
| 208 | + |
| 209 | + def _to_core_aten( |
| 210 | + model: Union[torch.fx.GraphModule, torch.nn.Module], |
| 211 | + example_inputs: Tuple[Value, ...], |
| 212 | + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, |
| 213 | + verbose=True, |
| 214 | + ) -> ExportedProgram: |
| 215 | + # post autograd export. eventually this will become .to_core_aten |
| 216 | + if not isinstance(model, torch.fx.GraphModule) and not isinstance( |
| 217 | + model, torch.nn.Module |
| 218 | + ): |
| 219 | + raise ValueError( |
| 220 | + f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}" |
| 221 | + ) |
| 222 | + core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes) |
| 223 | + if verbose: |
| 224 | + logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") |
| 225 | + return core_aten_ep |
| 226 | + |
| 227 | + def _core_aten_to_edge( |
| 228 | + core_aten_exir_ep: ExportedProgram, |
| 229 | + edge_constant_methods: Optional[Dict[str, Any]] = None, |
| 230 | + edge_compile_config=None, |
| 231 | + verbose=True, |
| 232 | + ) -> EdgeProgramManager: |
| 233 | + if not edge_compile_config: |
| 234 | + edge_compile_config = exir.EdgeCompileConfig( |
| 235 | + _check_ir_validity=False, # quant ops currently break ir verification |
| 236 | + ) |
| 237 | + edge_manager: EdgeProgramManager = to_edge( |
| 238 | + core_aten_exir_ep, |
| 239 | + constant_methods=edge_constant_methods, |
| 240 | + compile_config=edge_compile_config, |
| 241 | + ) |
| 242 | + if verbose: |
| 243 | + logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}") |
| 244 | + return edge_manager |
| 245 | + |
| 246 | + def export_to_edge( |
| 247 | + model: Union[torch.fx.GraphModule, torch.nn.Module], |
| 248 | + example_inputs: Tuple[Value, ...], |
| 249 | + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, |
| 250 | + edge_constant_methods: Optional[Dict[str, Any]] = None, |
| 251 | + edge_compile_config=_EDGE_COMPILE_CONFIG, |
| 252 | + verbose=True, |
| 253 | + ) -> EdgeProgramManager: |
| 254 | + core_aten_ep = _to_core_aten( |
| 255 | + model, example_inputs, dynamic_shapes, verbose=verbose |
| 256 | + ) |
| 257 | + return _core_aten_to_edge( |
| 258 | + core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose |
| 259 | + ) |
| 260 | + |
| 261 | + def export_for_et(model, device, output_path, args=None) -> str: # noqa: C901 |
| 262 | + |
| 263 | + input = ( |
| 264 | + torch.tensor([[1]], dtype=torch.long, device=device), |
| 265 | + torch.tensor([0], dtype=torch.long, device=device), |
| 266 | + ) |
| 267 | + |
| 268 | + state_dict = model.state_dict() |
| 269 | + state_dict_dtype = state_dict[next(iter(state_dict))].dtype |
| 270 | + target_precision = get_precision() |
| 271 | + dynamic_shapes = None |
| 272 | + |
| 273 | + # TODO: need to use kv sdpa? |
| 274 | + edge_config = EdgeCompileConfig( |
| 275 | + _check_ir_validity=False, |
| 276 | + _skip_type_promotion=bool(target_precision == torch.float16), |
| 277 | + ) |
| 278 | + |
| 279 | + if target_precision == torch.float16 or target_precision == torch.bfloat16: |
| 280 | + if state_dict_dtype != torch.float16: |
| 281 | + print("model.to torch.float16") |
| 282 | + model = model.to(dtype=torch.float16) |
| 283 | + state_dict_dtype = torch.float16 |
| 284 | + elif target_precision == torch.float32: |
| 285 | + if state_dict_dtype != torch.float32: |
| 286 | + print("model.to torch.float32") |
| 287 | + model = model.to(dtype=torch.float32) |
| 288 | + elif target_precision == torch.bfloat16: |
| 289 | + print("model.to torch.bfloat16") |
| 290 | + model = model.to(dtype=torch.bfloat16) |
| 291 | + else: |
| 292 | + raise ValueError(f"Unsupported dtype for ET export: {target_precision}") |
| 293 | + |
| 294 | + replace_attention_with_custom_sdpa_attention(model) |
| 295 | + with torch.nn.attention.sdpa_kernel( |
| 296 | + [torch.nn.attention.SDPBackend.MATH] |
| 297 | + ), torch.no_grad(): |
| 298 | + m = capture_pre_autograd_graph(model, input, dynamic_shapes=dynamic_shapes) |
| 299 | + |
| 300 | + edge_manager = export_to_edge( |
| 301 | + m, |
| 302 | + input, |
| 303 | + dynamic_shapes=dynamic_shapes, |
| 304 | + edge_compile_config=edge_config, |
| 305 | + ) |
| 306 | + edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner()) |
| 307 | + export_program = edge_manager.to_executorch( |
| 308 | + ExecutorchBackendConfig( |
| 309 | + extract_constant_segment=True, |
| 310 | + extract_delegate_segments=True, |
| 311 | + passes=[ |
| 312 | + QuantFusionPass(), |
| 313 | + ], |
| 314 | + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), |
| 315 | + ) |
| 316 | + ) |
| 317 | + |
| 318 | + print("The methods are: ", export_program.methods) |
| 319 | + with open(output_path, "wb") as f: |
| 320 | + export_program.write_to_file(f) |
| 321 | + |
| 322 | + return output_path |
| 323 | + |
| 324 | +except Exception as e: |
| 325 | + executorch_exception = f"ET EXPORT EXCEPTION: {e}" |
| 326 | + executorch_export_available = False |
| 327 | + |
| 328 | + |
| 329 | +""" |
| 330 | +Exporting Flow |
| 331 | +""" |
| 332 | + |
| 333 | + |
82 | 334 | def main(args): |
83 | 335 | builder_args = BuilderArgs.from_args(args) |
84 | 336 | quantize = args.quantize |
@@ -153,7 +405,7 @@ def main(args): |
153 | 405 | output_pte_path = str(os.path.abspath(output_pte_path)) |
154 | 406 | if executorch_export_available: |
155 | 407 | print(f"Exporting model using ExecuTorch to {output_pte_path}") |
156 | | - export_model_et( |
| 408 | + export_for_et( |
157 | 409 | model_to_pte, builder_args.device, args.output_pte_path, args |
158 | 410 | ) |
159 | 411 | else: |
|
0 commit comments