Skip to content

Commit 2b6dfe8

Browse files
authored
llama : remove write/read of output ids/logits/embeddings (ggml-org#18862)
* llama : remove write/read of output ids/logits/embeddings This commit removes the write/read of output ids, logits and embeddings from the llama context state. Refs: ggml-org#18862 (comment) * completion : add replying of session state This commit updates the session handing in the completion tool to handle the that logits are no longer stored in the session file. Instead, we need to replay the last token to get the logits for sampling. * common : add common_prompt_batch_decode function This commit adds a new function which is responsible for decoding prompt and optionally handle the saving for session data. * update save-state.cpp to use llama_state_load_file This commit updates the save-load-state example to utilize the new llama_state_load_file function for loading the model state from a file. And it also replays the last token after loading since this state is now stored before the last token is processed. * examples : set n_seq_max = 2 for ctx3 This commit updates the save-load-state example to set the n_seq_max parameter to 2 when initializing the ctx3 context. The motivation for this change is that using 1 as n_parallel/n_seq_max the context only supports one sequence, but the test laster tries to use a second sequence which results in the following error: ```console main : loaded state with 4 tokens main : seq 0 copied, 225760 bytes main : kv cache cleared find_slot: seq_id=1 >= n_seq_max=1 Try using a bigger --parallel value state_read_meta: failed to find available cells in kv cache ``` This seems to only happen for recurrent/hybrid models.
1 parent e8e2616 commit 2b6dfe8

File tree

5 files changed

+132
-200
lines changed

5 files changed

+132
-200
lines changed

common/common.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,3 +1760,65 @@ float lr_opt::get_lr(float epoch) const {
17601760
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
17611761
return r;
17621762
}
1763+
1764+
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
1765+
llama_batch batch = llama_batch_get_one(&last_token, 1);
1766+
batch.pos = &pos;
1767+
if (llama_decode(ctx, batch)) {
1768+
LOG_ERR("%s: failed to replay last token\n", __func__);
1769+
return false;
1770+
}
1771+
return true;
1772+
}
1773+
1774+
bool common_prompt_batch_decode(
1775+
struct llama_context * ctx,
1776+
const std::vector<llama_token> & tokens,
1777+
int & n_past,
1778+
int n_batch,
1779+
std::string_view state_path,
1780+
bool save_state) {
1781+
const int n_eval = tokens.size();
1782+
if (n_eval == 0) {
1783+
return true;
1784+
}
1785+
1786+
if (save_state && n_eval > 1) {
1787+
const int n_tokens_before_last = n_eval - 1;
1788+
1789+
GGML_ASSERT(n_eval <= n_batch);
1790+
1791+
// Decode all but the last token so we can save the memory state before decoding the last token.
1792+
// This is done so we can restore the session state later and replay the last token.
1793+
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
1794+
// memory, so we can't just remove the last token from the memory and replay the last token which
1795+
// is the reason for this logic.
1796+
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
1797+
LOG_ERR("%s : failed to eval\n", __func__);
1798+
return false;
1799+
}
1800+
n_past += n_tokens_before_last;
1801+
1802+
llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last);
1803+
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last);
1804+
1805+
llama_token last_token = tokens.back();
1806+
llama_batch batch = llama_batch_get_one(&last_token, 1);
1807+
int32_t pos = n_past;
1808+
batch.pos = &pos;
1809+
1810+
if (llama_decode(ctx, batch)) {
1811+
LOG_ERR("%s : failed to eval last token\n", __func__);
1812+
return false;
1813+
}
1814+
n_past++;
1815+
} else {
1816+
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
1817+
LOG_ERR("%s : failed to eval\n", __func__);
1818+
return false;
1819+
}
1820+
n_past += n_eval;
1821+
}
1822+
1823+
return true;
1824+
}

common/common.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,23 @@ void common_batch_add(
804804
const std::vector<llama_seq_id> & seq_ids,
805805
bool logits);
806806

807+
// decodes a single batch of tokens for a prompt and manages session tokens
808+
//
809+
// Note: We save state before the last token so that we can replay it to ensure
810+
// compatibility with all memory types. Recurrent/hybrid models cannot remove
811+
// tokens from memory, so this approach works across all model architectures.
812+
bool common_prompt_batch_decode(
813+
struct llama_context * ctx,
814+
const std::vector<llama_token> & embd,
815+
int & n_past,
816+
int n_batch,
817+
std::string_view state_path,
818+
bool save_state);
819+
820+
// replays the last token after loading state to regenerate logits
821+
// used after loading session state to ensure the sampling context has valid logits
822+
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);
823+
807824
//
808825
// Vocab utils
809826
//

