Status: Complete (3 of 4 optimizations - encoder caching deferred) Priority: High (latency-sensitive) Last Updated: 2026-01-26
This document outlines optimization opportunities for the ONNX neural network inference pipeline used in swipe typing prediction. The target is <100ms end-to-end latency from swipe gesture completion to word prediction display.
Swipe Gesture → SwipeTrajectoryProcessor → EncoderWrapper → BeamSearchEngine → Word Predictions
↓ ↓ ↓
Feature Extraction ONNX Encoder ONNX Decoder (per step)
(~5-10ms) (~15-25ms) (~10-20ms × N steps)
| File | Purpose |
|---|---|
onnx/BeamSearchEngine.kt |
Beam search decoding with decoder calls |
onnx/EncoderWrapper.kt |
Encoder model inference |
onnx/DecoderWrapper.kt |
Decoder model inference |
onnx/TensorFactory.kt |
Tensor creation utilities |
onnx/SessionConfigurator.kt |
ONNX Runtime session options |
onnx/SwipePredictorOrchestrator.kt |
Pipeline coordinator |
SwipeTrajectoryProcessor.kt |
Feature extraction pipeline |
Problem: Creating new OnnxTensor objects for every beam search step causes GC pressure and allocation overhead.
Current Code (BeamSearchEngine.kt:246-258):
// Created ONCE per step (good):
val actualSrcLengthTensor = OnnxTensor.createTensor(ortEnvironment, intArrayOf(actualSrcLength))
// Created PER BEAM (bad - 5 beams × 10 steps = 50 allocations):
for (beam in activeBeams) {
val tgtTokens = IntArray(DECODER_SEQ_LEN) { PAD_IDX } // NEW allocation
// ... populate tokens ...
val targetTokensTensor = OnnxTensor.createTensor(ortEnvironment,
java.nio.IntBuffer.wrap(tgtTokens), longArrayOf(1, DECODER_SEQ_LEN.toLong()))
}Impact: ~50 tensor allocations per prediction (5 beams × 10 steps)
Problem: Sequential processing of beams makes N separate ONNX inference calls per step.
Current Code (BeamSearchEngine.kt:139-140):
// SEQUENTIAL PROCESSING (current)
val nextBeams = processSequential(activeBeams, memory, actualSrcLength, step)Existing Infrastructure (DecoderWrapper.kt:107-168):
// ALREADY EXISTS but NOT USED by BeamSearchEngine!
fun decodeBatched(
memory: OnnxTensor,
beamTokens: List<List<Long>>,
actualSrcLength: Int,
decoderSeqLength: Int,
step: Int = 0
): DecoderResultImpact: 5 beams × 10 steps = 50 decoder calls → 1 call per step = 10 decoder calls (5x reduction)
Problem: Encoder output (memory tensor) is computed fresh for every prediction, even for similar swipe patterns.
Current State: No caching between predictions.
Potential: Similar swipe patterns (same start/end keys, similar trajectory shape) could reuse encoder output.
Problem: Fixed thread count (4) may not be optimal for all devices.
Current Code (SessionConfigurator.kt:84-90):
// HARDCODED to 4 threads
xnnOptions["intra_op_num_threads"] = "4"
sessionOptions.setIntraOpNumThreads(4)Solution: Make configurable via user settings (see Section 4.4).
Goal: Cache the actual_src_length tensor across beam search steps.
Background: Initial implementation included pre-allocated IntArray buffers and Direct ByteBuffers for target tokens. Performance analysis revealed JVM is highly optimized for small, short-lived allocations (80-byte arrays). The overhead of Direct ByteBuffer management outweighed savings.
Final Implementation - only cachedSrcLengthTensor is reused:
class BeamSearchEngine(...) {
// Cached actualSrcLength tensor (recreated only when length changes)
// Saves ~15 OnnxTensor creations per prediction (one per step -> one per search)
private var cachedSrcLength: Int = -1
private var cachedSrcLengthTensor: OnnxTensor? = null
fun search(...): List<BeamSearchCandidate> {
try {
// Main decoding loop...
} finally {
cleanup() // Release native memory
}
}
private fun processSequential(...) {
// OPTIMIZATION: Reuse actualSrcLengthTensor if length unchanged
if (actualSrcLength != cachedSrcLength) {
cachedSrcLengthTensor?.close()
cachedSrcLengthTensor = OnnxTensor.createTensor(ortEnvironment, intArrayOf(actualSrcLength))
cachedSrcLength = actualSrcLength
}
for (beam in activeBeams) {
// Simple allocation - JVM optimized for small arrays
val tgtTokens = IntArray(DECODER_SEQ_LEN) { PAD_IDX }
// ... populate and create tensor
}
}
fun cleanup() {
cachedSrcLengthTensor?.close()
cachedSrcLengthTensor = null
cachedSrcLength = -1
}
}Why only srcLength caching?
actual_src_lengthis a native OnnxTensor wrapping a single int- Created ~15 times per prediction (once per decoding step)
- Native tensor allocation has higher overhead than JVM arrays
- Target token IntArrays (80 bytes) are efficiently allocated by JVM
Estimated Impact: ~15 native tensor allocations saved per prediction
Goal: Use existing DecoderWrapper.decodeBatched() instead of sequential processing.
Changes to BeamSearchEngine.kt:
Option A: Refactor to use DecoderWrapper (Cleaner, but requires constructor change)
class BeamSearchEngine(
private val decoderWrapper: DecoderWrapper, // CHANGED: Use wrapper instead of raw session
private val ortEnvironment: OrtEnvironment,
// ... rest of params ...
) {
// Remove: private val decoderSession: OrtSessionOption B: Add batched processing alongside existing code (Less invasive)
private fun processBatched(
activeBeams: List<BeamState>,
memory: OnnxTensor,
actualSrcLength: Int,
step: Int
): List<BeamState> {
val newCandidates = ArrayList<BeamState>()
val numBeams = activeBeams.size
// 1. Prepare batched input tensor [numBeams, DECODER_SEQ_LEN]
val batchedTokens = IntArray(numBeams * DECODER_SEQ_LEN)
for (b in activeBeams.indices) {
val beam = activeBeams[b]
val len = min(beam.tokens.size, DECODER_SEQ_LEN)
for (i in 0 until len) {
batchedTokens[b * DECODER_SEQ_LEN + i] = beam.tokens[i].toInt()
}
// Rest is already PAD_IDX (0) from array initialization
}
val batchedShape = longArrayOf(numBeams.toLong(), DECODER_SEQ_LEN.toLong())
val batchedTokensTensor = OnnxTensor.createTensor(
ortEnvironment,
java.nio.IntBuffer.wrap(batchedTokens),
batchedShape
)
// 2. Create src_length tensor (broadcast model uses single value)
val srcLengthTensor = OnnxTensor.createTensor(ortEnvironment, intArrayOf(actualSrcLength))
try {
// 3. Single decoder call for ALL beams
val inputs = mapOf(
"memory" to memory,
"actual_src_length" to srcLengthTensor,
"target_tokens" to batchedTokensTensor
)
val result = decoderSession.run(inputs)
val logitsTensor = result.get(0) as OnnxTensor
// 4. Extract logits [numBeams, DECODER_SEQ_LEN, vocabSize]
@Suppress("UNCHECKED_CAST")
val logits3D = logitsTensor.value as Array<Array<FloatArray>>
// 5. Process each beam's logits
for (b in activeBeams.indices) {
val beam = activeBeams[b]
val currentPos = beam.tokens.size - 1
if (currentPos in 0 until DECODER_SEQ_LEN) {
val logits = logits3D[b][currentPos]
// Apply Trie Masking, Prefix Boosts, etc. (same as sequential)
applyTrieMasking(beam, logits)
val appliedBoosts = applyPrefixBoosts(beam, logits)
val logProbs = logSoftmax(logits)
val topIndices = getTopKIndices(logProbs, beamWidth)
// Create new beam candidates (same as sequential)
for (idx in topIndices) {
// ... same logic as processSequential ...
}
}
}
result.close()
} finally {
batchedTokensTensor.close()
srcLengthTensor.close()
}
return newCandidates
}Modify search() to use batched mode:
fun search(memory: OnnxTensor, actualSrcLength: Int, useBatched: Boolean = false): List<BeamSearchCandidate> {
// ... existing setup ...
while (step < maxLength) {
// ... existing beam filtering ...
try {
val startInf = System.nanoTime()
// CHANGED: Use batched processing when enabled
val nextBeams = if (useBatched && activeBeams.size > 1) {
processBatched(activeBeams, memory, actualSrcLength, step)
} else {
processSequential(activeBeams, memory, actualSrcLength, step)
}
candidates.addAll(nextBeams)
// ... rest unchanged ...
}
}
}Estimated Impact: 5x reduction in decoder calls (50 → 10 per prediction)
Goal: Cache encoder output for similar swipe trajectories.
Design Considerations:
- Key: Hash of trajectory fingerprint (start/end keys, sampled waypoints, length)
- Value: Encoder memory tensor
- Size: LRU cache with ~10 entries (encoder output is ~200KB per entry)
- Invalidation: TTL-based (5 minutes) or explicit clear on keyboard resize
New class: EncoderCache.kt:
package tribixbite.cleverkeys.onnx
import ai.onnxruntime.OnnxTensor
import android.util.LruCache
import tribixbite.cleverkeys.SwipeTrajectoryProcessor
/**
* LRU cache for encoder memory tensors.
*
* Caches encoder output based on trajectory fingerprint to avoid
* re-encoding similar swipe patterns.
*/
class EncoderCache(
maxEntries: Int = 10,
private val ttlMs: Long = 5 * 60 * 1000 // 5 minutes
) {
data class CacheEntry(
val memory: OnnxTensor,
val actualLength: Int,
val timestamp: Long = System.currentTimeMillis()
) {
fun isExpired(): Boolean = System.currentTimeMillis() - timestamp > ttlMs
}
private val cache = object : LruCache<Long, CacheEntry>(maxEntries) {
override fun entryRemoved(
evicted: Boolean,
key: Long,
oldValue: CacheEntry,
newValue: CacheEntry?
) {
// Close evicted tensor to prevent memory leak
if (evicted) {
oldValue.memory.close()
}
}
}
/**
* Generate trajectory fingerprint hash.
*
* Uses: start key, end key, path length, 4 sampled waypoint keys
*/
fun computeFingerprint(features: SwipeTrajectoryProcessor.TrajectoryFeatures): Long {
val keys = features.nearestKeys
if (keys.isEmpty()) return 0L
var hash = 17L
// Start key
hash = hash * 31 + keys.first()
// End key
hash = hash * 31 + keys.last()
// Path length (binned to 10-point buckets)
hash = hash * 31 + (features.actualLength / 10)
// 4 sampled waypoints (25%, 50%, 75%)
val step = keys.size / 4
if (step > 0) {
for (i in 1..3) {
hash = hash * 31 + keys.getOrElse(i * step) { 0 }
}
}
return hash
}
fun get(fingerprint: Long): CacheEntry? {
val entry = cache.get(fingerprint)
if (entry != null && entry.isExpired()) {
cache.remove(fingerprint)
entry.memory.close()
return null
}
return entry
}
fun put(fingerprint: Long, memory: OnnxTensor, actualLength: Int) {
cache.put(fingerprint, CacheEntry(memory, actualLength))
}
fun clear() {
// Close all cached tensors
cache.snapshot().values.forEach { it.memory.close() }
cache.evictAll()
}
fun stats(): String {
return "EncoderCache: ${cache.size()}/${cache.maxSize()} entries"
}
}Integration in SwipePredictorOrchestrator.kt:
class SwipePredictorOrchestrator private constructor(private val context: Context) {
// ... existing fields ...
// NEW: Encoder cache
private val encoderCache = EncoderCache(maxEntries = 10)
fun predict(input: SwipeInput): PredictionPostProcessor.Result {
// ... existing feature extraction ...
// NEW: Check encoder cache
val fingerprint = encoderCache.computeFingerprint(features)
val cachedEntry = encoderCache.get(fingerprint)
val memory: OnnxTensor
val encoderTime: Long
if (cachedEntry != null) {
// Cache hit - reuse encoder output
memory = cachedEntry.memory
encoderTime = 0L
if (debugModeActive) {
logDebug("⚡ Encoder CACHE HIT (fingerprint=$fingerprint)\n")
}
} else {
// Cache miss - run encoder
val encoderStartTime = System.currentTimeMillis()
val encoderResult = encoderWrapper!!.encode(features)
memory = encoderResult.memory
encoderTime = System.currentTimeMillis() - encoderStartTime
// Store in cache (don't cache if actualLength differs significantly)
encoderCache.put(fingerprint, memory, features.actualLength)
if (debugModeActive) {
logDebug("⚡ Encoder: ${encoderTime}ms (cached for fingerprint=$fingerprint)\n")
}
}
// ... rest of prediction (unchanged) ...
}
fun cleanup() {
encoderCache.clear() // NEW
// ... existing cleanup ...
}
}Estimated Impact: ~20-30% latency reduction for repeated similar swipes
Risk: Tensor lifecycle management - must ensure cached tensors aren't closed while in use.
Goal: Allow users to configure XNNPACK thread count based on their device.
Step 1: Add to Defaults (Config.kt):
object Defaults {
// ... existing defaults ...
// ONNX Runtime settings
const val ONNX_XNNPACK_THREADS = 2 // Default to 2 (good for most ARM devices)
}Step 2: Add to Config class (Config.kt):
class Config private constructor(prefs: SharedPreferences, res: Resources) {
// ... existing fields ...
@JvmField var onnx_xnnpack_threads = Defaults.ONNX_XNNPACK_THREADS
private fun refresh(prefs: SharedPreferences, res: Resources) {
// ... existing refresh code ...
onnx_xnnpack_threads = prefs.getInt("onnx_xnnpack_threads", Defaults.ONNX_XNNPACK_THREADS)
}
}Step 3: Modify SessionConfigurator.kt:
object SessionConfigurator {
private const val TAG = "SessionConfigurator"
fun createOptimizedSessionOptions(
context: Context?,
sessionName: String,
xnnpackThreads: Int = 2 // NEW parameter with default
): OrtSession.SessionOptions {
// ... existing code ...
}
private fun tryEnableHardwareAcceleration(
sessionOptions: OrtSession.SessionOptions,
sessionName: String,
xnnpackThreads: Int // NEW parameter
) {
// ... NNAPI and QNN attempts ...
// Try XNNPACK
try {
val xnnOptions = HashMap<String, String>()
xnnOptions["intra_op_num_threads"] = xnnpackThreads.toString() // CHANGED
sessionOptions.addXnnpack(xnnOptions)
sessionOptions.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL)
sessionOptions.setIntraOpNumThreads(xnnpackThreads) // CHANGED
Log.i(TAG, "✅ XNNPACK enabled for $sessionName (threads=$xnnpackThreads)")
} catch (e: Exception) {
Log.w(TAG, "XNNPACK failed, using CPU", e)
}
}
}Step 4: Add to SettingsActivity.kt (in Neural Settings section):
// In the neural prediction settings section:
SearchableSetting(
"ONNX Thread Count",
listOf("performance", "threads", "xnnpack", "cpu"),
{ settings ->
val current = _config.onnx_xnnpack_threads
val threadCounts = listOf(1, 2, 4, 6, 8)
val displayValues = threadCounts.map {
if (it == 2) "$it threads (Recommended)" else "$it threads"
}
showListDialog(
"ONNX Inference Threads",
"Number of CPU threads for neural inference. Lower values reduce battery usage, higher values may improve speed on multi-core devices.",
displayValues.toTypedArray(),
threadCounts.indexOf(current).coerceAtLeast(0)
) { index ->
_prefs.edit().putInt("onnx_xnnpack_threads", threadCounts[index]).apply()
Toast.makeText(this, "Restart app for thread change to take effect", Toast.LENGTH_LONG).show()
}
}
),Note: Session options are set at model load time, so changes require app restart or explicit session recreation.
Estimated Impact: Variable (device-dependent). 2 threads often optimal for ARM.
| Optimization | Impact | Effort | Priority | Status |
|---|---|---|---|---|
| Batched Decoding | High (5x fewer decoder calls) | Medium | 1 | ✅ Implemented - disabled by default |
| Tensor Reuse | Medium (15 allocations saved) | Low | 2 | ✅ Simplified - srcLength only |
| XNNPACK Threads Setting | Low-Medium | Low | 3 | ✅ Implemented with UI slider (1-8 threads) |
| Encoder Caching | Medium (cache hits only) | Medium | 4 | ⏸️ Deferred (complexity vs. benefit) |
Implementation Notes:
- Batched Decoding: Uses broadcast-enabled model (memory [1,...] broadcasts to [N,...])
- Tensor Reuse: IntArray/ByteBuffer pre-allocation reverted (JVM optimized for small allocations)
- XNNPACK UI: Settings slider + backup/export support + reset-to-defaults profile
- Memory Safety: try-finally ensures cleanup() called after every search()
// In SwipePredictionTest.kt or new benchmark test
@Test
fun benchmarkInferenceLatency() {
val orchestrator = SwipePredictorOrchestrator.getInstance(context)
orchestrator.initialize()
val testInputs = loadTestSwipeInputs() // Various lengths/patterns
val times = mutableListOf<Long>()
repeat(100) {
val input = testInputs.random()
val start = System.nanoTime()
orchestrator.predict(input)
times.add((System.nanoTime() - start) / 1_000_000)
}
Log.i("Benchmark", "Latency: min=${times.min()}ms, avg=${times.average()}ms, max=${times.max()}ms")
Log.i("Benchmark", "P50=${times.sorted()[50]}ms, P95=${times.sorted()[95]}ms")
}@Test
fun verifyOptimizationDoesNotChangePredictions() {
// Compare sequential vs batched predictions on same inputs
val sequentialResults = runPredictions(useBatched = false)
val batchedResults = runPredictions(useBatched = true)
assertEquals(sequentialResults, batchedResults, "Batched mode changed predictions!")
}- Use Android Profiler to compare GC events before/after tensor reuse
- Monitor OnnxTensor allocations via allocation tracking
| Risk | Mitigation |
|---|---|
| Batched decoding changes predictions | Diff test (5.2) validates equivalence |
| Tensor reuse causes stale data | Clear/fill arrays at start of each use |
| Encoder cache returns stale result | Fingerprint includes trajectory shape, TTL expiry |
| XNNPACK thread change degrades performance | Default to safe value (2), user override |
| Cached tensor closed while in use | Reference counting or defensive copy |
- FP16 Quantization: Convert models to FP16 for ~2x speedup on supported hardware
- Dynamic Batching: Adjust batch size based on available beams
- Async Encoding: Start encoding while user is still swiping
- NNAPI Caching: Leverage NNAPI compilation caching for faster cold starts