diff --git a/common/log.cpp b/common/log.cpp index 4ccdbd17cd7..6a831b94aa2 100644 --- a/common/log.cpp +++ b/common/log.cpp @@ -92,7 +92,14 @@ struct common_log_entry { // signals the worker thread to stop bool is_end; - void print(FILE * file = nullptr) const { + void print(FILE * file = nullptr, ggml_log_callback callback = nullptr, void * callback_user_data = nullptr) const { + // if callback is provided, use it instead of printing + if (callback != nullptr) { + callback(level, msg.data(), callback_user_data); + return; + } + + FILE * fcur = file; if (!fcur) { // stderr displays DBG messages only when their verbosity level is not higher than the threshold @@ -150,6 +157,8 @@ struct common_log { timestamps = false; running = false; t_start = t_us(); + callback = nullptr; + callback_user_data = nullptr; // initial message size - will be expanded if longer messages arrive entries.resize(capacity); @@ -191,6 +200,10 @@ struct common_log { // worker thread copies into this common_log_entry cur; + // custom callback for log messages + ggml_log_callback callback; + void * callback_user_data; + public: void add(enum ggml_log_level level, const char * fmt, va_list args) { std::lock_guard lock(mtx); @@ -281,11 +294,15 @@ struct common_log { thrd = std::thread([this]() { while (true) { + ggml_log_callback cb = nullptr; + void * cb_user_data = nullptr; { std::unique_lock lock(mtx); cv.wait(lock, [this]() { return head != tail; }); cur = entries[head]; + cb = callback; + cb_user_data = callback_user_data; head = (head + 1) % entries.size(); } @@ -294,7 +311,7 @@ struct common_log { break; } - cur.print(); // stdout and stderr + cur.print(nullptr, cb, cb_user_data); // stdout and stderr or callback if (file) { cur.print(file); @@ -376,6 +393,15 @@ struct common_log { this->timestamps = timestamps; } + + void set_callback(ggml_log_callback cb, void * user_data) { + pause(); + + this->callback = cb; + this->callback_user_data = user_data; + + resume(); + } }; // @@ -442,3 +468,7 @@ void common_log_set_prefix(struct common_log * log, bool prefix) { void common_log_set_timestamps(struct common_log * log, bool timestamps) { log->set_timestamps(timestamps); } + +void common_log_set_callback(struct common_log * log, ggml_log_callback callback, void * user_data) { + log->set_callback(callback, user_data); +} diff --git a/common/log.h b/common/log.h index f329b434c93..653f0662e68 100644 --- a/common/log.h +++ b/common/log.h @@ -76,6 +76,11 @@ void common_log_set_colors (struct common_log * log, log_colors colors); // n void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix +// set a custom callback to handle log messages instead of printing to stdout/stderr +// if callback is NULL, reverts to default printing behavior +// note: the callback will be called from the worker thread +void common_log_set_callback(struct common_log * log, ggml_log_callback callback, void * user_data); // not thread-safe + // helper macros for logging // use these to avoid computing log arguments if the verbosity of the log is higher than the threshold //