@@ -59,7 +59,7 @@ void print_performance_report(
5959 outfile << num_tok;
6060 outfile.close ();
6161 } else {
62- ET_CHECK_MSG ( false , " Error saving the inference speed file" );
62+ ET_LOG (Error , " Error saving the inference speed file" );
6363 }
6464}
6565
@@ -84,7 +84,7 @@ void save_logits(
8484
8585} // namespace
8686
87- std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer (
87+ std::unique_ptr<::tokenizers::Tokenizer> load_qnn_llama_tokenizer (
8888 const std::string& tokenizer_path,
8989 Version version) {
9090 auto special_tokens = get_special_tokens (version);
@@ -175,7 +175,7 @@ Error Runner::load() {
175175 eos_ids->insert (tokenizer_->encode (" <|eot|>" , 0 , 0 ).get ()[0 ]);
176176 eos_ids->insert (tokenizer_->encode (" <|end_of_text|>" , 0 , 0 ).get ()[0 ]);
177177 } else {
178- tokenizer_ = load_llama_tokenizer (tokenizer_path_, Version::Default);
178+ tokenizer_ = load_qnn_llama_tokenizer (tokenizer_path_, Version::Default);
179179 if (tokenizer_ == nullptr ) {
180180 ET_LOG (
181181 Error, " Failed to load tokenizer with %s" , tokenizer_path_.c_str ());
@@ -313,13 +313,29 @@ Error Runner::load() {
313313}
314314
315315Error Runner::generate (
316+ const std::string& prompt,
317+ const executorch::extension::llm::GenerationConfig& config,
318+ std::function<void (const std::string&)> token_callback,
319+ std::function<void(const Stats&)> stats_callback) {
320+ return generate_from_pos (prompt, 0 , config, token_callback, stats_callback);
321+ }
322+
323+ Error Runner::generate_from_pos (
324+ const std::string& prompt,
325+ int64_t start_pos,
326+ const executorch::extension::llm::GenerationConfig& config,
327+ std::function<void (const std::string&)> token_callback,
328+ std::function<void(const Stats&)> stats_callback) {
329+ // TODO: currently only support start_pos == 0
330+ return generate_tokenized_prompt_option (prompt, false , config, token_callback, stats_callback);
331+ }
332+
333+ Error Runner::generate_from_prompt_or_file (
316334 const std::string& prompt,
317335 bool tokenized_prompt,
318- int32_t seq_len ,
336+ const executorch::extension::llm::GenerationConfig& config ,
319337 std::function<void (const std::string&)> token_callback,
320- std::function<void(const Stats&)> stats_callback,
321- bool echo,
322- bool warming) {
338+ std::function<void(const Stats&)> stats_callback) {
323339 ET_CHECK_MSG (!prompt.empty (), " prompt cannot be null" );
324340 if (!is_loaded ()) {
325341 stats_.model_load_start_ms = time_in_ms ();
@@ -328,6 +344,7 @@ Error Runner::generate(
328344 }
329345 stats_.inference_start_ms = time_in_ms ();
330346
347+ int32_t seq_len = config.seq_len ;
331348 seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
332349 int32_t n_bos = (cur_pos_ == 0 ) ? 1 : 0 ;
333350
@@ -366,7 +383,7 @@ Error Runner::generate(
366383 " sequence length exceeded - please increase the seq_len value" );
367384
368385 // Prompt Processor first
369- if (token_callback) {
386+ if (token_callback && config. echo ) {
370387 token_callback (prompt);
371388 }
372389 bool dump_logits = dump_logits_path_.empty () ? false : true ;
0 commit comments