2424#include < signal.h>
2525#endif
2626
27- static bool g_is_generating = false ;
27+ // volatile, because of signal being an interrupt
28+ static volatile bool g_is_generating = false ;
29+ static volatile bool g_is_interrupted = false ;
2830
2931/* *
3032 * Please note that this is NOT a production-ready stuff.
@@ -50,8 +52,10 @@ static void sigint_handler(int signo) {
5052 g_is_generating = false ;
5153 } else {
5254 console::cleanup ();
53- LOG (" \n Interrupted by user\n " );
54- _exit (130 );
55+ if (g_is_interrupted) {
56+ _exit (1 );
57+ }
58+ g_is_interrupted = true ;
5559 }
5660 }
5761}
@@ -167,7 +171,7 @@ struct decode_embd_batch {
167171static int generate_response (mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
168172 llama_tokens generated_tokens;
169173 for (int i = 0 ; i < n_predict; i++) {
170- if (i > n_predict || !g_is_generating) {
174+ if (i > n_predict || !g_is_generating || g_is_interrupted ) {
171175 printf (" \n " );
172176 break ;
173177 }
@@ -184,6 +188,11 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
184188 printf (" %s" , common_token_to_piece (ctx.lctx , token_id).c_str ());
185189 fflush (stdout);
186190
191+ if (g_is_interrupted) {
192+ printf (" \n " );
193+ break ;
194+ }
195+
187196 // eval the token
188197 common_batch_clear (ctx.batch );
189198 common_batch_add (ctx.batch , token_id, ctx.n_past ++, {0 }, true );
@@ -219,6 +228,9 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
219228 text.add_special = add_bos;
220229 text.parse_special = true ;
221230 mtmd_input_chunks chunks;
231+
232+ if (g_is_interrupted) return 0 ;
233+
222234 int32_t res = mtmd_tokenize (ctx.ctx_vision .get (), chunks, text, bitmaps);
223235 if (res != 0 ) {
224236 LOG_ERR (" Unable to tokenize prompt, res = %d\n " , res);
@@ -276,6 +288,8 @@ int main(int argc, char ** argv) {
276288#endif
277289 }
278290
291+ if (g_is_interrupted) return 130 ;
292+
279293 if (is_single_turn) {
280294 g_is_generating = true ;
281295 if (params.prompt .find (" <__image__>" ) == std::string::npos) {
@@ -287,7 +301,7 @@ int main(int argc, char ** argv) {
287301 if (eval_message (ctx, msg, params.image , true )) {
288302 return 1 ;
289303 }
290- if (generate_response (ctx, smpl, n_predict)) {
304+ if (!g_is_interrupted && generate_response (ctx, smpl, n_predict)) {
291305 return 1 ;
292306 }
293307
@@ -302,12 +316,13 @@ int main(int argc, char ** argv) {
302316 std::vector<std::string> images_fname;
303317 std::string content;
304318
305- while (true ) {
319+ while (!g_is_interrupted ) {
306320 g_is_generating = false ;
307321 LOG (" \n > " );
308322 console::set_display (console::user_input);
309323 std::string line;
310324 console::readline (line, false );
325+ if (g_is_interrupted) break ;
311326 console::set_display (console::reset);
312327 line = string_strip (line);
313328 if (line.empty ()) {
@@ -335,6 +350,7 @@ int main(int argc, char ** argv) {
335350 msg.role = " user" ;
336351 msg.content = content;
337352 int ret = eval_message (ctx, msg, images_fname, is_first_msg);
353+ if (g_is_interrupted) break ;
338354 if (ret == 2 ) {
339355 // non-fatal error
340356 images_fname.clear ();
@@ -352,6 +368,7 @@ int main(int argc, char ** argv) {
352368 is_first_msg = false ;
353369 }
354370 }
371+ if (g_is_interrupted) LOG (" \n Interrupted by user\n " );
355372 llama_perf_context_print (ctx.lctx );
356- return 0 ;
373+ return g_is_interrupted ? 130 : 0 ;
357374}
0 commit comments