Skip to content

Commit a877db4

Browse files
Add functions to suppress special tokens
1 parent e7907f2 commit a877db4

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

app/src/main/cpp/hips.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,33 @@ extern "C" JNIEXPORT jstring JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_de
252252
return jString;
253253
}
254254

255+
/**
256+
* Function to check if a token is a special token.
257+
*
258+
* @param env The JNI environment.
259+
* @param thiz Java object this function was called with.
260+
* @param token Token ID to check.
261+
* @param jCtx Memory address of the context.
262+
* @return Boolean that is true if the token special, false otherwise.
263+
*/
264+
extern "C" JNIEXPORT jboolean JNICALL Java_org_vonderheidt_hips_utils_LlamaCpp_isSpecial(JNIEnv* env, jobject thiz, jint token, jlong jCtx) {
265+
// Cast memory address of the context from Java long to C++ pointer
266+
auto cppCtx = reinterpret_cast<llama_context*>(jCtx);
267+
268+
// Get model the context was created with
269+
const llama_model* model = llama_get_model(cppCtx);
270+
271+
// Check if token is special
272+
// Token ID doesn't need casting because jint and llama_token are both just int32_t
273+
bool cppIsSpecial = llama_token_is_eog(model, token) || llama_token_is_control(model,token);
274+
275+
// Cast boolean to return it
276+
// static_cast because casting booleans is type safe, unlike reinterpret_cast for casting C++ pointers to Java long
277+
auto jIsSpecial = static_cast<jboolean>(cppIsSpecial);
278+
279+
return jIsSpecial;
280+
}
281+
255282
/**
256283
* Function to calculate the logit matrix (i.e. predictions for every token in the prompt).
257284
*

app/src/main/java/org/vonderheidt/hips/utils/Huffman.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ object Huffman {
4141
// Only last row of logit matrix is needed as it contains logits corresponding to last token of the prompt
4242
val logits = LlamaCpp.getLogits(if (isFirstRun) contextTokens else intArrayOf(sampledToken)).last()
4343

44+
// Suppress special tokens to avoid early termination before all bits of secret message are encoded
45+
LlamaCpp.suppressSpecialTokens(logits)
46+
4447
// Get top 2^bitsPerToken logits for last token of prompt (= height of Huffman tree)
4548
val topLogits = getTopLogits(logits)
4649

@@ -124,6 +127,9 @@ object Huffman {
124127
// Calculate the logit matrix again initially from context tokens, then from last cover text token, and get last row
125128
val logits = LlamaCpp.getLogits(if (isFirstRun) contextTokens else intArrayOf(coverTextToken)).last()
126129

130+
// Suppress special tokens
131+
LlamaCpp.suppressSpecialTokens(logits)
132+
127133
// Get top 2^bitsPerToken logits
128134
val topLogits = getTopLogits(logits)
129135

app/src/main/java/org/vonderheidt/hips/utils/LlamaCpp.kt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,25 @@ object LlamaCpp {
9595
}
9696
}
9797

98+
/**
99+
* Function to suppress special tokens, i.e. eog (end-of-generation) and control tokens.
100+
*
101+
* Suppressing eog tokens is needed to avoid early termination when generating a cover text.
102+
*
103+
* Additionally suppressing control tokens is beneficial because the cover text then can't contain any invisible tokens.
104+
* This ensures integrity when using a non-digital communication medium.
105+
*
106+
* @param logits Logits for the last token of the prompt (= last row of logits matrix).
107+
*/
108+
fun suppressSpecialTokens(logits: FloatArray) {
109+
// Suppress special tokens by setting their logits to negative values
110+
for (token in logits.indices) {
111+
if (isSpecial(token)) {
112+
logits[token] = -100f
113+
}
114+
}
115+
}
116+
98117
/**
99118
* Function to check if a token is the end of a sentence. Needed to complete the last sentence of the cover text.
100119
*
@@ -193,6 +212,15 @@ object LlamaCpp {
193212
*/
194213
external fun getLogits(tokens: IntArray, ctx: Long = this.ctx): Array<FloatArray>
195214

215+
/**
216+
* Wrapper for the `llama_token_is_eog` and `llama_token_is_control` functions of llama.cpp. Checks if a token is a special token.
217+
*
218+
* @param token Token ID to check.
219+
* @param ctx Memory address of the context.
220+
* @return Boolean that is true if the token special, false otherwise.
221+
*/
222+
private external fun isSpecial(token: Int, ctx: Long = this.ctx): Boolean
223+
196224
/**
197225
* Wrapper for the `llama_sampler_sample` function of llama.cpp. Samples the next token based on the last one.
198226
*

0 commit comments

Comments
 (0)