@@ -252,6 +252,92 @@ extern "C" JNIEXPORT jstring JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_de
252252 return jString;
253253}
254254
255+ /* *
256+ * Function to calculate the logit matrix (i.e. predictions for every token in the prompt).
257+ *
258+ * Only the last row of the `n_tokens` x `n_vocab` matrix is actually needed as it contains the logits corresponding to the last token of the prompt.
259+ *
260+ * @param env The JNI environment.
261+ * @param thiz Java object this function was called with.
262+ * @param jTokens Token IDs from tokenization of the prompt.
263+ * @param jCtx Memory address of the context.
264+ * @return The logit matrix.
265+ */
266+ extern " C" JNIEXPORT jobjectArray JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_getLogits (JNIEnv* env, jobject thiz, jintArray jTokens, jlong jCtx) {
267+ // Cast memory addresses of context from Java long to C++ pointer
268+ auto cppCtx = reinterpret_cast <llama_context*>(jCtx);
269+
270+ // Get model the context was created with
271+ // No need to specify cppModel in variable name as there is no jModel
272+ const llama_model* model = llama_get_model (cppCtx);
273+
274+ // Copy token IDs from Java array to C++ array
275+ // Data types jint, jsize and int32_t are all equivalent
276+ jint* cppTokens = env -> GetIntArrayElements (jTokens, nullptr );
277+
278+ // C++ allows accessing illegal array indices and returns garbage values, doesn't throw IndexOutOfBoundsException like Java/Kotlin
279+ // Manually ensure that indices stay within dimensions n_tokens x n_vocab of the logit matrix
280+ jsize n_tokens = env -> GetArrayLength (jTokens);
281+ int32_t n_vocab = llama_n_vocab (model);
282+
283+ // Store tokens to be processed in batch data structure
284+ // llama.cpp example cited below stores multiple tokens from tokenization of the prompt in the first run, single last sampled token in subsequent runs
285+ // TODO
286+ // llama.cpp docs: "NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it"
287+ // But is used like this in https://github.com/ggerganov/llama.cpp/blob/master/examples/simple/simple.cpp
288+ llama_batch batch = llama_batch_get_one (cppTokens, n_tokens);
289+
290+ // Run decoder to calculate logits for the next token
291+ int32_t decode = llama_decode (cppCtx, batch);
292+
293+ // Log success or error message
294+ if (decode == 0 ) {
295+ LOGi (" Java_org_vonderheidt_hips_utils_LlamaCpp_getLogits: decode = %d, success" , decode);
296+ }
297+ else if (decode == 1 ) {
298+ LOGw (" Java_org_vonderheidt_hips_utils_LlamaCpp_getLogits: decode = %d, could not find a KV slot for the batch" , decode);
299+ }
300+ else {
301+ LOGe (" Java_org_vonderheidt_hips_utils_LlamaCpp_getLogits: decode = %d, error. the KV cache state is restored to the state before this call" , decode);
302+ }
303+
304+ // Get pointer to the logit matrix
305+ float * cppLogits = llama_get_logits (cppCtx);
306+
307+ // Copy logits from C++ matrix to Java matrix
308+ // Initialize outer Java array as object array with as many components as there are rows
309+ // TODO
310+ // n_rows should be n_tokens, but llama_get_logits consistently returns 1 x n_vocab matrix that generates reasonable text, so likely already is last row of actual logit matrix
311+ // Possibly related to dimension of batch when using llama_batch_get_one
312+ jsize n_rows = 1 ;
313+ jsize n_columns = n_vocab;
314+
315+ jobjectArray jLogits = env -> NewObjectArray (n_rows, env -> FindClass (" [F" ), nullptr );
316+
317+ // Fill outer array with inner arrays
318+ for (jsize i = 0 ; i < n_rows; i++) {
319+ // Initialize inner Java arrays as float arrays with as many components as there are columns
320+ jfloatArray row = env -> NewFloatArray (n_columns);
321+
322+ // Float array for row i contains the n_columns elements from cppLogits matrix that start at offset i * n_columns
323+ env -> SetFloatArrayRegion (row, 0 , n_columns, cppLogits + (i * n_columns));
324+
325+ // Put float array as row i into outer Java array
326+ env -> SetObjectArrayElement (jLogits, i, row);
327+
328+ // Free local reference to float array from memory
329+ env -> DeleteLocalRef (row);
330+ }
331+
332+ // Free C++ array of token IDs from memory
333+ env -> ReleaseIntArrayElements (jTokens, cppTokens, 0 );
334+
335+ // llama.cpp example cited above doesn't use llama_batch_free(batch) to free the batch from memory, actually causes crash if done here
336+ // Does only free model, context and sampler after text generation finishes, but those are managed by the LlamaCpp Kotlin object
337+
338+ return jLogits;
339+ }
340+
255341/* *
256342 * Function to sample the next token based on the last one.
257343 *
0 commit comments