@@ -43,29 +43,86 @@ impl LlamaBackend {
4343 #[ tracing:: instrument( skip_all) ]
4444 pub fn init ( ) -> crate :: Result < LlamaBackend > {
4545 Self :: mark_init ( ) ?;
46- unsafe { llama_cpp_sys_2:: llama_backend_init ( false ) }
46+ unsafe { llama_cpp_sys_2:: llama_backend_init ( ) }
4747 Ok ( LlamaBackend { } )
4848 }
4949
5050 /// Initialize the llama backend (with numa).
5151 /// ```
5252 ///# use llama_cpp_2::llama_backend::LlamaBackend;
5353 ///# use std::error::Error;
54+ ///# use llama_cpp_2::llama_backend::NumaStrategy;
5455 ///
5556 ///# fn main() -> Result<(), Box<dyn Error>> {
56- /// let llama_backend = LlamaBackend::init_numa()?;
57+ ///
58+ /// let llama_backend = LlamaBackend::init_numa(NumaStrategy::MIRROR)?;
5759 ///
5860 ///# Ok(())
5961 ///# }
6062 /// ```
6163 #[ tracing:: instrument( skip_all) ]
62- pub fn init_numa ( ) -> crate :: Result < LlamaBackend > {
64+ pub fn init_numa ( strategy : NumaStrategy ) -> crate :: Result < LlamaBackend > {
6365 Self :: mark_init ( ) ?;
64- unsafe { llama_cpp_sys_2:: llama_backend_init ( true ) }
66+ unsafe {
67+ llama_cpp_sys_2:: llama_numa_init ( llama_cpp_sys_2:: ggml_numa_strategy:: from ( strategy) )
68+ }
6569 Ok ( LlamaBackend { } )
6670 }
6771}
6872
73+ /// A rusty wrapper around `numa_strategy`.
74+ #[ derive( Debug , Eq , PartialEq , Copy , Clone ) ]
75+ pub enum NumaStrategy {
76+ /// The numa strategy is disabled.
77+ DISABLED ,
78+ /// help wanted: what does this do?
79+ DISTRIBUTE ,
80+ /// help wanted: what does this do?
81+ ISOLATE ,
82+ /// help wanted: what does this do?
83+ NUMACTL ,
84+ /// help wanted: what does this do?
85+ MIRROR ,
86+ /// help wanted: what does this do?
87+ COUNT ,
88+ }
89+
90+ /// An invalid numa strategy was provided.
91+ #[ derive( Debug , Eq , PartialEq , Copy , Clone ) ]
92+ pub struct InvalidNumaStrategy (
93+ /// The invalid numa strategy that was provided.
94+ pub llama_cpp_sys_2:: ggml_numa_strategy ,
95+ ) ;
96+
97+ impl TryFrom < llama_cpp_sys_2:: ggml_numa_strategy > for NumaStrategy {
98+ type Error = InvalidNumaStrategy ;
99+
100+ fn try_from ( value : llama_cpp_sys_2:: ggml_numa_strategy ) -> Result < Self , Self :: Error > {
101+ match value {
102+ llama_cpp_sys_2:: GGML_NUMA_STRATEGY_DISABLED => Ok ( Self :: DISABLED ) ,
103+ llama_cpp_sys_2:: GGML_NUMA_STRATEGY_DISTRIBUTE => Ok ( Self :: DISTRIBUTE ) ,
104+ llama_cpp_sys_2:: GGML_NUMA_STRATEGY_ISOLATE => Ok ( Self :: ISOLATE ) ,
105+ llama_cpp_sys_2:: GGML_NUMA_STRATEGY_NUMACTL => Ok ( Self :: NUMACTL ) ,
106+ llama_cpp_sys_2:: GGML_NUMA_STRATEGY_MIRROR => Ok ( Self :: MIRROR ) ,
107+ llama_cpp_sys_2:: GGML_NUMA_STRATEGY_COUNT => Ok ( Self :: COUNT ) ,
108+ value => Err ( InvalidNumaStrategy ( value) ) ,
109+ }
110+ }
111+ }
112+
113+ impl From < NumaStrategy > for llama_cpp_sys_2:: ggml_numa_strategy {
114+ fn from ( value : NumaStrategy ) -> Self {
115+ match value {
116+ NumaStrategy :: DISABLED => llama_cpp_sys_2:: GGML_NUMA_STRATEGY_DISABLED ,
117+ NumaStrategy :: DISTRIBUTE => llama_cpp_sys_2:: GGML_NUMA_STRATEGY_DISTRIBUTE ,
118+ NumaStrategy :: ISOLATE => llama_cpp_sys_2:: GGML_NUMA_STRATEGY_ISOLATE ,
119+ NumaStrategy :: NUMACTL => llama_cpp_sys_2:: GGML_NUMA_STRATEGY_NUMACTL ,
120+ NumaStrategy :: MIRROR => llama_cpp_sys_2:: GGML_NUMA_STRATEGY_MIRROR ,
121+ NumaStrategy :: COUNT => llama_cpp_sys_2:: GGML_NUMA_STRATEGY_COUNT ,
122+ }
123+ }
124+ }
125+
69126/// Drops the llama backend.
70127/// ```
71128///
@@ -92,3 +149,33 @@ impl Drop for LlamaBackend {
92149 unsafe { llama_cpp_sys_2:: llama_backend_free ( ) }
93150 }
94151}
152+
153+ #[ cfg( test) ]
154+ mod tests {
155+ use super :: * ;
156+
157+ #[ test]
158+ fn numa_from_and_to ( ) {
159+ let numas = [
160+ NumaStrategy :: DISABLED ,
161+ NumaStrategy :: DISTRIBUTE ,
162+ NumaStrategy :: ISOLATE ,
163+ NumaStrategy :: NUMACTL ,
164+ NumaStrategy :: MIRROR ,
165+ NumaStrategy :: COUNT ,
166+ ] ;
167+
168+ for numa in & numas {
169+ let from = llama_cpp_sys_2:: ggml_numa_strategy:: from ( * numa) ;
170+ let to = NumaStrategy :: try_from ( from) . expect ( "Failed to convert from and to" ) ;
171+ assert_eq ! ( * numa, to) ;
172+ }
173+ }
174+
175+ #[ test]
176+ fn check_invalid_numa ( ) {
177+ let invalid = 800 ;
178+ let invalid = NumaStrategy :: try_from ( invalid) ;
179+ assert_eq ! ( invalid, Err ( InvalidNumaStrategy ( invalid. unwrap_err( ) . 0 ) ) ) ;
180+ }
181+ }
0 commit comments