Skip to content

Commit 48b432b

Browse files
Add function for greedy sampling
1 parent 1f3b811 commit 48b432b

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

app/src/main/cpp/hips.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#define TAG "hips.cpp" // Logcat tag to identify entries from hips.cpp
2626
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) // Log info message
27+
#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) // Log warning message
2728
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) // Log error message
2829

2930
/**
@@ -249,4 +250,47 @@ extern "C" JNIEXPORT jstring JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_de
249250
jstring jString = env -> NewStringUTF(cppString.c_str());
250251

251252
return jString;
253+
}
254+
255+
/**
256+
* Function to sample the next token based on the last one.
257+
*
258+
* @param env The JNI environment.
259+
* @param thiz Java object this function was called with.
260+
* @param lastToken ID of the last token.
261+
* @param jCtx Memory address of the context.
262+
* @param jSmpl Memory address of the sampler.
263+
* @return ID of the next token.
264+
*/
265+
extern "C" JNIEXPORT jint JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_sample(JNIEnv* env, jobject thiz, jint lastToken, jlong jCtx, jlong jSmpl) {
266+
// Cast memory addresses of context and sampler from Java long to C++ pointers
267+
// Casting the last token ID from jint to llama_token is not necessary since both is just int32_t
268+
auto cppCtx = reinterpret_cast<llama_context*>(jCtx);
269+
auto cppSmpl = reinterpret_cast<llama_sampler*>(jSmpl);
270+
271+
// Create a batch containing only the last token
272+
// TODO
273+
// llama.cpp docs: "NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it"
274+
// But is used like this in https://github.com/ggerganov/llama.cpp/blob/master/examples/simple/simple.cpp
275+
llama_batch batch = llama_batch_get_one(&lastToken, 1);
276+
277+
// Run decoder to calculate logits for the next token
278+
int32_t decode = llama_decode(cppCtx, batch);
279+
280+
// Log success or error message
281+
if (decode == 0) {
282+
LOGi("Java_org_vonderheidt_hips_utils_LlamaCpp_sample: decode = %d, success", decode);
283+
}
284+
else if (decode == 1) {
285+
LOGw("Java_org_vonderheidt_hips_utils_LlamaCpp_sample: decode = %d, could not find a KV slot for the batch", decode);
286+
}
287+
else {
288+
LOGe("Java_org_vonderheidt_hips_utils_LlamaCpp_sample: decode = %d, error. the KV cache state is restored to the state before this call", decode);
289+
}
290+
291+
// Sample next token from logits with given sampler and return it
292+
// Again, casting the next token ID is not necessary
293+
llama_token nextToken = llama_sampler_sample(cppSmpl, cppCtx, -1);
294+
295+
return nextToken;
252296
}

app/src/main/java/org/vonderheidt/hips/utils/LlamaCpp.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,13 @@ object LlamaCpp {
181181
* @return Detokenization as a string.
182182
*/
183183
external fun detokenize(tokens: IntArray, ctx: Long = this.ctx): String
184+
185+
/**
186+
* Wrapper for the `llama_sampler_sample` function of llama.cpp. Samples the next token based on the last one.
187+
*
188+
* @param lastToken ID of the last token.
189+
* @param ctx Memory address of the context.
190+
* @return ID of the next token.
191+
*/
192+
external fun sample(lastToken: Int, ctx: Long = this.ctx, smpl: Long = this.smpl): Int
184193
}

0 commit comments

Comments
 (0)