Skip to content

Commit bf4076b

Browse files
committed
chat: check if inference is stopped before unloading model in onStop() #52
1 parent 9b88303 commit bf4076b

File tree

3 files changed

+93
-75
lines changed

3 files changed

+93
-75
lines changed

app/src/main/java/io/shubham0204/smollmandroid/llm/SmolLMManager.kt

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ class SmolLMManager(
3838
) {
3939
private val instance = SmolLM()
4040
private var responseGenerationJob: Job? = null
41+
private var modelInitJob: Job? = null
4142
private var chat: Chat? = null
4243
var isInstanceLoaded = false
44+
var isInferenceOn = false
4345

4446
data class SmolLMInitParams(
4547
val chat: Chat,
@@ -63,39 +65,40 @@ class SmolLMManager(
6365
onSuccess: () -> Unit,
6466
) {
6567
try {
66-
CoroutineScope(Dispatchers.Default).launch {
67-
chat = initParams.chat
68-
if (isInstanceLoaded) {
69-
close()
70-
}
71-
instance.create(
72-
initParams.modelPath,
73-
initParams.minP,
74-
initParams.temperature,
75-
initParams.storeChats,
76-
initParams.contextSize,
77-
)
78-
LOGD("Model loaded")
79-
if (initParams.chat.systemPrompt.isNotEmpty()) {
80-
instance.addSystemPrompt(initParams.chat.systemPrompt)
81-
LOGD("System prompt added")
82-
}
83-
if (!initParams.chat.isTask) {
84-
messagesDB.getMessagesForModel(initParams.chat.id).forEach { message ->
85-
if (message.isUserMessage) {
86-
instance.addUserMessage(message.message)
87-
LOGD("User message added: ${message.message}")
88-
} else {
89-
instance.addAssistantMessage(message.message)
90-
LOGD("Assistant message added: ${message.message}")
68+
modelInitJob =
69+
CoroutineScope(Dispatchers.Default).launch {
70+
chat = initParams.chat
71+
if (isInstanceLoaded) {
72+
close()
73+
}
74+
instance.create(
75+
initParams.modelPath,
76+
initParams.minP,
77+
initParams.temperature,
78+
initParams.storeChats,
79+
initParams.contextSize,
80+
)
81+
LOGD("Model loaded")
82+
if (initParams.chat.systemPrompt.isNotEmpty()) {
83+
instance.addSystemPrompt(initParams.chat.systemPrompt)
84+
LOGD("System prompt added")
85+
}
86+
if (!initParams.chat.isTask) {
87+
messagesDB.getMessagesForModel(initParams.chat.id).forEach { message ->
88+
if (message.isUserMessage) {
89+
instance.addUserMessage(message.message)
90+
LOGD("User message added: ${message.message}")
91+
} else {
92+
instance.addAssistantMessage(message.message)
93+
LOGD("Assistant message added: ${message.message}")
94+
}
9195
}
9296
}
97+
withContext(Dispatchers.Main) {
98+
isInstanceLoaded = true
99+
onSuccess()
100+
}
93101
}
94-
withContext(Dispatchers.Main) {
95-
isInstanceLoaded = true
96-
onSuccess()
97-
}
98-
}
99102
} catch (e: Exception) {
100103
onError(e)
101104
}
@@ -113,6 +116,7 @@ class SmolLMManager(
113116
assert(chat != null) { "Please call SmolLMManager.create() first." }
114117
responseGenerationJob =
115118
CoroutineScope(Dispatchers.Default).launch {
119+
isInferenceOn = true
116120
var response = ""
117121
val duration =
118122
measureTime {
@@ -127,6 +131,7 @@ class SmolLMManager(
127131
// add it to the messages database
128132
messagesDB.addAssistantMessage(chat!!.id, response)
129133
withContext(Dispatchers.Main) {
134+
isInferenceOn = false
130135
onSuccess(
131136
SmolLMResponse(
132137
response = response,
@@ -138,22 +143,28 @@ class SmolLMManager(
138143
}
139144
}
140145
} catch (e: CancellationException) {
146+
isInferenceOn = false
141147
onCancelled()
142148
} catch (e: Exception) {
149+
isInferenceOn = false
143150
onError(e)
144151
}
145152
}
146153

147154
fun stopResponseGeneration() {
148-
responseGenerationJob?.let {
149-
if (it.isActive) {
150-
it.cancel()
151-
}
152-
}
155+
responseGenerationJob?.let { cancelJobIfActive(it) }
153156
}
154157

155158
fun close() {
159+
stopResponseGeneration()
160+
modelInitJob?.let { cancelJobIfActive(it) }
156161
instance.close()
157162
isInstanceLoaded = false
158163
}
164+
165+
private fun cancelJobIfActive(job: Job) {
166+
if (job.isActive) {
167+
job.cancel()
168+
}
169+
}
159170
}

app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatActivity.kt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ private val LOGD: (String) -> Unit = { Log.d(LOGTAG, it) }
111111

112112
class ChatActivity : ComponentActivity() {
113113
private val viewModel: ChatScreenViewModel by inject()
114+
private var modelUnloaded = false
114115

115116
override fun onCreate(savedInstanceState: Bundle?) {
116117
super.onCreate(savedInstanceState)
@@ -145,14 +146,19 @@ class ChatActivity : ComponentActivity() {
145146
*/
146147
override fun onStart() {
147148
super.onStart()
148-
viewModel.loadModel()
149-
LOGD("onStart() called - model loaded")
149+
if (modelUnloaded) {
150+
viewModel.loadModel()
151+
LOGD("onStart() called - model loaded")
152+
}
150153
}
151154

152155
override fun onStop() {
153156
super.onStop()
154-
viewModel.unloadModel()
155-
LOGD("onStop() called - model unloaded")
157+
modelUnloaded = viewModel.unloadModel()
158+
if (modelUnloaded) {
159+
Toast.makeText(this, "Model unloaded from memory", Toast.LENGTH_SHORT).show()
160+
}
161+
LOGD("onStop() called - model unloaded result: $modelUnloaded")
156162
}
157163
}
158164

app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatScreenViewModel.kt

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -239,50 +239,51 @@ class ChatScreenViewModel(
239239
// clear resources occupied by the previous model
240240
smolLMManager.close()
241241
_currChatState.value?.let { chat ->
242-
if (chat.llmModelId == -1L) {
242+
val model = modelsRepository.getModelFromId(chat.llmModelId)
243+
if (chat.llmModelId == -1L || model == null) {
243244
_showSelectModelListDialogState.value = true
244245
} else {
245-
val model = modelsRepository.getModelFromId(chat.llmModelId)
246-
if (model != null) {
247-
_modelLoadState.value = ModelLoadingState.IN_PROGRESS
248-
smolLMManager.create(
249-
SmolLMManager.SmolLMInitParams(
250-
chat,
251-
model.path,
252-
chat.minP,
253-
chat.temperature,
254-
!chat.isTask,
255-
chat.contextSize.toLong(),
256-
),
257-
onError = { e ->
258-
_modelLoadState.value = ModelLoadingState.FAILURE
259-
createAlertDialog(
260-
dialogTitle = context.getString(R.string.dialog_err_title),
261-
dialogText = context.getString(R.string.dialog_err_text, e.message),
262-
dialogPositiveButtonText = context.getString(R.string.dialog_err_change_model),
263-
onPositiveButtonClick = { showSelectModelListDialog() },
264-
dialogNegativeButtonText = context.getString(R.string.dialog_err_close),
265-
onNegativeButtonClick = {},
266-
)
267-
},
268-
onSuccess = {
269-
_modelLoadState.value = ModelLoadingState.SUCCESS
270-
},
271-
)
272-
} else {
273-
_showSelectModelListDialogState.value = true
274-
}
246+
_modelLoadState.value = ModelLoadingState.IN_PROGRESS
247+
smolLMManager.create(
248+
SmolLMManager.SmolLMInitParams(
249+
chat,
250+
model.path,
251+
chat.minP,
252+
chat.temperature,
253+
!chat.isTask,
254+
chat.contextSize.toLong(),
255+
),
256+
onError = { e ->
257+
_modelLoadState.value = ModelLoadingState.FAILURE
258+
createAlertDialog(
259+
dialogTitle = context.getString(R.string.dialog_err_title),
260+
dialogText = context.getString(R.string.dialog_err_text, e.message),
261+
dialogPositiveButtonText = context.getString(R.string.dialog_err_change_model),
262+
onPositiveButtonClick = { showSelectModelListDialog() },
263+
dialogNegativeButtonText = context.getString(R.string.dialog_err_close),
264+
onNegativeButtonClick = {},
265+
)
266+
},
267+
onSuccess = {
268+
_modelLoadState.value = ModelLoadingState.SUCCESS
269+
},
270+
)
275271
}
276272
}
277273
}
278274

279275
/**
280-
* Clears the resources occupied by the model.
276+
* Clears the resources occupied by the model only
277+
* if the inference is not in progress.
281278
*/
282-
fun unloadModel() {
283-
smolLMManager.close()
284-
_modelLoadState.value = ModelLoadingState.NOT_LOADED
285-
}
279+
fun unloadModel(): Boolean =
280+
if (!smolLMManager.isInferenceOn) {
281+
smolLMManager.close()
282+
_modelLoadState.value = ModelLoadingState.NOT_LOADED
283+
true
284+
} else {
285+
false
286+
}
286287

287288
fun showContextLengthUsageDialog() {
288289
_currChatState.value?.let { chat ->

0 commit comments

Comments
 (0)