This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
[aoti] Remove need for -l in cmake call #1159
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,13 +5,13 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import os | ||
| from typing import Optional | ||
| from typing import Dict, Optional | ||
|
|
||
| import torch | ||
| import torch._inductor | ||
| import torch.nn as nn | ||
|
|
||
| from torch.export import Dim | ||
| import torch._inductor | ||
|
|
||
| from torchchat.cli.builder import ( | ||
| _initialize_model, | ||
|
|
@@ -39,6 +39,7 @@ def export_for_server( | |
| output_path: str = "model.pt2", | ||
| dynamic_shapes: bool = False, | ||
| package: bool = True, | ||
| metadata: Optional[Dict[str, str]] = None, | ||
| ) -> str: | ||
| """ | ||
| Export the model using AOT Compile to get a .dso for server use cases. | ||
|
|
@@ -67,8 +68,10 @@ def export_for_server( | |
| dynamic_shapes = None | ||
|
|
||
| with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): | ||
| metadata = {} # TODO: put more metadata here | ||
| options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata} | ||
| options = { | ||
| "aot_inductor.package": package, | ||
| "aot_inductor.metadata": metadata or {}, | ||
| } | ||
| if not package: | ||
| options = {"aot_inductor.output_path": output_path} | ||
|
|
||
|
|
@@ -81,6 +84,7 @@ def export_for_server( | |
|
|
||
| if package: | ||
| from torch._inductor.package import package_aoti | ||
|
|
||
| path = package_aoti(output_path, path) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, somehow my editor added on a bunch of formatting changes here.. hope it's not too confusing otherwise I can try to remove them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries, these are good lint fixes |
||
| print(f"The generated packaged model can be found at: {path}") | ||
|
|
@@ -102,13 +106,13 @@ def export_for_server( | |
| from typing import Any, Dict, Tuple, Union | ||
|
|
||
| import executorch.exir as exir | ||
| from executorch.backends.xnnpack._passes.convert_to_linear import ( | ||
| ConvertToLinearPass, | ||
| ) | ||
|
|
||
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( | ||
| XnnpackDynamicallyQuantizedPartitioner, | ||
| ) | ||
| from executorch.backends.xnnpack._passes.convert_to_linear import ( | ||
| ConvertToLinearPass, | ||
| ) | ||
| from executorch.exir import EdgeProgramManager, to_edge | ||
|
|
||
| from executorch.exir.capture._config import ( | ||
|
|
@@ -166,18 +170,22 @@ def __init__(self, attention: Attention): | |
|
|
||
| self.wo = attention.wo | ||
|
|
||
| max_batch_size, n_heads, max_seq_length, head_dim = ( | ||
| attention.kv_cache[0].k_cache.shape | ||
| ) | ||
| max_batch_size, n_heads, max_seq_length, head_dim = attention.kv_cache[ | ||
| 0 | ||
| ].k_cache.shape | ||
| cache_dtype = attention.kv_cache[0].k_cache.dtype | ||
| # The `Attention` module being replaced can have multiple KV caches | ||
| # (denoted by `cache_lanes`). Thus we follow the same setup format | ||
| # as in `Attention.setup_cache`. | ||
| cache_lanes = len(attention.kv_cache) | ||
| self.kv_cache = nn.ModuleList([ | ||
| CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype) | ||
| for _ in range(cache_lanes) | ||
| ]) | ||
| self.kv_cache = nn.ModuleList( | ||
| [ | ||
| CustomKVCache( | ||
| max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype | ||
| ) | ||
| for _ in range(cache_lanes) | ||
| ] | ||
| ) | ||
|
|
||
| self.n_heads = attention.n_heads | ||
| self.head_dim = attention.head_dim | ||
|
|
@@ -215,9 +223,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0): | |
| return self.wo(output) | ||
|
|
||
| def replace_attention_with_custom_sdpa_attention(module: nn.Module): | ||
| from executorch.extension.llm.custom_ops import ( # noqa | ||
| sdpa_with_kv_cache, | ||
| ) | ||
| from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa | ||
|
|
||
| for name, child in module.named_children(): | ||
| if isinstance(child, Attention): | ||
|
|
@@ -238,7 +244,9 @@ def _to_core_aten( | |
| raise ValueError( | ||
| f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}" | ||
| ) | ||
| core_aten_ep = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes) | ||
| core_aten_ep = export_for_training( | ||
| model, example_inputs, dynamic_shapes=dynamic_shapes | ||
| ) | ||
| if verbose: | ||
| logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") | ||
| return core_aten_ep | ||
|
|
@@ -350,7 +358,11 @@ def main(args): | |
|
|
||
| print(f"Using device={builder_args.device}") | ||
| set_precision(builder_args.precision) | ||
| set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path) | ||
| set_backend( | ||
| dso=args.output_dso_path, | ||
| pte=args.output_pte_path, | ||
| aoti_package=args.output_aoti_package_path, | ||
| ) | ||
|
|
||
| builder_args.dso_path = None | ||
| builder_args.pte_path = None | ||
|
|
@@ -372,6 +384,7 @@ def main(args): | |
|
|
||
| # TODO: clean this up | ||
| # This mess is because ET does not support _weight_int4pack_mm right now | ||
| tokenizer_args = None | ||
| if not builder_args.gguf_path: | ||
| # tokenizer needed for quantization so get that here, | ||
| try: | ||
|
|
@@ -382,9 +395,8 @@ def main(args): | |
|
|
||
| if builder_args.max_seq_length is None: | ||
| if ( | ||
| (output_dso_path is not None or output_aoti_package_path is not None) | ||
| and not builder_args.dynamic_shapes | ||
| ): | ||
| output_dso_path is not None or output_aoti_package_path is not None | ||
| ) and not builder_args.dynamic_shapes: | ||
| print("Setting max_seq_length to 300 for DSO export.") | ||
| builder_args.max_seq_length = 300 | ||
| elif output_pte_path is not None: | ||
|
|
@@ -397,7 +409,8 @@ def main(args): | |
| quantize, | ||
| tokenizer, | ||
| max_seq_length=builder_args.max_seq_length, | ||
| support_tensor_subclass=output_dso_path is None and output_aoti_package_path is None, | ||
| support_tensor_subclass=output_dso_path is None | ||
| and output_aoti_package_path is None, | ||
| ) | ||
| model_to_pte = model | ||
| model_to_dso = model | ||
|
|
@@ -435,7 +448,9 @@ def main(args): | |
| if output_dso_path: | ||
| output_dso_path = str(os.path.abspath(output_dso_path)) | ||
| print(f"Exporting model using AOT Inductor to {output_dso_path}") | ||
| print("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead.") | ||
| print( | ||
| "WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead." | ||
| ) | ||
| export_for_server( | ||
| model_to_dso, | ||
| builder_args.device, | ||
|
|
@@ -446,11 +461,23 @@ def main(args): | |
|
|
||
| if output_aoti_package_path: | ||
| output_aoti_package_path = str(os.path.abspath(output_aoti_package_path)) | ||
| print(f"Exporting model using AOT Inductor to {output_aoti_package_path}") | ||
|
|
||
| if tokenizer_args is None: | ||
| tokenizer_type = "0" | ||
| elif tokenizer_args.is_sentencepiece: | ||
| tokenizer_type = "2" # Corresponding to llama2 | ||
| else: | ||
| tokenizer_type = "3" # Corresponding to llama3 | ||
|
|
||
| metadata = {"tokenizer_type": tokenizer_type} | ||
| print( | ||
| "Exporting model using AOT Inductor to " f"{output_aoti_package_path}." | ||
| ) | ||
| export_for_server( | ||
| model_to_aoti_package, | ||
| builder_args.device, | ||
| output_aoti_package_path, | ||
| builder_args.dynamic_shapes, | ||
| package=True, | ||
| metadata=metadata, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can delete
metadata = metadata or {}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general advice is to not have mutable structures as default args because they survive invocations =>
https://docs.python-guide.org/writing/gotchas/