Skip to content

Commit 70ce330

Browse files
Add function to get logits
1 parent 84e9a39 commit 70ce330

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

app/src/main/cpp/hips.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,92 @@ extern "C" JNIEXPORT jstring JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_de
252252
return jString;
253253
}
254254

255+
/**
256+
* Function to calculate the logit matrix (i.e. predictions for every token in the prompt).
257+
*
258+
* Only the last row of the `n_tokens` x `n_vocab` matrix is actually needed as it contains the logits corresponding to the last token of the prompt.
259+
*
260+
* @param env The JNI environment.
261+
* @param thiz Java object this function was called with.
262+
* @param jTokens Token IDs from tokenization of the prompt.
263+
* @param jCtx Memory address of the context.
264+
* @return The logit matrix.
265+
*/
266+
extern "C" JNIEXPORT jobjectArray JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_getLogits(JNIEnv* env, jobject thiz, jintArray jTokens, jlong jCtx) {
267+
// Cast memory addresses of context from Java long to C++ pointer
268+
auto cppCtx = reinterpret_cast<llama_context*>(jCtx);
269+
270+
// Get model the context was created with
271+
// No need to specify cppModel in variable name as there is no jModel
272+
const llama_model* model = llama_get_model(cppCtx);
273+
274+
// Copy token IDs from Java array to C++ array
275+
// Data types jint, jsize and int32_t are all equivalent
276+
jint* cppTokens = env -> GetIntArrayElements(jTokens, nullptr);
277+
278+
// C++ allows accessing illegal array indices and returns garbage values, doesn't throw IndexOutOfBoundsException like Java/Kotlin
279+
// Manually ensure that indices stay within dimensions n_tokens x n_vocab of the logit matrix
280+
jsize n_tokens = env -> GetArrayLength(jTokens);
281+
int32_t n_vocab = llama_n_vocab(model);
282+
283+
// Store tokens to be processed in batch data structure
284+
// llama.cpp example cited below stores multiple tokens from tokenization of the prompt in the first run, single last sampled token in subsequent runs
285+
// TODO
286+
// llama.cpp docs: "NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it"
287+
// But is used like this in https://github.com/ggerganov/llama.cpp/blob/master/examples/simple/simple.cpp
288+
llama_batch batch = llama_batch_get_one(cppTokens, n_tokens);
289+
290+
// Run decoder to calculate logits for the next token
291+
int32_t decode = llama_decode(cppCtx, batch);
292+
293+
// Log success or error message
294+
if (decode == 0) {
295+
LOGi("Java_org_vonderheidt_hips_utils_LlamaCpp_getLogits: decode = %d, success", decode);
296+
}
297+
else if (decode == 1) {
298+
LOGw("Java_org_vonderheidt_hips_utils_LlamaCpp_getLogits: decode = %d, could not find a KV slot for the batch", decode);
299+
}
300+
else {
301+
LOGe("Java_org_vonderheidt_hips_utils_LlamaCpp_getLogits: decode = %d, error. the KV cache state is restored to the state before this call", decode);
302+
}
303+
304+
// Get pointer to the logit matrix
305+
float* cppLogits = llama_get_logits(cppCtx);
306+
307+
// Copy logits from C++ matrix to Java matrix
308+
// Initialize outer Java array as object array with as many components as there are rows
309+
// TODO
310+
// n_rows should be n_tokens, but llama_get_logits consistently returns 1 x n_vocab matrix that generates reasonable text, so likely already is last row of actual logit matrix
311+
// Possibly related to dimension of batch when using llama_batch_get_one
312+
jsize n_rows = 1;
313+
jsize n_columns = n_vocab;
314+
315+
jobjectArray jLogits = env -> NewObjectArray(n_rows, env -> FindClass("[F"), nullptr);
316+
317+
// Fill outer array with inner arrays
318+
for (jsize i = 0; i < n_rows; i++) {
319+
// Initialize inner Java arrays as float arrays with as many components as there are columns
320+
jfloatArray row = env -> NewFloatArray(n_columns);
321+
322+
// Float array for row i contains the n_columns elements from cppLogits matrix that start at offset i * n_columns
323+
env -> SetFloatArrayRegion(row, 0, n_columns, cppLogits + (i * n_columns));
324+
325+
// Put float array as row i into outer Java array
326+
env -> SetObjectArrayElement(jLogits, i, row);
327+
328+
// Free local reference to float array from memory
329+
env -> DeleteLocalRef(row);
330+
}
331+
332+
// Free C++ array of token IDs from memory
333+
env -> ReleaseIntArrayElements(jTokens, cppTokens, 0);
334+
335+
// llama.cpp example cited above doesn't use llama_batch_free(batch) to free the batch from memory, actually causes crash if done here
336+
// Does only free model, context and sampler after text generation finishes, but those are managed by the LlamaCpp Kotlin object
337+
338+
return jLogits;
339+
}
340+
255341
/**
256342
* Function to sample the next token based on the last one.
257343
*

app/src/main/java/org/vonderheidt/hips/utils/LlamaCpp.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,17 @@ object LlamaCpp {
182182
*/
183183
external fun detokenize(tokens: IntArray, ctx: Long = this.ctx): String
184184

185+
/**
186+
* Wrapper for the `llama_get_logits` function of llama.cpp. Calculates the logit matrix (i.e. predictions for every token in the prompt).
187+
*
188+
* Only the last row of the `n_tokens` x `n_vocab` matrix is actually needed as it contains the logits corresponding to the last token of the prompt.
189+
*
190+
* @param tokens Token IDs from tokenization of the prompt.
191+
* @param ctx Memory address of the context.
192+
* @return The logit matrix.
193+
*/
194+
external fun getLogits(tokens: IntArray, ctx: Long = this.ctx): Array<FloatArray>
195+
185196
/**
186197
* Wrapper for the `llama_sampler_sample` function of llama.cpp. Samples the next token based on the last one.
187198
*

0 commit comments

Comments
 (0)