Skip to content

Commit 5523789

Browse files
committed
core: Add topK, topP, XTC sampling parameters
1 parent bb26957 commit 5523789

File tree

11 files changed

+183
-18
lines changed

11 files changed

+183
-18
lines changed

app/src/main/java/io/shubham0204/smollmandroid/data/AppDB.kt

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
11
package io.shubham0204.smollmandroid.data
22

33
import android.content.Context
4+
import androidx.room.AutoMigration
45
import androidx.room.Database
56
import androidx.room.Room
67
import androidx.room.RoomDatabase
78
import androidx.room.TypeConverters
9+
import androidx.room.migration.Migration
10+
import androidx.sqlite.db.SupportSQLiteDatabase
811
import kotlinx.coroutines.Dispatchers
912
import kotlinx.coroutines.flow.Flow
1013
import kotlinx.coroutines.runBlocking
1114
import org.koin.core.annotation.Single
1215
import java.util.Date
1316

17+
val MIGRATION_1_2 =
18+
object : Migration(1, 2) {
19+
override fun migrate(database: SupportSQLiteDatabase) {
20+
database.execSQL("ALTER TABLE Chat ADD COLUMN topK INTEGER NOT NULL DEFAULT 40")
21+
database.execSQL("ALTER TABLE Chat ADD COLUMN topP REAL NOT NULL DEFAULT 0.9")
22+
database.execSQL("ALTER TABLE Chat ADD COLUMN xtcP REAL NOT NULL DEFAULT 0.0")
23+
database.execSQL("ALTER TABLE Chat ADD COLUMN xtcT REAL NOT NULL DEFAULT 1.0")
24+
}
25+
}
26+
1427
@Database(
1528
entities = [Chat::class, ChatMessage::class, LLMModel::class, Task::class, Folder::class],
16-
version = 1,
29+
version = 2,
30+
exportSchema = true,
1731
)
1832
@TypeConverters(Converters::class)
1933
abstract class AppRoomDatabase : RoomDatabase() {
@@ -38,7 +52,8 @@ class AppDB(
3852
context,
3953
AppRoomDatabase::class.java,
4054
"app-database",
41-
).build()
55+
).addMigrations(MIGRATION_1_2)
56+
.build()
4257

4358
/** Get all chats from the database sorted by dateUsed in descending order. */
4459
fun getChats(): Flow<List<Chat>> = db.chatsDao().getChats()

