@@ -32,8 +32,6 @@ LICENSE file in the root directory of this source tree.
3232
3333#ifdef __AOTI_MODEL__
3434#include < torch/csrc/inductor/aoti_package/model_package_loader.h>
35- torch::Device aoti_device (torch::kCPU );
36-
3735#else // __ET_MODEL__
3836#include < executorch/extension/module/module.h>
3937#include < executorch/extension/tensor/tensor_ptr.h>
@@ -89,9 +87,11 @@ typedef struct {
8987typedef struct {
9088 Config config; // the hyperparameters of the architecture (the blueprint)
9189 RunState state; // buffers for the "wave" of activations in the forward pass
90+ std::unordered_map<std::string, std::string> metadata;
9291
9392#ifdef __AOTI_MODEL__
9493 torch::inductor::AOTIModelPackageLoader* runner;
94+
9595#else // __ET_MODEL__
9696 Module* runner;
9797#endif
@@ -130,19 +130,9 @@ void read_checkpoint(char* checkpoint, Config* config) {
130130
131131void build_transformer (
132132 Transformer* t,
133- char * model_path,
134- int vocab_size,
135- int seq_len) {
136- // read in the Config and the Weights from the model
137- // read_checkpoint(model_path, &t->config);
138- // allocate the RunState buffers
139- t->config .vocab_size = vocab_size;
140- t->config .seq_len = seq_len;
141- malloc_run_state (&t->state , &t->config );
142-
133+ char * model_path) {
143134#ifdef __AOTI_MODEL__
144135 t->runner = new torch::inductor::AOTIModelPackageLoader (model_path);
145- aoti_device = t->runner ->get_metadata ()[" AOTI_DEVICE_KEY" ] == " cpu" ? torch::Device (torch::kCPU ) : torch::Device (torch::kCUDA );
146136#else // __ET_MODEL__
147137 t->runner = new Module (
148138 /* path to PTE model */ model_path,
@@ -194,6 +184,9 @@ float* forward(Transformer* transformer, int token, int pos) {
194184 torch::Tensor token_tensor =
195185 torch::from_blob (token_buffer, {1 , 1 }, torch::kLong );
196186 torch::Tensor pos_tensor = torch::from_blob (pos_buffer, {1 }, torch::kLong );
187+ torch::Device aoti_device = transformer->runner ->get_metadata ()[" AOTI_DEVICE_KEY" ] == " cpu"
188+ ? torch::Device (torch::kCPU )
189+ : torch::Device (torch::kCUDA );
197190 std::vector<torch::Tensor> inputs{
198191 token_tensor.to (aoti_device), pos_tensor.to (aoti_device)};
199192
@@ -895,26 +888,25 @@ int main(int argc, char* argv[]) {
895888 system_prompt = argv[i + 1 ];
896889 } else if (argv[i][1 ] == ' l' ) {
897890 llama_ver = atoi (argv[i + 1 ]);
898- #ifdef __AOTI_MODEL__
899- } else if (argv[i][1 ] == ' d' ) {
900- #ifdef USE_CUDA
901- if (strcasecmp (argv[i + 1 ], " CUDA" ) == 0 ) {
902- aoti_device = torch::Device (torch::kCUDA );
903- } else
904- #endif
905- if (strcasecmp (argv[i + 1 ], " CPU" ) == 0 ) {
906- aoti_device = torch::Device (torch::kCPU );
907- } else {
908- fprintf (stderr, " Unknown device %s" , argv[i + 1 ]);
909- exit (1 );
910- }
911- #endif
912891 } else {
913892 error_usage ();
914893 }
915894 }
916895
896+ if (model_path == NULL ) {
897+ fprintf (stderr, " No model_path provided." );
898+ error_usage ();
899+ }
900+
901+ Transformer transformer;
902+ build_transformer (&transformer, model_path);
903+
904+ #ifdef __AOTI_MODEL__
905+ ModelType model_type = get_model_type (std::stoi (transformer.runner ->get_metadata ()[" tokenizer_type" ]));
906+ #else // __ET_MODEL__
917907 ModelType model_type = get_model_type (llama_ver);
908+ #endif
909+
918910 if (model_type == UNKNOWN_MODEL) {
919911 fprintf (
920912 stderr,
@@ -923,11 +915,6 @@ int main(int argc, char* argv[]) {
923915 error_usage ();
924916 }
925917
926- if (model_path == NULL ) {
927- fprintf (stderr, " No model_path provided." );
928- error_usage ();
929- }
930-
931918 if (tokenizer_path == NULL ) {
932919 fprintf (stderr, " No tokenizer_path provided." );
933920 error_usage ();
@@ -950,8 +937,12 @@ int main(int argc, char* argv[]) {
950937 vocab_size = tokenizer->vocab_size ();
951938 }
952939
953- Transformer transformer;
954- build_transformer (&transformer, model_path, vocab_size, steps);
940+ // read in the Config and the Weights from the model
941+ // read_checkpoint(model_path, &t->config);
942+ // allocate the RunState buffers
943+ transformer.config .vocab_size = vocab_size;
944+ transformer.config .seq_len = steps;
945+ malloc_run_state (&transformer.state , &transformer.config );
955946
956947 Sampler sampler;
957948 build_sampler (&sampler, vocab_size, temperature, topp, rng_seed);
0 commit comments