@@ -252,16 +252,140 @@ impl State {
252252 }
253253 }
254254
255+ pub ( super ) fn update_previous_level_for_disabled_log (
256+ & self ,
257+ level : llama_cpp_sys_2:: ggml_log_level ,
258+ ) {
259+ if level != llama_cpp_sys_2:: GGML_LOG_LEVEL_CONT {
260+ self . previous_level
261+ . store ( level as i32 , std:: sync:: atomic:: Ordering :: Release ) ;
262+ }
263+ }
264+
255265 /// Checks whether the given log level is enabled by the current tracing subscriber.
256266 pub ( super ) fn is_enabled_for_level ( & self , level : llama_cpp_sys_2:: ggml_log_level ) -> bool {
257267 // CONT logs do not need to check if they are enabled.
258- if level == llama_cpp_sys_2:: GGML_LOG_LEVEL_CONT {
259- return true ;
260- }
268+ let level = if level == llama_cpp_sys_2:: GGML_LOG_LEVEL_CONT {
269+ self . previous_level
270+ . load ( std:: sync:: atomic:: Ordering :: Relaxed )
271+ as llama_cpp_sys_2:: ggml_log_level
272+ } else {
273+ level
274+ } ;
261275 let ( meta, _) = meta_for_level ( level) ;
262276 tracing:: dispatcher:: get_default ( |dispatcher| dispatcher. enabled ( meta) )
263277 }
264278}
265279
266280pub ( super ) static LLAMA_STATE : OnceLock < Box < State > > = OnceLock :: new ( ) ;
267281pub ( super ) static GGML_STATE : OnceLock < Box < State > > = OnceLock :: new ( ) ;
282+
283+ #[ cfg( test) ]
284+ mod tests {
285+ use crate :: logs_to_trace;
286+ use std:: sync:: { Arc , Mutex } ;
287+ use tracing:: subscriber:: DefaultGuard ;
288+ use tracing_subscriber:: util:: SubscriberInitExt ;
289+
290+ use super :: * ;
291+
292+ struct Logger {
293+ #[ allow( unused) ]
294+ guard : DefaultGuard ,
295+ logs : Arc < Mutex < Vec < String > > > ,
296+ }
297+
298+ #[ derive( Clone ) ]
299+ struct VecWriter ( Arc < Mutex < Vec < String > > > ) ;
300+
301+ impl std:: io:: Write for VecWriter {
302+ fn write ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < usize > {
303+ let log_line = String :: from_utf8 ( buf. to_vec ( ) ) . map_err ( |_| {
304+ std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , "Invalid UTF-8" )
305+ } ) ?;
306+ self . 0 . lock ( ) . unwrap ( ) . push ( log_line) ;
307+ Ok ( buf. len ( ) )
308+ }
309+
310+ fn flush ( & mut self ) -> std:: io:: Result < ( ) > {
311+ Ok ( ( ) )
312+ }
313+ }
314+
315+ fn create_logger ( max_level : tracing:: Level ) -> Logger {
316+ let logs = Arc :: new ( Mutex :: new ( vec ! [ ] ) ) ;
317+ let writer = VecWriter ( logs. clone ( ) ) ;
318+
319+ Logger {
320+ guard : tracing_subscriber:: fmt ( )
321+ . with_max_level ( max_level)
322+ . with_ansi ( false )
323+ . without_time ( )
324+ . with_file ( false )
325+ . with_line_number ( false )
326+ . with_level ( false )
327+ . with_target ( false )
328+ . with_writer ( move || writer. clone ( ) )
329+ . finish ( )
330+ . set_default ( ) ,
331+ logs,
332+ }
333+ }
334+
335+ #[ test]
336+ fn cont_disabled_log ( ) {
337+ let logger = create_logger ( tracing:: Level :: INFO ) ;
338+ let mut log_state = Box :: new ( State :: new ( Module :: LlamaCpp , LogOptions :: default ( ) ) ) ;
339+ let log_ptr = log_state. as_mut ( ) as * mut State as * mut std:: os:: raw:: c_void ;
340+
341+ logs_to_trace (
342+ llama_cpp_sys_2:: GGML_LOG_LEVEL_DEBUG ,
343+ c"Hello " . as_ptr ( ) ,
344+ log_ptr,
345+ ) ;
346+ logs_to_trace (
347+ llama_cpp_sys_2:: GGML_LOG_LEVEL_CONT ,
348+ c"world\n " . as_ptr ( ) ,
349+ log_ptr,
350+ ) ;
351+
352+ assert ! ( logger. logs. lock( ) . unwrap( ) . is_empty( ) ) ;
353+
354+ logs_to_trace (
355+ llama_cpp_sys_2:: GGML_LOG_LEVEL_DEBUG ,
356+ c"Hello " . as_ptr ( ) ,
357+ log_ptr,
358+ ) ;
359+ logs_to_trace (
360+ llama_cpp_sys_2:: GGML_LOG_LEVEL_CONT ,
361+ c"world" . as_ptr ( ) ,
362+ log_ptr,
363+ ) ;
364+ logs_to_trace (
365+ llama_cpp_sys_2:: GGML_LOG_LEVEL_CONT ,
366+ c"\n " . as_ptr ( ) ,
367+ log_ptr,
368+ ) ;
369+ }
370+
371+ #[ test]
372+ fn cont_enabled_log ( ) {
373+ let logger = create_logger ( tracing:: Level :: INFO ) ;
374+ let mut log_state = Box :: new ( State :: new ( Module :: LlamaCpp , LogOptions :: default ( ) ) ) ;
375+ let log_ptr = log_state. as_mut ( ) as * mut State as * mut std:: os:: raw:: c_void ;
376+
377+ logs_to_trace (
378+ llama_cpp_sys_2:: GGML_LOG_LEVEL_INFO ,
379+ c"Hello " . as_ptr ( ) ,
380+ log_ptr,
381+ ) ;
382+ logs_to_trace (
383+ llama_cpp_sys_2:: GGML_LOG_LEVEL_CONT ,
384+ c"world\n " . as_ptr ( ) ,
385+ log_ptr,
386+ ) ;
387+
388+ // Not sure where the extra \n comes from.
389+ assert_eq ! ( * logger. logs. lock( ) . unwrap( ) , vec![ "Hello world\n \n " ] ) ;
390+ }
391+ }
0 commit comments