Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b2b93c5

Browse files
committed
[aoti] Remove need for -l in cmake
1 parent 4510ba0 commit b2b93c5

File tree

3 files changed

+37
-37
lines changed

3 files changed

+37
-37
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ torchchat/utils/scripts/build_native.sh aoti
332332

333333
Then run the compiled executable, with the pt2.
334334
```bash
335-
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
335+
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -i "Once upon a time"
336336
```
337337

338338
## Mobile Execution

runner/run.cpp

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ LICENSE file in the root directory of this source tree.
3232

3333
#ifdef __AOTI_MODEL__
3434
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
35-
torch::Device aoti_device(torch::kCPU);
36-
3735
#else // __ET_MODEL__
3836
#include <executorch/extension/module/module.h>
3937
#include <executorch/extension/tensor/tensor_ptr.h>
@@ -89,9 +87,11 @@ typedef struct {
8987
typedef struct {
9088
Config config; // the hyperparameters of the architecture (the blueprint)
9189
RunState state; // buffers for the "wave" of activations in the forward pass
90+
std::unordered_map<std::string, std::string> metadata;
9291

9392
#ifdef __AOTI_MODEL__
9493
torch::inductor::AOTIModelPackageLoader* runner;
94+
9595
#else // __ET_MODEL__
9696
Module* runner;
9797
#endif
@@ -130,19 +130,9 @@ void read_checkpoint(char* checkpoint, Config* config) {
130130

131131
void build_transformer(
132132
Transformer* t,
133-
char* model_path,
134-
int vocab_size,
135-
int seq_len) {
136-
// read in the Config and the Weights from the model
137-
// read_checkpoint(model_path, &t->config);
138-
// allocate the RunState buffers
139-
t->config.vocab_size = vocab_size;
140-
t->config.seq_len = seq_len;
141-
malloc_run_state(&t->state, &t->config);
142-
133+
char* model_path) {
143134
#ifdef __AOTI_MODEL__
144135
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
145-
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
146136
#else //__ET_MODEL__
147137
t->runner = new Module(
148138
/* path to PTE model */ model_path,
@@ -194,6 +184,9 @@ float* forward(Transformer* transformer, int token, int pos) {
194184
torch::Tensor token_tensor =
195185
torch::from_blob(token_buffer, {1, 1}, torch::kLong);
196186
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
187+
torch::Device aoti_device = transformer->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu"
188+
? torch::Device(torch::kCPU)
189+
: torch::Device(torch::kCUDA);
197190
std::vector<torch::Tensor> inputs{
198191
token_tensor.to(aoti_device), pos_tensor.to(aoti_device)};
199192

@@ -895,26 +888,25 @@ int main(int argc, char* argv[]) {
895888
system_prompt = argv[i + 1];
896889
} else if (argv[i][1] == 'l') {
897890
llama_ver = atoi(argv[i + 1]);
898-
#ifdef __AOTI_MODEL__
899-
} else if (argv[i][1] == 'd') {
900-
#ifdef USE_CUDA
901-
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
902-
aoti_device = torch::Device(torch::kCUDA);
903-
} else
904-
#endif
905-
if (strcasecmp(argv[i + 1], "CPU") == 0) {
906-
aoti_device = torch::Device(torch::kCPU);
907-
} else {
908-
fprintf(stderr, "Unknown device %s", argv[i + 1]);
909-
exit(1);
910-
}
911-
#endif
912891
} else {
913892
error_usage();
914893
}
915894
}
916895

896+
if (model_path == NULL) {
897+
fprintf(stderr, "No model_path provided.");
898+
error_usage();
899+
}
900+
901+
Transformer transformer;
902+
build_transformer(&transformer, model_path);
903+
904+
#ifdef __AOTI_MODEL__
905+
ModelType model_type = get_model_type(std::stoi(transformer.runner->get_metadata()["tokenizer_type"]));
906+
#else // __ET_MODEL__
917907
ModelType model_type = get_model_type(llama_ver);
908+
#endif
909+
918910
if (model_type == UNKNOWN_MODEL) {
919911
fprintf(
920912
stderr,
@@ -923,11 +915,6 @@ int main(int argc, char* argv[]) {
923915
error_usage();
924916
}
925917

926-
if (model_path == NULL) {
927-
fprintf(stderr, "No model_path provided.");
928-
error_usage();
929-
}
930-
931918
if (tokenizer_path == NULL) {
932919
fprintf(stderr, "No tokenizer_path provided.");
933920
error_usage();
@@ -950,8 +937,12 @@ int main(int argc, char* argv[]) {
950937
vocab_size = tokenizer->vocab_size();
951938
}
952939

953-
Transformer transformer;
954-
build_transformer(&transformer, model_path, vocab_size, steps);
940+
// read in the Config and the Weights from the model
941+
// read_checkpoint(model_path, &t->config);
942+
// allocate the RunState buffers
943+
transformer.config.vocab_size = vocab_size;
944+
transformer.config.seq_len = steps;
945+
malloc_run_state(&transformer.state, &transformer.config);
955946

956947
Sampler sampler;
957948
build_sampler(&sampler, vocab_size, temperature, topp, rng_seed);

torchchat/export.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
from typing import Optional
8+
from typing import Dict, Optional
99

1010
import torch
1111
import torch.nn as nn
@@ -39,6 +39,7 @@ def export_for_server(
3939
output_path: str = "model.pt2",
4040
dynamic_shapes: bool = False,
4141
package: bool = True,
42+
metadata: Optional[Dict[str, str]] = None,
4243
) -> str:
4344
"""
4445
Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,7 +68,7 @@ def export_for_server(
6768
dynamic_shapes = None
6869

6970
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
70-
metadata = {} # TODO: put more metadata here
71+
metadata = metadata or {}
7172
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata}
7273
if not package:
7374
options = {"aot_inductor.output_path": output_path}
@@ -372,6 +373,7 @@ def main(args):
372373

373374
# TODO: clean this up
374375
# This mess is because ET does not support _weight_int4pack_mm right now
376+
tokenizer_args = None
375377
if not builder_args.gguf_path:
376378
# tokenizer needed for quantization so get that here,
377379
try:
@@ -446,11 +448,18 @@ def main(args):
446448

447449
if output_aoti_package_path:
448450
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
451+
452+
tokenizer_type = "0"
453+
if tokenizer_args is not None:
454+
tokenizer_type = "2" if tokenizer_args.is_sentencepiece else "3"
455+
456+
metadata = {"tokenizer_type": tokenizer_type}
449457
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
450458
export_for_server(
451459
model_to_aoti_package,
452460
builder_args.device,
453461
output_aoti_package_path,
454462
builder_args.dynamic_shapes,
455463
package=True,
464+
metadata=metadata,
456465
)

0 commit comments

Comments
 (0)