examples/save-load-state/save-load-state.cpp

Lines changed: 35 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
#include <vector>
66
#include <cstdio>
77

8+
89
int main(int argc, char ** argv) {
910
common_params params;
1011

1112
params.prompt = "The quick brown fox";
1213
params.sampling.seed = 1234;
1314

15+
const std::string_view state_file = "dump_state.bin";
16+
1417
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
1518
return 1;
1619
}
@@ -53,35 +56,16 @@ int main(int argc, char ** argv) {
5356
// tokenize prompt
5457
auto tokens = common_tokenize(ctx, params.prompt, true);
5558

56-
// prepare the batch
57-
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
58-
for (size_t i = 0; i < tokens.size(); i++) {
59-
common_batch_add(batch, tokens[i], i, {0}, false);
60-
}
61-
batch.logits[batch.n_tokens - 1] = true; // generate next token
62-
63-
// evaluate prompt
64-
llama_decode(ctx, batch);
65-
n_past += batch.n_tokens;
66-
67-
// save state (rng, logits, embedding and kv_cache) to file
68-
{
69-
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
70-
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
71-
72-
FILE *fp_write = fopen("dump_state.bin", "wb");
73-
fwrite(state_mem.data(), 1, written, fp_write);
74-
fclose(fp_write);
75-
76-
fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
59+
const bool save_state = true;
60+
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
61+
return 1;
7762
}
7863

79-
// save state (last tokens)
80-
const auto n_past_saved = n_past;
81-
8264
// first run
8365
printf("\nfirst run: %s", params.prompt.c_str());
8466

67+
llama_batch batch = llama_batch_init(1, 0, 1);
68+
8569
for (auto i = 0; i < params.n_predict; i++) {
8670
auto next_token = llama_sampler_sample(smpl, ctx, -1);
8771
auto next_token_str = common_token_to_piece(ctx, next_token);
@@ -111,27 +95,23 @@ int main(int argc, char ** argv) {
11195

11296
printf("\nsecond run: %s", params.prompt.c_str());
11397

114-
// load state (rng, logits, embedding and kv_cache) from file
115-
{
116-
std::vector<uint8_t> state_mem;
117-
118-
FILE * fp_read = fopen("dump_state.bin", "rb");
119-
fseek(fp_read, 0, SEEK_END);
120-
state_mem.resize(ftell(fp_read));
121-
fseek(fp_read, 0, SEEK_SET);
122-
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
123-
fclose(fp_read);
124-
125-
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
126-
fprintf(stderr, "\n%s : failed to read state\n", __func__);
127-
return 1;
128-
}
98+
// load state from file
99+
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
100+
size_t n_token_count_out = 0;
129101

130-
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
102+
if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
103+
fprintf(stderr, "\n%s : failed to load state\n", __func__);
104+
return 1;
131105
}
132106

107+
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
108+
133109
// restore state (last tokens)
134-
n_past = n_past_saved;
110+
n_past = n_token_count_out;
111+
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
112+
return 1;
113+
}
114+
++n_past;
135115

136116
// second run
137117
for (auto i = 0; i < params.n_predict; i++) {
@@ -160,7 +140,9 @@ int main(int argc, char ** argv) {
160140
}
161141

162142
// make new context
163-
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
143+
auto params_ctx3 = common_context_params_to_llama(params);
144+
params_ctx3.n_seq_max = 2;
145+
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
164146

165147
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
166148

@@ -169,26 +151,21 @@ int main(int argc, char ** argv) {
169151
printf("\nsingle seq run: %s", params.prompt.c_str());
170152

171153
// load state (rng, logits, embedding and kv_cache) from file
172-
{
173-
std::vector<uint8_t> state_mem;
174-
175-
FILE * fp_read = fopen("dump_state.bin", "rb");
176-
fseek(fp_read, 0, SEEK_END);
177-
state_mem.resize(ftell(fp_read));
178-
fseek(fp_read, 0, SEEK_SET);
179-
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
180-
fclose(fp_read);
154+
n_token_count_out = 0;
181155

182-
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
183-
fprintf(stderr, "\n%s : failed to read state\n", __func__);
184-
return 1;
185-
}
186-
187-
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
156+
if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
157+
fprintf(stderr, "\n%s : failed to load state\n", __func__);
158+
return 1;
188159
}
189160

161+
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
162+
190163
// restore state (last tokens)
191-
n_past = n_past_saved;
164+
n_past = n_token_count_out;
165+
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
166+
return 1;
167+
}
168+
++n_past;
192169

193170
// save seq 0 and load into seq 1
194171
{

src/llama-context.cpp

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -2440,64 +2440,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
24402440
// TODO: add more model-specific info which should prevent loading the session file if not identical
24412441
}
24422442

2443-
// write output ids
2444-
{
2445-
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2446-
2447-
const auto n_outputs = this->n_outputs;
2448-
const auto & output_ids = this->output_ids;
2449-
2450-
std::vector<int32_t> w_output_pos;
2451-
2452-
w_output_pos.resize(n_outputs);
2453-
2454-
// build a more compact representation of the output ids
2455-
for (size_t i = 0; i < n_batch(); ++i) {
2456-
// map an output id to a position in the batch
2457-
int64_t pos = output_ids[i];
2458-
if (pos >= 0) {
2459-
GGML_ASSERT(pos < n_outputs);
2460-
w_output_pos[pos] = i;
2461-
}
2462-
}
2463-
2464-
io.write(&n_outputs, sizeof(n_outputs));
2465-
2466-
if (n_outputs) {
2467-
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
2468-
}
2469-
}
2470-
2471-
// [TAG_CONTEXT_STATE_LOGITS]
2472-
// write logits
2473-
{
2474-
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
2475-
2476-
const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
2477-
2478-
io.write(&logits_size, sizeof(logits_size));
2479-
2480-
if (logits_size) {
2481-
io.write(logits.data, logits_size * sizeof(float));
2482-
}
2483-
}
2484-
2485-
// write embeddings
2486-
{
2487-
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
2488-
2489-
const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
2490-
2491-
io.write(&embd_size, sizeof(embd_size));
2492-
2493-
if (embd_size) {
2494-
io.write(embd.data, embd_size * sizeof(float));
2495-
}
2496-
}
2497-
2498-
// TODO: handle sampling buffers and samplers state ?
2499-
// https://github.com/ggml-org/llama.cpp/pull/17004
2500-
25012443
if (memory != nullptr) {
25022444
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
25032445
memory->state_write(io);
@@ -2523,70 +2465,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
25232465
// TODO: add more info which needs to be identical but which is not verified otherwise
25242466
}
25252467

2526-
// read output ids
2527-
{
2528-
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
2529-
2530-
auto n_outputs = this->n_outputs;
2531-
io.read_to(&n_outputs, sizeof(n_outputs));
2532-
2533-
if (n_outputs > output_reserve(n_outputs)) {
2534-
throw std::runtime_error("could not reserve outputs");
2535-
}
2536-
2537-
std::vector<int32_t> output_pos;
2538-
2539-
if (n_outputs) {
2540-
output_pos.resize(n_outputs);
2541-
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
2542-
2543-
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
2544-
int32_t id = output_pos[i];
2545-
if ((uint32_t) id >= n_batch()) {
2546-
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
2547-
}
2548-
this->output_ids[id] = i;
2549-
}
2550-
2551-
this->n_outputs = n_outputs;
2552-
}
2553-
}
2554-
2555-
// read logits
2556-
{
2557-
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
2558-
2559-
uint64_t logits_size;
2560-
io.read_to(&logits_size, sizeof(logits_size));
2561-
2562-
if (this->logits.size < logits_size) {
2563-
throw std::runtime_error("logits buffer too small");
2564-
}
2565-
2566-
if (logits_size) {
2567-
io.read_to(this->logits.data, logits_size * sizeof(float));
2568-
}
2569-
}
2570-
2571-
// read embeddings
2572-
{
2573-
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
2574-
2575-
uint64_t embd_size;
2576-
io.read_to(&embd_size, sizeof(embd_size));
2577-
2578-
if (this->embd.size < embd_size) {
2579-
throw std::runtime_error("embeddings buffer too small");
2580-
}
2581-
2582-
if (embd_size) {
2583-
io.read_to(this->embd.data, embd_size * sizeof(float));
2584-
}
2585-
}
2586-
2587-
// TODO: handle sampling buffers and samplers state ?
2588-
// https://github.com/ggml-org/llama.cpp/pull/17004
2589-
25902468
if (memory) {
25912469
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
25922470

0 commit comments

Comments
 (0)