Skip to content

Commit bb26957

Browse files
committed
cpp: fix memory bug causing leak when reloading models
1 parent 243605e commit bb26957

File tree

3 files changed

+50
-63
lines changed

3 files changed

+50
-63
lines changed

smollm/src/main/cpp/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ function(build_library target_name)
100100
target_include_directories(
101101
${target_name}
102102
PUBLIC
103+
${COMMON_DIR}
103104
${GGML_DIR}/include
104105
${GGML_DIR}/src
105106
${GGML_DIR}/src/ggml-cpu
106107
${LLAMA_DIR}/include
107-
${COMMON_DIR}
108108
${VENDOR_DIR}
109109
)
110110

Lines changed: 46 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
#include "LLMInference.h"
2-
#include "llama.h"
3-
#include "gguf.h"
42
#include <android/log.h>
53
#include <cstring>
64
#include <iostream>
@@ -9,20 +7,9 @@
97
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
108
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
119

12-
std::vector<llama_token> common_tokenize(
13-
const struct llama_vocab * vocab,
14-
const std::string & text,
15-
bool add_special,
16-
bool parse_special = false);
17-
18-
std::string common_token_to_piece(
19-
const struct llama_context * ctx,
20-
llama_token token,
21-
bool special = true);
22-
2310
void
24-
LLMInference::loadModel(const char* model_path, float minP, float temperature, bool storeChats, long contextSize,
25-
const char* chatTemplate, int nThreads, bool useMmap, bool useMlock) {
11+
LLMInference::loadModel(const char *model_path, float minP, float temperature, bool storeChats, long contextSize,
12+
const char *chatTemplate, int nThreads, bool useMmap, bool useMlock) {
2613
LOGi("loading model with"
2714
"\n\tmodel_path = %s"
2815
"\n\tminP = %f"
@@ -35,56 +22,57 @@ LLMInference::loadModel(const char* model_path, float minP, float temperature, b
3522
"\n\tuseMlock = %d",
3623
model_path, minP, temperature, storeChats, contextSize, chatTemplate, nThreads, useMmap, useMlock);
3724

25+
// load dynamic backends
26+
ggml_backend_load_all();
27+
3828
// create an instance of llama_model
3929
llama_model_params model_params = llama_model_default_params();
40-
model_params.use_mmap = useMmap;
41-
model_params.use_mlock = useMlock;
42-
_model = llama_model_load_from_file(model_path, model_params);
43-
30+
model_params.use_mmap = useMmap;
31+
model_params.use_mlock = useMlock;
32+
_model = llama_model_load_from_file(model_path, model_params);
4433
if (!_model) {
4534
LOGe("failed to load model from %s", model_path);
4635
throw std::runtime_error("loadModel() failed");
4736
}
4837

4938
// create an instance of llama_context
5039
llama_context_params ctx_params = llama_context_default_params();
51-
ctx_params.n_ctx = contextSize;
52-
ctx_params.n_threads = nThreads;
53-
ctx_params.no_perf = true; // disable performance metrics
54-
_ctx = llama_init_from_model(_model, ctx_params);
55-
40+
ctx_params.n_ctx = contextSize;
41+
ctx_params.n_batch = contextSize;
42+
ctx_params.n_threads = nThreads;
43+
ctx_params.no_perf = true; // disable performance metrics
44+
_ctx = llama_init_from_model(_model, ctx_params);
5645
if (!_ctx) {
5746
LOGe("llama_new_context_with_model() returned null)");
5847
throw std::runtime_error("llama_new_context_with_model() returned null");
5948
}
6049

61-
// initialize sampler
50+
// create an instance of llama_sampler
6251
llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params();
63-
sampler_params.no_perf = true; // disable performance metrics
64-
_sampler = llama_sampler_chain_init(sampler_params);
65-
52+
sampler_params.no_perf = true; // disable performance metrics
53+
_sampler = llama_sampler_chain_init(sampler_params);
6654
llama_sampler_chain_add(_sampler, llama_sampler_init_temp(temperature));
6755
llama_sampler_chain_add(_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
68-
if (minP >= 0.01f) {
69-
// minP = 0.0 (disabled)
70-
// minP can be adjusted across 100 steps between [0.0,1.0], the smallest step being 0.01
71-
llama_sampler_chain_add(_sampler, llama_sampler_init_min_p(minP, 1));
72-
}
7356

7457
_formattedMessages = std::vector<char>(llama_n_ctx(_ctx));
7558
_messages.clear();
76-
_chatTemplate = strdup(chatTemplate);
59+
60+
if (chatTemplate == nullptr) {
61+
_chatTemplate = llama_model_chat_template(_model, nullptr);
62+
} else {
63+
_chatTemplate = strdup(chatTemplate);
64+
}
7765
this->_storeChats = storeChats;
7866
}
7967

8068
void
81-
LLMInference::addChatMessage(const char* message, const char* role) {
82-
_messages.push_back({ strdup(role), strdup(message) });
69+
LLMInference::addChatMessage(const char *message, const char *role) {
70+
_messages.push_back({strdup(role), strdup(message)});
8371
}
8472

8573
float
8674
LLMInference::getResponseGenerationTime() const {
87-
return (float)_responseNumTokens / (_responseGenerationTime / 1e6);
75+
return (float) _responseNumTokens / (_responseGenerationTime / 1e6);
8876
}
8977

9078
int
@@ -93,19 +81,19 @@ LLMInference::getContextSizeUsed() const {
9381
}
9482

9583
void
96-
LLMInference::startCompletion(const char* query) {
84+
LLMInference::startCompletion(const char *query) {
9785
if (!_storeChats) {
9886
_prevLen = 0;
9987
_formattedMessages.clear();
10088
_formattedMessages = std::vector<char>(llama_n_ctx(_ctx));
10189
}
10290
_responseGenerationTime = 0;
103-
_responseNumTokens = 0;
91+
_responseNumTokens = 0;
10492
addChatMessage(query, "user");
10593
// apply the chat-template
10694
int newLen = llama_chat_apply_template(_chatTemplate, _messages.data(), _messages.size(), true,
10795
_formattedMessages.data(), _formattedMessages.size());
108-
if (newLen > (int)_formattedMessages.size()) {
96+
if (newLen > (int) _formattedMessages.size()) {
10997
// resize the output buffer `_formattedMessages`
11098
// and re-apply the chat template
11199
_formattedMessages.resize(newLen);
@@ -120,19 +108,20 @@ LLMInference::startCompletion(const char* query) {
120108

121109
// create a llama_batch containing a single sequence
122110
// see llama_batch_init for more details
123-
_batch.token = _promptTokens.data();
124-
_batch.n_tokens = _promptTokens.size();
111+
_batch = new llama_batch();
112+
_batch->token = _promptTokens.data();
113+
_batch->n_tokens = _promptTokens.size();
125114
}
126115

127116
// taken from:
128117
// https://github.com/ggerganov/llama.cpp/blob/master/examples/llama.android/llama/src/main/cpp/llama-android.cpp#L38
129118
bool
130-
LLMInference::_isValidUtf8(const char* response) {
119+
LLMInference::_isValidUtf8(const char *response) {
131120
if (!response) {
132121
return true;
133122
}
134-
const unsigned char* bytes = (const unsigned char*)response;
135-
int num;
123+
const unsigned char *bytes = (const unsigned char *) response;
124+
int num;
136125
while (*bytes != 0x00) {
137126
if ((*bytes & 0x80) == 0x00) {
138127
// U+0000 to U+007F
@@ -166,14 +155,14 @@ LLMInference::completionLoop() {
166155
// check if the length of the inputs to the model
167156
// have exceeded the context size of the model
168157
uint32_t contextSize = llama_n_ctx(_ctx);
169-
_nCtxUsed = llama_kv_self_used_cells(_ctx);
170-
if (_nCtxUsed + _batch.n_tokens > contextSize) {
158+
_nCtxUsed = llama_memory_seq_pos_max(llama_get_memory(_ctx), 0) + 1;
159+
if (_nCtxUsed + _batch->n_tokens > contextSize) {
171160
throw std::runtime_error("context size reached");
172161
}
173162

174163
auto start = ggml_time_us();
175164
// run the model
176-
if (llama_decode(_ctx, _batch) < 0) {
165+
if (llama_decode(_ctx, *_batch) < 0) {
177166
throw std::runtime_error("llama_decode() failed");
178167
}
179168

@@ -186,7 +175,6 @@ LLMInference::completionLoop() {
186175
return "[EOG]";
187176
}
188177
std::string piece = common_token_to_piece(_ctx, _currToken, true);
189-
LOGi("common_token_to_piece: %s", piece.c_str());
190178
auto end = ggml_time_us();
191179
_responseGenerationTime += (end - start);
192180
_responseNumTokens += 1;
@@ -195,8 +183,8 @@ LLMInference::completionLoop() {
195183
// re-init the batch with the newly predicted token
196184
// key, value pairs of all previous tokens have been cached
197185
// in the KV cache
198-
_batch.token = &_currToken;
199-
_batch.n_tokens = 1;
186+
_batch->token = &_currToken;
187+
_batch->n_tokens = 1;
200188

201189
if (_isValidUtf8(_cacheResponseTokens.c_str())) {
202190
_response += _cacheResponseTokens;
@@ -214,23 +202,21 @@ LLMInference::stopCompletion() {
214202
addChatMessage(_response.c_str(), "assistant");
215203
}
216204
_response.clear();
217-
const char* tmpl = llama_model_chat_template(_model, nullptr);
218-
_prevLen = llama_chat_apply_template(tmpl, _messages.data(), _messages.size(), false, nullptr, 0);
205+
_prevLen = llama_chat_apply_template(_chatTemplate, _messages.data(), _messages.size(), false, nullptr, 0);
219206
if (_prevLen < 0) {
220207
throw std::runtime_error("llama_chat_apply_template() in LLMInference::stopCompletion() failed");
221208
}
222209
}
223210

224211
LLMInference::~LLMInference() {
225-
LOGi("deallocating LLMInference instance");
226212
// free memory held by the message text in messages
227213
// (as we had used strdup() to create a malloc'ed copy)
228-
for (llama_chat_message& message : _messages) {
229-
free(const_cast<char*>(message.role));
230-
free(const_cast<char*>(message.content));
214+
for (llama_chat_message &message: _messages) {
215+
free(const_cast<char *>(message.role));
216+
free(const_cast<char *>(message.content));
231217
}
232-
free(const_cast<char*>(_chatTemplate));
233-
llama_sampler_free(_sampler);
234218
llama_free(_ctx);
235219
llama_model_free(_model);
236-
}
220+
delete _batch;
221+
llama_sampler_free(_sampler);
222+
}

smollm/src/main/cpp/LLMInference.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
#pragma once
12
#include "llama.h"
2-
#include <jni.h>
3+
#include "common.h"
34
#include <string>
45
#include <vector>
56

@@ -9,7 +10,7 @@ class LLMInference {
910
llama_model* _model;
1011
llama_sampler* _sampler;
1112
llama_token _currToken;
12-
llama_batch _batch;
13+
llama_batch* _batch;
1314

1415
// container to store user/assistant messages in the chat
1516
std::vector<llama_chat_message> _messages;

0 commit comments

Comments
 (0)