@@ -102,6 +102,7 @@ typedef struct {
102102typedef struct {
103103 Config config; // the hyperparameters of the architecture (the blueprint)
104104 RunState state; // buffers for the "wave" of activations in the forward pass
105+ std::unordered_map<std::string, std::string> metadata;
105106
106107#ifdef __AOTI_MODEL__
107108 torch::inductor::AOTIModelPackageLoader *runner;
@@ -141,20 +142,9 @@ void read_checkpoint(char *checkpoint, Config *config) {
141142 config->vocab_size = abs (config->vocab_size );
142143}
143144
144- void build_transformer (Transformer *t, char *model_path, int vocab_size,
145- int seq_len) {
146- // read in the Config and the Weights from the model
147- // read_checkpoint(model_path, &t->config);
148- // allocate the RunState buffers
149- t->config .vocab_size = vocab_size;
150- t->config .seq_len = seq_len;
151- malloc_run_state (&t->state , &t->config );
152-
145+ void build_transformer (Transformer *t, char *model_path) {
153146#ifdef __AOTI_MODEL__
154147 t->runner = new torch::inductor::AOTIModelPackageLoader (model_path);
155- aoti_device = t->runner ->get_metadata ()[" AOTI_DEVICE_KEY" ] == " cpu"
156- ? torch::Device (torch::kCPU )
157- : torch::Device (torch::kCUDA );
158148#else // __ET_MODEL__
159149 t->runner = new Module (
160150 /* path to PTE model */ model_path,
@@ -776,9 +766,6 @@ void error_usage() {
776766 " -v <int> (optional) vocab size, default is model-specific.\n " );
777767 fprintf (stderr,
778768 " -l <int> (optional) llama version (2 or 3), default 2.\n " );
779- fprintf (
780- stderr,
781- " -d <string> (optional) device(CUDA or CPU) model was exported for\n " );
782769 exit (EXIT_FAILURE);
783770}
784771
@@ -848,37 +835,35 @@ int main(int argc, char *argv[]) {
848835 system_prompt = argv[i + 1 ];
849836 } else if (argv[i][1 ] == ' l' ) {
850837 llama_ver = atoi (argv[i + 1 ]);
851- #ifdef __AOTI_MODEL__
852- } else if (argv[i][1 ] == ' d' ) {
853- #ifdef USE_CUDA
854- if (strcasecmp (argv[i + 1 ], " CUDA" ) == 0 ) {
855- aoti_device = torch::Device (torch::kCUDA );
856- } else
857- #endif
858- if (strcasecmp (argv[i + 1 ], " CPU" ) == 0 ) {
859- aoti_device = torch::Device (torch::kCPU );
860- } else {
861- fprintf (stderr, " Unknown device %s" , argv[i + 1 ]);
862- exit (1 );
863- }
864- #endif
865838 } else {
866839 error_usage ();
867840 }
868841 }
869842
843+ if (model_path == NULL ) {
844+ fprintf (stderr, " No model_path provided." );
845+ error_usage ();
846+ }
847+
848+ Transformer transformer;
849+ build_transformer (&transformer, model_path);
850+
851+ #ifdef __AOTI_MODEL__
852+ auto aoti_metadata = transformer.runner ->get_metadata ();
853+ aoti_device = aoti_metadata[" AOTI_DEVICE_KEY" ] == " cpu"
854+ ? torch::Device (torch::kCPU )
855+ : torch::Device (torch::kCUDA );
856+ ModelType model_type = get_model_type (std::stoi (aoti_metadata[" tokenizer_type" ]));
857+ #else // __ET_MODEL__
870858 ModelType model_type = get_model_type (llama_ver);
859+ #endif
860+
871861 if (model_type == UNKNOWN_MODEL) {
872862 fprintf (stderr, " Unknown model type passed by -l argument. Received l=%d." ,
873863 llama_ver);
874864 error_usage ();
875865 }
876866
877- if (model_path == NULL ) {
878- fprintf (stderr, " No model_path provided." );
879- error_usage ();
880- }
881-
882867 if (tokenizer_path == NULL ) {
883868 fprintf (stderr, " No tokenizer_path provided." );
884869 error_usage ();
@@ -901,8 +886,12 @@ int main(int argc, char *argv[]) {
901886 vocab_size = tokenizer->vocab_size ();
902887 }
903888
904- Transformer transformer;
905- build_transformer (&transformer, model_path, vocab_size, steps);
889+ // read in the Config and the Weights from the model
890+ // read_checkpoint(model_path, &t->config);
891+ // allocate the RunState buffers
892+ transformer.config .vocab_size = vocab_size;
893+ transformer.config .seq_len = steps;
894+ malloc_run_state (&transformer.state , &transformer.config );
906895
907896 Sampler sampler;
908897 build_sampler (&sampler, vocab_size, temperature, topp, rng_seed);
0 commit comments