Skip to content

Commit 8946d80

Browse files
authored
Fix text_llm_runner kv cache pos count and use it for generate() (pytorch#15286)
pos_ should advance by prefill and generated prompt size.
1 parent 9152f0a commit 8946d80

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

extension/llm/runner/text_llm_runner.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ Error TextLLMRunner::generate(
190190
// Generate max_new_tokens - 1 because prefill already generated 1 token.
191191
auto generate_result = text_token_generator_->generate(
192192
prompt_tokens,
193-
num_prompt_tokens,
193+
pos_,
194194
max_new_tokens - 1,
195195
temperature_ == -1.0f ? config.temperature : temperature_,
196196
wrapped_callback);
@@ -199,6 +199,8 @@ Error TextLLMRunner::generate(
199199
}
200200
int64_t num_generated_tokens = generate_result.get();
201201

202+
pos_ += num_generated_tokens;
203+
202204
stats_->inference_end_ms = time_in_ms();
203205
if (!config.warming) {
204206
printf("\n");

0 commit comments

Comments
 (0)