Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 247fd20

Browse files
committed
[aoti] Remove need for -l in cmake
1 parent 54455a3 commit 247fd20

File tree

3 files changed

+68
-57
lines changed

3 files changed

+68
-57
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ torchchat/utils/scripts/build_native.sh aoti
332332

333333
Then run the compiled executable, with the pt2.
334334
```bash
335-
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
335+
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -i "Once upon a time"
336336
```
337337

338338
## Mobile Execution

runner/run.cpp

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ LICENSE file in the root directory of this source tree.
3232

3333
#ifdef __AOTI_MODEL__
3434
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
35-
torch::Device aoti_device(torch::kCPU);
36-
3735
#else // __ET_MODEL__
3836
#include <executorch/extension/module/module.h>
3937
#include <executorch/extension/tensor/tensor_ptr.h>
@@ -89,9 +87,11 @@ typedef struct {
8987
typedef struct {
9088
Config config; // the hyperparameters of the architecture (the blueprint)
9189
RunState state; // buffers for the "wave" of activations in the forward pass
90+
std::unordered_map<std::string, std::string> metadata;
9291

9392
#ifdef __AOTI_MODEL__
9493
torch::inductor::AOTIModelPackageLoader* runner;
94+
9595
#else // __ET_MODEL__
9696
Module* runner;
9797
#endif
@@ -130,19 +130,9 @@ void read_checkpoint(char* checkpoint, Config* config) {
130130

131131
void build_transformer(
132132
Transformer* t,
133-
char* model_path,
134-
int vocab_size,
135-
int seq_len) {
136-
// read in the Config and the Weights from the model
137-
// read_checkpoint(model_path, &t->config);
138-
// allocate the RunState buffers
139-
t->config.vocab_size = vocab_size;
140-
t->config.seq_len = seq_len;
141-
malloc_run_state(&t->state, &t->config);
142-
133+
char* model_path) {
143134
#ifdef __AOTI_MODEL__
144135
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
145-
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
146136
#else //__ET_MODEL__
147137
t->runner = new Module(
148138
/* path to PTE model */ model_path,
@@ -194,6 +184,9 @@ float* forward(Transformer* transformer, int token, int pos) {
194184
torch::Tensor token_tensor =
195185
torch::from_blob(token_buffer, {1, 1}, torch::kLong);
196186
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
187+
torch::Device aoti_device = transformer->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu"
188+
? torch::Device(torch::kCPU)
189+
: torch::Device(torch::kCUDA);
197190
std::vector<torch::Tensor> inputs{
198191
token_tensor.to(aoti_device), pos_tensor.to(aoti_device)};
199192

@@ -895,26 +888,25 @@ int main(int argc, char* argv[]) {
895888
system_prompt = argv[i + 1];
896889
} else if (argv[i][1] == 'l') {
897890
llama_ver = atoi(argv[i + 1]);
898-
#ifdef __AOTI_MODEL__
899-
} else if (argv[i][1] == 'd') {
900-
#ifdef USE_CUDA
901-
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
902-
aoti_device = torch::Device(torch::kCUDA);
903-
} else
904-
#endif
905-
if (strcasecmp(argv[i + 1], "CPU") == 0) {
906-
aoti_device = torch::Device(torch::kCPU);
907-
} else {
908-
fprintf(stderr, "Unknown device %s", argv[i + 1]);
909-
exit(1);
910-
}
911-
#endif
912891
} else {
913892
error_usage();
914893
}
915894
}
916895

896+
if (model_path == NULL) {
897+
fprintf(stderr, "No model_path provided.");
898+
error_usage();
899+
}
900+
901+
Transformer transformer;
902+
build_transformer(&transformer, model_path);
903+
904+
#ifdef __AOTI_MODEL__
905+
ModelType model_type = get_model_type(std::stoi(transformer.runner->get_metadata()["tokenizer_type"]));
906+
#else // __ET_MODEL__
917907
ModelType model_type = get_model_type(llama_ver);
908+
#endif
909+
918910
if (model_type == UNKNOWN_MODEL) {
919911
fprintf(
920912
stderr,
@@ -923,11 +915,6 @@ int main(int argc, char* argv[]) {
923915
error_usage();
924916
}
925917

926-
if (model_path == NULL) {
927-
fprintf(stderr, "No model_path provided.");
928-
error_usage();
929-
}
930-
931918
if (tokenizer_path == NULL) {
932919
fprintf(stderr, "No tokenizer_path provided.");
933920
error_usage();
@@ -950,8 +937,12 @@ int main(int argc, char* argv[]) {
950937
vocab_size = tokenizer->vocab_size();
951938
}
952939

953-
Transformer transformer;
954-
build_transformer(&transformer, model_path, vocab_size, steps);
940+
// read in the Config and the Weights from the model
941+
// read_checkpoint(model_path, &t->config);
942+
// allocate the RunState buffers
943+
transformer.config.vocab_size = vocab_size;
944+
transformer.config.seq_len = steps;
945+
malloc_run_state(&transformer.state, &transformer.config);
955946

956947
Sampler sampler;
957948
build_sampler(&sampler, vocab_size, temperature, topp, rng_seed);

torchchat/export.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
from typing import Optional
8+
from typing import Dict, Optional
99

1010
import torch
11+
import torch._inductor
1112
import torch.nn as nn
1213

1314
from torch.export import Dim
14-
import torch._inductor
1515

1616
from torchchat.cli.builder import (
1717
_initialize_model,
@@ -39,6 +39,7 @@ def export_for_server(
3939
output_path: str = "model.pt2",
4040
dynamic_shapes: bool = False,
4141
package: bool = True,
42+
metadata: Dict[str, str] = {},
4243
) -> str:
4344
"""
4445
Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,7 +68,6 @@ def export_for_server(
6768
dynamic_shapes = None
6869

6970
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
70-
metadata = {} # TODO: put more metadata here
7171
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata}
7272
if not package:
7373
options = {"aot_inductor.output_path": output_path}
@@ -81,6 +81,7 @@ def export_for_server(
8181

8282
if package:
8383
from torch._inductor.package import package_aoti
84+
8485
path = package_aoti(output_path, path)
8586

8687
print(f"The generated packaged model can be found at: {path}")
@@ -102,13 +103,13 @@ def export_for_server(
102103
from typing import Any, Dict, Tuple, Union
103104

104105
import executorch.exir as exir
106+
from executorch.backends.xnnpack._passes.convert_to_linear import (
107+
ConvertToLinearPass,
108+
)
105109

106110
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
107111
XnnpackDynamicallyQuantizedPartitioner,
108112
)
109-
from executorch.backends.xnnpack._passes.convert_to_linear import (
110-
ConvertToLinearPass,
111-
)
112113
from executorch.exir import EdgeProgramManager, to_edge
113114

114115
from executorch.exir.capture._config import (
@@ -166,18 +167,22 @@ def __init__(self, attention: Attention):
166167

167168
self.wo = attention.wo
168169

169-
max_batch_size, n_heads, max_seq_length, head_dim = (
170-
attention.kv_cache[0].k_cache.shape
171-
)
170+
max_batch_size, n_heads, max_seq_length, head_dim = attention.kv_cache[
171+
0
172+
].k_cache.shape
172173
cache_dtype = attention.kv_cache[0].k_cache.dtype
173174
# The `Attention` module being replaced can have multiple KV caches
174175
# (denoted by `cache_lanes`). Thus we follow the same setup format
175176
# as in `Attention.setup_cache`.
176177
cache_lanes = len(attention.kv_cache)
177-
self.kv_cache = nn.ModuleList([
178-
CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype)
179-
for _ in range(cache_lanes)
180-
])
178+
self.kv_cache = nn.ModuleList(
179+
[
180+
CustomKVCache(
181+
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
182+
)
183+
for _ in range(cache_lanes)
184+
]
185+
)
181186

182187
self.n_heads = attention.n_heads
183188
self.head_dim = attention.head_dim
@@ -215,9 +220,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
215220
return self.wo(output)
216221

217222
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
218-
from executorch.extension.llm.custom_ops import ( # noqa
219-
sdpa_with_kv_cache,
220-
)
223+
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
221224

222225
for name, child in module.named_children():
223226
if isinstance(child, Attention):
@@ -350,7 +353,11 @@ def main(args):
350353

351354
print(f"Using device={builder_args.device}")
352355
set_precision(builder_args.precision)
353-
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)
356+
set_backend(
357+
dso=args.output_dso_path,
358+
pte=args.output_pte_path,
359+
aoti_package=args.output_aoti_package_path,
360+
)
354361

355362
builder_args.dso_path = None
356363
builder_args.pte_path = None
@@ -372,6 +379,7 @@ def main(args):
372379

373380
# TODO: clean this up
374381
# This mess is because ET does not support _weight_int4pack_mm right now
382+
tokenizer_args = None
375383
if not builder_args.gguf_path:
376384
# tokenizer needed for quantization so get that here,
377385
try:
@@ -382,9 +390,8 @@ def main(args):
382390

383391
if builder_args.max_seq_length is None:
384392
if (
385-
(output_dso_path is not None or output_aoti_package_path is not None)
386-
and not builder_args.dynamic_shapes
387-
):
393+
output_dso_path is not None or output_aoti_package_path is not None
394+
) and not builder_args.dynamic_shapes:
388395
print("Setting max_seq_length to 300 for DSO export.")
389396
builder_args.max_seq_length = 300
390397
elif output_pte_path is not None:
@@ -397,7 +404,8 @@ def main(args):
397404
quantize,
398405
tokenizer,
399406
max_seq_length=builder_args.max_seq_length,
400-
support_tensor_subclass=output_dso_path is None and output_aoti_package_path is None,
407+
support_tensor_subclass=output_dso_path is None
408+
and output_aoti_package_path is None,
401409
)
402410
model_to_pte = model
403411
model_to_dso = model
@@ -435,7 +443,9 @@ def main(args):
435443
if output_dso_path:
436444
output_dso_path = str(os.path.abspath(output_dso_path))
437445
print(f"Exporting model using AOT Inductor to {output_dso_path}")
438-
print("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead.")
446+
print(
447+
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
448+
)
439449
export_for_server(
440450
model_to_dso,
441451
builder_args.device,
@@ -446,11 +456,21 @@ def main(args):
446456

447457
if output_aoti_package_path:
448458
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
459+
460+
if tokenizer_args is None:
461+
tokenizer_type = "0"
462+
elif tokenizer_args.is_sentencepiece:
463+
tokenizer_type = "2" # Corresponding to llama2
464+
else:
465+
tokenizer_type = "3" # Corresponding to llama3
466+
467+
metadata = {"tokenizer_type": tokenizer_type}
449468
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
450469
export_for_server(
451470
model_to_aoti_package,
452471
builder_args.device,
453472
output_aoti_package_path,
454473
builder_args.dynamic_shapes,
455474
package=True,
475+
metadata=metadata,
456476
)

0 commit comments

Comments
 (0)