99// A llama 3.2 runner that includes preprocessing and post processing
1010// logic. The module takes in a string as input and emits a string as output.
1111
12+ #include < executorch/examples/models/llama/runner/runner.h>
1213#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1314#include < executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
1415#include < executorch/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h>
@@ -59,7 +60,7 @@ void print_performance_report(
5960 outfile << num_tok;
6061 outfile.close ();
6162 } else {
62- ET_CHECK_MSG ( false , " Error saving the inference speed file" );
63+ ET_LOG (Error , " Error saving the inference speed file" );
6364 }
6465}
6566
@@ -84,13 +85,6 @@ void save_logits(
8485
8586} // namespace
8687
87- std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer (
88- const std::string& tokenizer_path,
89- Version version) {
90- auto special_tokens = get_special_tokens (version);
91- return llm::load_tokenizer (tokenizer_path, std::move (special_tokens));
92- }
93-
9488Runner::Runner (
9589 const std::string& decoder_model_version,
9690 const std::string& model_path,
@@ -177,7 +171,8 @@ Error Runner::load() {
177171 eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
178172 eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
179173 } else {
180- tokenizer_ = load_llama_tokenizer (tokenizer_path_, Version::Default);
174+ tokenizer_ =
175+ example::load_llama_tokenizer (tokenizer_path_, Version::Default);
181176 if (tokenizer_ == nullptr ) {
182177 ET_LOG (
183178 Error, " Failed to load tokenizer with %s" , tokenizer_path_.c_str ());
@@ -317,13 +312,30 @@ Error Runner::load() {
317312}
318313
319314Error Runner::generate (
315+ const std::string& prompt,
316+ const llm::GenerationConfig& config,
317+ std::function<void (const std::string&)> token_callback,
318+ std::function<void(const Stats&)> stats_callback) {
319+ return generate_from_pos (prompt, 0 , config, token_callback, stats_callback);
320+ }
321+
322+ Error Runner::generate_from_pos (
323+ const std::string& prompt,
324+ int64_t start_pos,
325+ const llm::GenerationConfig& config,
326+ std::function<void (const std::string&)> token_callback,
327+ std::function<void(const Stats&)> stats_callback) {
328+ // TODO: currently only support start_pos == 0
329+ return generate_from_prompt_or_file (
330+ prompt, false , config, token_callback, stats_callback);
331+ }
332+
333+ Error Runner::generate_from_prompt_or_file (
320334 const std::string& prompt,
321335 bool tokenized_prompt,
322- int32_t seq_len ,
336+ const llm::GenerationConfig& config ,
323337 std::function<void (const std::string&)> token_callback,
324- std::function<void(const Stats&)> stats_callback,
325- bool echo,
326- bool warming) {
338+ std::function<void(const Stats&)> stats_callback) {
327339 ET_CHECK_MSG (!prompt.empty (), " prompt cannot be null" );
328340 if (!is_loaded ()) {
329341 stats_.model_load_start_ms = time_in_ms ();
@@ -332,6 +344,7 @@ Error Runner::generate(
332344 }
333345 stats_.inference_start_ms = time_in_ms ();
334346
347+ int32_t seq_len = config.seq_len ;
335348 seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
336349 int32_t n_bos = (cur_pos_ == 0 ) ? 1 : 0 ;
337350
@@ -370,7 +383,7 @@ Error Runner::generate(
370383 " sequence length exceeded - please increase the seq_len value" );
371384
372385 // Prompt Processor first
373- if (token_callback) {
386+ if (token_callback && config. echo ) {
374387 token_callback (prompt);
375388 }
376389 bool dump_logits = dump_logits_path_.empty () ? false : true ;
0 commit comments