@@ -92,7 +92,14 @@ struct common_log_entry {
9292 // signals the worker thread to stop
9393 bool is_end;
9494
95- void print (FILE * file = nullptr ) const {
95+ void print (FILE * file = nullptr , ggml_log_callback callback = nullptr , void * callback_user_data = nullptr ) const {
96+ // if callback is provided, use it instead of printing
97+ if (callback != nullptr ) {
98+ callback (level, msg.data (), callback_user_data);
99+ return ;
100+ }
101+
102+
96103 FILE * fcur = file;
97104 if (!fcur) {
98105 // stderr displays DBG messages only when their verbosity level is not higher than the threshold
@@ -150,6 +157,8 @@ struct common_log {
150157 timestamps = false ;
151158 running = false ;
152159 t_start = t_us ();
160+ callback = nullptr ;
161+ callback_user_data = nullptr ;
153162
154163 // initial message size - will be expanded if longer messages arrive
155164 entries.resize (capacity);
@@ -191,6 +200,10 @@ struct common_log {
191200 // worker thread copies into this
192201 common_log_entry cur;
193202
203+ // custom callback for log messages
204+ ggml_log_callback callback;
205+ void * callback_user_data;
206+
194207public:
195208 void add (enum ggml_log_level level, const char * fmt, va_list args) {
196209 std::lock_guard<std::mutex> lock (mtx);
@@ -281,11 +294,15 @@ struct common_log {
281294
282295 thrd = std::thread ([this ]() {
283296 while (true ) {
297+ ggml_log_callback cb = nullptr ;
298+ void * cb_user_data = nullptr ;
284299 {
285300 std::unique_lock<std::mutex> lock (mtx);
286301 cv.wait (lock, [this ]() { return head != tail; });
287302
288303 cur = entries[head];
304+ cb = callback;
305+ cb_user_data = callback_user_data;
289306
290307 head = (head + 1 ) % entries.size ();
291308 }
@@ -294,7 +311,7 @@ struct common_log {
294311 break ;
295312 }
296313
297- cur.print (); // stdout and stderr
314+ cur.print (nullptr , cb, cb_user_data ); // stdout and stderr or callback
298315
299316 if (file) {
300317 cur.print (file);
@@ -376,6 +393,15 @@ struct common_log {
376393
377394 this ->timestamps = timestamps;
378395 }
396+
397+ void set_callback (ggml_log_callback cb, void * user_data) {
398+ pause ();
399+
400+ this ->callback = cb;
401+ this ->callback_user_data = user_data;
402+
403+ resume ();
404+ }
379405};
380406
381407//
@@ -442,3 +468,7 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
442468void common_log_set_timestamps (struct common_log * log, bool timestamps) {
443469 log->set_timestamps (timestamps);
444470}
471+
472+ void common_log_set_callback (struct common_log * log, ggml_log_callback callback, void * user_data) {
473+ log->set_callback (callback, user_data);
474+ }
0 commit comments