app/src/main/java/io/shubham0204/smollmandroid/data/ChatsDB.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,15 @@ data class Chat(
5959
/**
6060
* LLM inference parameters that are used for this chat.
6161
*/
62-
var minP: Float = 0.1f,
63-
var temperature: Float = 0.8f,
62+
var minP: Float = 0.05f,
63+
var temperature: Float = 1.0f,
6464
var nThreads: Int = 4,
6565
var useMmap: Boolean = true,
6666
var useMlock: Boolean = false,
67+
var topK: Int = 50,
68+
var topP: Float = 1.0f,
69+
var xtcP: Float = 0.0f,
70+
var xtcT: Float = 1.0f,
6771
/**
6872
* The maximum number of tokens that can be used as context to the model
6973
* This is editable by users in the EditChatSettingsScreen.kt.

app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatScreenViewModel.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ class ChatScreenViewModel(
329329
chat.nThreads,
330330
chat.useMmap,
331331
chat.useMlock,
332+
chat.topP,
333+
chat.topK,
334+
chat.xtcP,
335+
chat.xtcT,
332336
),
333337
onError = { e ->
334338
_modelLoadState.value = ModelLoadingState.FAILURE

app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/EditChatSettingsScreen.kt

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ fun EditChatSettingsScreen(
7878
var chatTemplate by remember { mutableStateOf(chat.chatTemplate) }
7979
var useMmap by remember { mutableStateOf(chat.useMmap) }
8080
var useMlock by remember { mutableStateOf(chat.useMlock) }
81+
var topP by remember { mutableStateOf(chat.topP) }
82+
var topK by remember { mutableStateOf(chat.topK) }
83+
var xtcP by remember { mutableStateOf(chat.xtcP) }
84+
var xtcT by remember { mutableStateOf(chat.xtcT) }
8185
val context = LocalContext.current
8286
val llmModel = viewModel.modelsRepository.getModelFromId(chat.llmModelId)
8387

@@ -110,6 +114,8 @@ fun EditChatSettingsScreen(
110114
nThreads = nThreads,
111115
useMmap = useMmap,
112116
useMlock = useMlock,
117+
topP = topP,
118+
topK = topK,
113119
)
114120
if (chat != updatedChat) {
115121
viewModel.updateChat(updatedChat)
@@ -234,6 +240,90 @@ fun EditChatSettingsScreen(
234240

235241
Spacer(modifier = Modifier.height(24.dp))
236242

243+
Text(
244+
stringResource(R.string.chat_settings_label_topP),
245+
style = MaterialTheme.typography.titleMedium,
246+
)
247+
Text(
248+
stringResource(R.string.chat_settings_desc_topP),
249+
style = MaterialTheme.typography.labelSmall,
250+
)
251+
Slider(
252+
value = topP,
253+
onValueChange = { topP = it },
254+
valueRange = 0.0f..1.0f,
255+
steps = 100,
256+
)
257+
Text(
258+
text = "%.1f".format(topP),
259+
style = MaterialTheme.typography.labelSmall,
260+
)
261+
262+
Spacer(modifier = Modifier.height(24.dp))
263+
264+
Text(
265+
stringResource(R.string.chat_settings_label_topK),
266+
style = MaterialTheme.typography.titleMedium,
267+
)
268+
Text(
269+
stringResource(R.string.chat_settings_desc_topK),
270+
style = MaterialTheme.typography.labelSmall,
271+
)
272+
Slider(
273+
value = topK.toFloat(),
274+
onValueChange = { topK = it.toInt() },
275+
valueRange = 0.0f..128.0f,
276+
steps = 128,
277+
)
278+
Text(
279+
text = topK.toString(),
280+
style = MaterialTheme.typography.labelSmall,
281+
)
282+
283+
Spacer(modifier = Modifier.height(24.dp))
284+
285+
Text(
286+
stringResource(R.string.chat_settings_label_xtcT),
287+
style = MaterialTheme.typography.titleMedium,
288+
)
289+
Text(
290+
stringResource(R.string.chat_settings_desc_xtcT),
291+
style = MaterialTheme.typography.labelSmall,
292+
)
293+
Slider(
294+
value = xtcT,
295+
onValueChange = { xtcT = it },
296+
valueRange = 0.0f..1.0f,
297+
steps = 100,
298+
)
299+
Text(
300+
text = "%.1f".format(xtcT),
301+
style = MaterialTheme.typography.labelSmall,
302+
)
303+
304+
Spacer(modifier = Modifier.height(24.dp))
305+
306+
Text(
307+
stringResource(R.string.chat_settings_label_xtcP),
308+
style = MaterialTheme.typography.titleMedium,
309+
)
310+
Text(
311+
stringResource(R.string.chat_settings_desc_xtcP),
312+
style = MaterialTheme.typography.labelSmall,
313+
)
314+
Slider(
315+
value = xtcP,
316+
onValueChange = { xtcP = it },
317+
valueRange = 0.0f..1.0f,
318+
steps = 100,
319+
)
320+
Text(
321+
text = "%.1f".format(xtcP),
322+
style = MaterialTheme.typography.labelSmall,
323+
)
324+
325+
Spacer(modifier = Modifier.height(24.dp))
326+
237327
Text(
238328
stringResource(R.string.chat_settings_label_ctx_size),
239329
style = MaterialTheme.typography.titleMedium,

app/src/main/res/values-zh-rCN/strings.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@
5353
<string name="chat_settings_label_sys_prompt">系统提示</string>
5454
<string name="chat_settings_label_chat_name">聊天名称</string>
5555
<string name="chat_settings_take_from_gguf">从 GGUF 模型获取</string>
56+
<string name="chat_settings_desc_topP">Top-p(核)采样选择累积概率超过某个阈值的最小最可能词元集合,从而输出更具动态性和多样性的内容。</string>
57+
<string name="chat_settings_desc_topK">Top-k 采样将模型的选择范围限制在 $k$ 个最可能的下一个词元内,从而确保输出内容更集中且更具可预测性。</string>
58+
<string name="chat_settings_label_topP">Top P</string>
59+
<string name="chat_settings_label_topK">Top K</string>
60+
<string name="chat_settings_desc_xtcP">从采样中移除所有令牌,但留下概率最低的一个,移除概率为 xtcP</string>
61+
<string name="chat_settings_desc_xtcT">如果多个令牌的预测概率都达到或超过阈值 xtcT…</string>
62+
<string name="chat_settings_label_xtcP">XTC 概率</string>
63+
<string name="chat_settings_label_xtcT">XTC 阈值</string>
5664
<string name="context_size_taken_from_model">上下文大小取自模型</string>
5765
<string name="chat_settings_title_num_tokens">令牌数量</string>
5866
<string name="chat_settings_err_min_ctx_size">上下文大小应至少为 200 个令牌</string>

app/src/main/res/values/strings.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,19 @@
3232
<string name="chat_settings_desc_temp">Temperature is a parameter that controls the randomness and creativity of LLM outputs, with lower temperatures producing more deterministic and focused responses, and higher temperatures leading to more diverse and creative outputs.</string>
3333
<string name="chat_settings_desc_ctx_length">The context length of a large language model (LLM) refers to the maximum number of tokens (words or subwords) it can process in a single input or output sequence. Larger context sizes need more memory.</string>
3434
<string name="chat_settings_desc_n_threads">The number of CPU threads to use for inference.</string>
35+
<string name="chat_settings_desc_topP">Top-p sampling selects the smallest set of most probable tokens whose cumulative probability exceeds a threshold, allowing for a more dynamic and diverse output.</string>
36+
<string name="chat_settings_desc_topK">Top-k sampling limits the model\'s choices to the k most likely next tokens, ensuring a more focused and predictable output.</string>
37+
<string name="chat_settings_desc_xtcP">...remove all except the least probable one from sampling, with probability xtcP</string>
38+
<string name="chat_settings_desc_xtcT">If there are multiple tokens with predicted probability at least the threshold xtcT...</string>
3539
<string name="chat_settings_label_ctx_size">Context Size</string>
3640
<string name="chat_settings_label_temp">Temperature</string>
3741
<string name="chat_settings_label_minp">min-p</string>
3842
<string name="chat_settings_label_sys_prompt">System Prompt</string>
3943
<string name="chat_settings_label_chat_name">Chat Name</string>
44+
<string name="chat_settings_label_topP">top P</string>
45+
<string name="chat_settings_label_topK">top K</string>
46+
<string name="chat_settings_label_xtcP">XTC Probability</string>
47+
<string name="chat_settings_label_xtcT">XTC Threshold</string>
4048
<string name="chat_settings_take_from_gguf">Take from GGUF Model</string>
4149
<string name="context_size_taken_from_model">Context size taken from model</string>
4250
<string name="chat_settings_title_num_tokens">No. of tokens</string>

smollm/build.gradle.kts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ android {
3434
// allow compiling 16 KB page-aligned shared libraries
3535
// https://developer.android.com/guide/practices/page-sizes#compile-r27
3636
arguments += listOf("-DANDROID_SUPPORT_FLEXIBLE_PAGE_SIZES=ON")
37-
arguments += "-DCMAKE_BUILD_TYPE=Release"
37+
// arguments += "-DCMAKE_BUILD_TYPE=Release"
3838

3939
// (debugging) uncomment the following line to enable debug builds
4040
// and attach hardware-assisted address sanitizer
41-
// arguments += "-DCMAKE_BUILD_TYPE=Debug"
41+
arguments += "-DCMAKE_BUILD_TYPE=Debug"
4242
// arguments += listOf("-DANDROID_SANITIZE=hwaddress")
4343
}
4444
}

smollm/src/main/cpp/LLMInference.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,24 @@
88
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
99

1010
void
11-
LLMInference::loadModel(const char *model_path, float minP, float temperature, bool storeChats, long contextSize,
12-
const char *chatTemplate, int nThreads, bool useMmap, bool useMlock) {
11+
LLMInference::loadModel(const char* model_path, float minP, float temperature, bool storeChats, long contextSize,
12+
const char* chatTemplate, int nThreads, bool useMmap, bool useMlock, float topP, int topK,
13+
float xtcP, float xtcT) {
1314
LOGi("loading model with"
1415
"\n\tmodel_path = %s"
1516
"\n\tminP = %f"
1617
"\n\ttemperature = %f"
1718
"\n\tstoreChats = %d"
18-
"\n\tcontextSize = %li"
19+
"\n\tcontextSize = %d"
1920
"\n\tchatTemplate = %s"
2021
"\n\tnThreads = %d"
2122
"\n\tuseMmap = %d"
22-
"\n\tuseMlock = %d",
23-
model_path, minP, temperature, storeChats, contextSize, chatTemplate, nThreads, useMmap, useMlock);
23+
"\n\tuseMlock = %d"
24+
"\n\ttopP = %f"
25+
"\n\ttopK = %i"
26+
"\n\txtcP = %f"
27+
"\n\txtcT = %f",
28+
model_path, minP, temperature, (int)storeChats, (int)contextSize, chatTemplate, nThreads, useMmap, useMlock, topP, topK, xtcP, xtcT);
2429

2530
// load dynamic backends
2631
ggml_backend_load_all();
@@ -53,6 +58,23 @@ LLMInference::loadModel(const char *model_path, float minP, float temperature, b
5358
_sampler = llama_sampler_chain_init(sampler_params);
5459
llama_sampler_chain_add(_sampler, llama_sampler_init_temp(temperature));
5560
llama_sampler_chain_add(_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
61+
if (minP >= 0.01f) {
62+
// minP = 0.0 (disabled)
63+
// minP can be adjusted across 100 steps between [0.0,1.0], the smallest step being 0.01
64+
llama_sampler_chain_add(_sampler, llama_sampler_init_min_p(minP, 1));
65+
}
66+
if (topK > 0) {
67+
LOGi("Enabled top-k sampling with k=%d", topK);
68+
llama_sampler_chain_add(_sampler, llama_sampler_init_top_k(topK));
69+
}
70+
if (topP <= 0.99) {
71+
LOGi("Enabled top-p sampling with p=%f", topP);
72+
llama_sampler_chain_add(_sampler, llama_sampler_init_top_p(topP, 1));
73+
}
74+
if (xtcT <= 0.99 || xtcP >= 0.01) {
75+
LOGi("Enabled XTC sampling with p=%f, t=%f", xtcP, xtcT);
76+
llama_sampler_chain_add(_sampler, llama_sampler_init_xtc(xtcP, xtcT, 1, LLAMA_DEFAULT_SEED));
77+
}
5678

5779
_formattedMessages = std::vector<char>(llama_n_ctx(_ctx));
5880
_messages.clear();

smollm/src/main/cpp/LLMInference.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class LLMInference {
4040

4141
public:
4242
void loadModel(const char* modelPath, float minP, float temperature, bool storeChats, long contextSize,
43-
const char* chatTemplate, int nThreads, bool useMmap, bool useMlock);
43+
const char* chatTemplate, int nThreads, bool useMmap, bool useMlock, float topP, int topK,
44+
float xtcP, float xtcT);
4445

4546
void addChatMessage(const char* message, const char* role);
4647

smollm/src/main/cpp/smollm.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
extern "C" JNIEXPORT jlong JNICALL
55
Java_io_shubham0204_smollm_SmolLM_loadModel(JNIEnv* env, jobject thiz, jstring modelPath, jfloat minP,
66
jfloat temperature, jboolean storeChats, jlong contextSize,
7-
jstring chatTemplate, jint nThreads, jboolean useMmap, jboolean useMlock) {
7+
jstring chatTemplate, jint nThreads, jboolean useMmap, jboolean useMlock,
8+
jfloat topP, jint topK, jfloat xtcP, jfloat xtcT) {
89
jboolean isCopy = true;
910
const char* modelPathCstr = env->GetStringUTFChars(modelPath, &isCopy);
1011
auto* llmInference = new LLMInference();
1112
const char* chatTemplateCstr = env->GetStringUTFChars(chatTemplate, &isCopy);
1213

1314
try {
1415
llmInference->loadModel(modelPathCstr, minP, temperature, storeChats, contextSize, chatTemplateCstr, nThreads,
15-
useMmap, useMlock);
16+
useMmap, useMlock, topP, topK, xtcP, xtcT);
1617
} catch (std::runtime_error& error) {
1718
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), error.what());
1819
}

0 commit comments

Comments
 (0)