11#include " LLMInference.h"
2- #include " llama.h"
3- #include " gguf.h"
42#include < android/log.h>
53#include < cstring>
64#include < iostream>
97#define LOGi (...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
108#define LOGe (...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
119
12- std::vector<llama_token> common_tokenize (
13- const struct llama_vocab * vocab,
14- const std::string & text,
15- bool add_special,
16- bool parse_special = false );
17-
18- std::string common_token_to_piece (
19- const struct llama_context * ctx,
20- llama_token token,
21- bool special = true );
22-
2310void
24- LLMInference::loadModel (const char * model_path, float minP, float temperature, bool storeChats, long contextSize,
25- const char * chatTemplate, int nThreads, bool useMmap, bool useMlock) {
11+ LLMInference::loadModel (const char * model_path, float minP, float temperature, bool storeChats, long contextSize,
12+ const char * chatTemplate, int nThreads, bool useMmap, bool useMlock) {
2613 LOGi (" loading model with"
2714 " \n\t model_path = %s"
2815 " \n\t minP = %f"
@@ -35,56 +22,57 @@ LLMInference::loadModel(const char* model_path, float minP, float temperature, b
3522 " \n\t useMlock = %d" ,
3623 model_path, minP, temperature, storeChats, contextSize, chatTemplate, nThreads, useMmap, useMlock);
3724
25+ // load dynamic backends
26+ ggml_backend_load_all ();
27+
3828 // create an instance of llama_model
3929 llama_model_params model_params = llama_model_default_params ();
40- model_params.use_mmap = useMmap;
41- model_params.use_mlock = useMlock;
42- _model = llama_model_load_from_file (model_path, model_params);
43-
30+ model_params.use_mmap = useMmap;
31+ model_params.use_mlock = useMlock;
32+ _model = llama_model_load_from_file (model_path, model_params);
4433 if (!_model) {
4534 LOGe (" failed to load model from %s" , model_path);
4635 throw std::runtime_error (" loadModel() failed" );
4736 }
4837
4938 // create an instance of llama_context
5039 llama_context_params ctx_params = llama_context_default_params ();
51- ctx_params.n_ctx = contextSize;
52- ctx_params.n_threads = nThreads ;
53- ctx_params.no_perf = true ; // disable performance metrics
54- _ctx = llama_init_from_model (_model, ctx_params);
55-
40+ ctx_params.n_ctx = contextSize;
41+ ctx_params.n_batch = contextSize ;
42+ ctx_params.n_threads = nThreads;
43+ ctx_params. no_perf = true ; // disable performance metrics
44+ _ctx = llama_init_from_model (_model, ctx_params);
5645 if (!_ctx) {
5746 LOGe (" llama_new_context_with_model() returned null)" );
5847 throw std::runtime_error (" llama_new_context_with_model() returned null" );
5948 }
6049
61- // initialize sampler
50+ // create an instance of llama_sampler
6251 llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params ();
63- sampler_params.no_perf = true ; // disable performance metrics
64- _sampler = llama_sampler_chain_init (sampler_params);
65-
52+ sampler_params.no_perf = true ; // disable performance metrics
53+ _sampler = llama_sampler_chain_init (sampler_params);
6654 llama_sampler_chain_add (_sampler, llama_sampler_init_temp (temperature));
6755 llama_sampler_chain_add (_sampler, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
68- if (minP >= 0 .01f ) {
69- // minP = 0.0 (disabled)
70- // minP can be adjusted across 100 steps between [0.0,1.0], the smallest step being 0.01
71- llama_sampler_chain_add (_sampler, llama_sampler_init_min_p (minP, 1 ));
72- }
7356
7457 _formattedMessages = std::vector<char >(llama_n_ctx (_ctx));
7558 _messages.clear ();
76- _chatTemplate = strdup (chatTemplate);
59+
60+ if (chatTemplate == nullptr ) {
61+ _chatTemplate = llama_model_chat_template (_model, nullptr );
62+ } else {
63+ _chatTemplate = strdup (chatTemplate);
64+ }
7765 this ->_storeChats = storeChats;
7866}
7967
8068void
81- LLMInference::addChatMessage (const char * message, const char * role) {
82- _messages.push_back ({ strdup (role), strdup (message) });
69+ LLMInference::addChatMessage (const char * message, const char * role) {
70+ _messages.push_back ({strdup (role), strdup (message)});
8371}
8472
8573float
8674LLMInference::getResponseGenerationTime () const {
87- return (float )_responseNumTokens / (_responseGenerationTime / 1e6 );
75+ return (float ) _responseNumTokens / (_responseGenerationTime / 1e6 );
8876}
8977
9078int
@@ -93,19 +81,19 @@ LLMInference::getContextSizeUsed() const {
9381}
9482
9583void
96- LLMInference::startCompletion (const char * query) {
84+ LLMInference::startCompletion (const char * query) {
9785 if (!_storeChats) {
9886 _prevLen = 0 ;
9987 _formattedMessages.clear ();
10088 _formattedMessages = std::vector<char >(llama_n_ctx (_ctx));
10189 }
10290 _responseGenerationTime = 0 ;
103- _responseNumTokens = 0 ;
91+ _responseNumTokens = 0 ;
10492 addChatMessage (query, " user" );
10593 // apply the chat-template
10694 int newLen = llama_chat_apply_template (_chatTemplate, _messages.data (), _messages.size (), true ,
10795 _formattedMessages.data (), _formattedMessages.size ());
108- if (newLen > (int )_formattedMessages.size ()) {
96+ if (newLen > (int ) _formattedMessages.size ()) {
10997 // resize the output buffer `_formattedMessages`
11098 // and re-apply the chat template
11199 _formattedMessages.resize (newLen);
@@ -120,19 +108,20 @@ LLMInference::startCompletion(const char* query) {
120108
121109 // create a llama_batch containing a single sequence
122110 // see llama_batch_init for more details
123- _batch.token = _promptTokens.data ();
124- _batch.n_tokens = _promptTokens.size ();
111+ _batch = new llama_batch ();
112+ _batch->token = _promptTokens.data ();
113+ _batch->n_tokens = _promptTokens.size ();
125114}
126115
127116// taken from:
128117// https://github.com/ggerganov/llama.cpp/blob/master/examples/llama.android/llama/src/main/cpp/llama-android.cpp#L38
129118bool
130- LLMInference::_isValidUtf8 (const char * response) {
119+ LLMInference::_isValidUtf8 (const char * response) {
131120 if (!response) {
132121 return true ;
133122 }
134- const unsigned char * bytes = (const unsigned char *) response;
135- int num;
123+ const unsigned char * bytes = (const unsigned char *) response;
124+ int num;
136125 while (*bytes != 0x00 ) {
137126 if ((*bytes & 0x80 ) == 0x00 ) {
138127 // U+0000 to U+007F
@@ -166,14 +155,14 @@ LLMInference::completionLoop() {
166155 // check if the length of the inputs to the model
167156 // have exceeded the context size of the model
168157 uint32_t contextSize = llama_n_ctx (_ctx);
169- _nCtxUsed = llama_kv_self_used_cells ( _ctx);
170- if (_nCtxUsed + _batch. n_tokens > contextSize) {
158+ _nCtxUsed = llama_memory_seq_pos_max ( llama_get_memory ( _ctx), 0 ) + 1 ;
159+ if (_nCtxUsed + _batch-> n_tokens > contextSize) {
171160 throw std::runtime_error (" context size reached" );
172161 }
173162
174163 auto start = ggml_time_us ();
175164 // run the model
176- if (llama_decode (_ctx, _batch) < 0 ) {
165+ if (llama_decode (_ctx, * _batch) < 0 ) {
177166 throw std::runtime_error (" llama_decode() failed" );
178167 }
179168
@@ -186,7 +175,6 @@ LLMInference::completionLoop() {
186175 return " [EOG]" ;
187176 }
188177 std::string piece = common_token_to_piece (_ctx, _currToken, true );
189- LOGi (" common_token_to_piece: %s" , piece.c_str ());
190178 auto end = ggml_time_us ();
191179 _responseGenerationTime += (end - start);
192180 _responseNumTokens += 1 ;
@@ -195,8 +183,8 @@ LLMInference::completionLoop() {
195183 // re-init the batch with the newly predicted token
196184 // key, value pairs of all previous tokens have been cached
197185 // in the KV cache
198- _batch. token = &_currToken;
199- _batch. n_tokens = 1 ;
186+ _batch-> token = &_currToken;
187+ _batch-> n_tokens = 1 ;
200188
201189 if (_isValidUtf8 (_cacheResponseTokens.c_str ())) {
202190 _response += _cacheResponseTokens;
@@ -214,23 +202,21 @@ LLMInference::stopCompletion() {
214202 addChatMessage (_response.c_str (), " assistant" );
215203 }
216204 _response.clear ();
217- const char * tmpl = llama_model_chat_template (_model, nullptr );
218- _prevLen = llama_chat_apply_template (tmpl, _messages.data (), _messages.size (), false , nullptr , 0 );
205+ _prevLen = llama_chat_apply_template (_chatTemplate, _messages.data (), _messages.size (), false , nullptr , 0 );
219206 if (_prevLen < 0 ) {
220207 throw std::runtime_error (" llama_chat_apply_template() in LLMInference::stopCompletion() failed" );
221208 }
222209}
223210
224211LLMInference::~LLMInference () {
225- LOGi (" deallocating LLMInference instance" );
226212 // free memory held by the message text in messages
227213 // (as we had used strdup() to create a malloc'ed copy)
228- for (llama_chat_message& message : _messages) {
229- free (const_cast <char *>(message.role ));
230- free (const_cast <char *>(message.content ));
214+ for (llama_chat_message & message: _messages) {
215+ free (const_cast <char *>(message.role ));
216+ free (const_cast <char *>(message.content ));
231217 }
232- free (const_cast <char *>(_chatTemplate));
233- llama_sampler_free (_sampler);
234218 llama_free (_ctx);
235219 llama_model_free (_model);
236- }
220+ delete _batch;
221+ llama_sampler_free (_sampler);
222+ }
0 commit comments