2424
2525#define TAG " hips.cpp" // Logcat tag to identify entries from hips.cpp
2626#define LOGi (...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) // Log info message
27+ #define LOGw (...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) // Log warning message
2728#define LOGe (...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) // Log error message
2829
2930/* *
@@ -249,4 +250,47 @@ extern "C" JNIEXPORT jstring JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_de
249250 jstring jString = env -> NewStringUTF (cppString.c_str ());
250251
251252 return jString;
253+ }
254+
255+ /* *
256+ * Function to sample the next token based on the last one.
257+ *
258+ * @param env The JNI environment.
259+ * @param thiz Java object this function was called with.
260+ * @param lastToken ID of the last token.
261+ * @param jCtx Memory address of the context.
262+ * @param jSmpl Memory address of the sampler.
263+ * @return ID of the next token.
264+ */
265+ extern " C" JNIEXPORT jint JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_sample (JNIEnv* env, jobject thiz, jint lastToken, jlong jCtx, jlong jSmpl) {
266+ // Cast memory addresses of context and sampler from Java long to C++ pointers
267+ // Casting the last token ID from jint to llama_token is not necessary since both is just int32_t
268+ auto cppCtx = reinterpret_cast <llama_context*>(jCtx);
269+ auto cppSmpl = reinterpret_cast <llama_sampler*>(jSmpl);
270+
271+ // Create a batch containing only the last token
272+ // TODO
273+ // llama.cpp docs: "NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it"
274+ // But is used like this in https://github.com/ggerganov/llama.cpp/blob/master/examples/simple/simple.cpp
275+ llama_batch batch = llama_batch_get_one (&lastToken, 1 );
276+
277+ // Run decoder to calculate logits for the next token
278+ int32_t decode = llama_decode (cppCtx, batch);
279+
280+ // Log success or error message
281+ if (decode == 0 ) {
282+ LOGi (" Java_org_vonderheidt_hips_utils_LlamaCpp_sample: decode = %d, success" , decode);
283+ }
284+ else if (decode == 1 ) {
285+ LOGw (" Java_org_vonderheidt_hips_utils_LlamaCpp_sample: decode = %d, could not find a KV slot for the batch" , decode);
286+ }
287+ else {
288+ LOGe (" Java_org_vonderheidt_hips_utils_LlamaCpp_sample: decode = %d, error. the KV cache state is restored to the state before this call" , decode);
289+ }
290+
291+ // Sample next token from logits with given sampler and return it
292+ // Again, casting the next token ID is not necessary
293+ llama_token nextToken = llama_sampler_sample (cppSmpl, cppCtx, -1 );
294+
295+ return nextToken;
252296}
0 commit comments