Skip to content

Commit 33f3d1a

Browse files
Add getTopLogits function to Huffman
1 parent 20f8dac commit 33f3d1a

File tree

1 file changed

+19
-0
lines changed
  • app/src/main/java/org/vonderheidt/hips/utils

1 file changed

+19
-0
lines changed

app/src/main/java/org/vonderheidt/hips/utils/Huffman.kt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,23 @@ object Huffman {
3636

3737
return cipherBits
3838
}
39+
40+
/**
41+
* Function to get the top 2^bitsPerToken logits for the last token of the prompt. Keeps track of the corresponding token IDs in a map.
42+
*
43+
* Parameter `bits_per_word` from Stegasuras was renamed to `bitsPerToken`.
44+
*
45+
* @param logits Logits for the last token of the prompt (= last row of logits matrix).
46+
* @param bitsPerToken Number of bits to encode/decode per cover text token (= height of Huffman tree). Determined by Settings object.
47+
* @return Map of top 2^bitsPerToken logits and the corresponding token IDs.
48+
*/
49+
private fun getTopLogits(logits: FloatArray, bitsPerToken: Int = Settings.bitsPerToken): Map<Int, Float> {
50+
val topLogits = logits
51+
.mapIndexed{ token, logit -> token to logit } // Convert to List<Pair<Int, Float>> so token IDs won't get lost
52+
.sortedByDescending { it.second } // Sort pairs descending based on logits
53+
.take(1 shl bitsPerToken) // Take top 2^bitsPerToken pairs
54+
.toMap() // Convert to Map<Int, Float> for Huffman tree (ensures there can't be any duplicate token IDs)
55+
56+
return topLogits
57+
}
3958
}

0 commit comments

Comments
 (0)