Skip to content

Commit 338cc79

Browse files
authored
Merge pull request #92 from utilityai/update-llama-cpp-2024-02-21
updated llama.cpp (includes breaking backend init changes)
2 parents 68534b6 + 1c6130c commit 338cc79

File tree

2 files changed

+92
-5
lines changed

2 files changed

+92
-5
lines changed

llama-cpp-2/src/llama_backend.rs

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,86 @@ impl LlamaBackend {
4343
#[tracing::instrument(skip_all)]
4444
pub fn init() -> crate::Result<LlamaBackend> {
4545
Self::mark_init()?;
46-
unsafe { llama_cpp_sys_2::llama_backend_init(false) }
46+
unsafe { llama_cpp_sys_2::llama_backend_init() }
4747
Ok(LlamaBackend {})
4848
}
4949

5050
/// Initialize the llama backend (with numa).
5151
/// ```
5252
///# use llama_cpp_2::llama_backend::LlamaBackend;
5353
///# use std::error::Error;
54+
///# use llama_cpp_2::llama_backend::NumaStrategy;
5455
///
5556
///# fn main() -> Result<(), Box<dyn Error>> {
56-
/// let llama_backend = LlamaBackend::init_numa()?;
57+
///
58+
/// let llama_backend = LlamaBackend::init_numa(NumaStrategy::MIRROR)?;
5759
///
5860
///# Ok(())
5961
///# }
6062
/// ```
6163
#[tracing::instrument(skip_all)]
62-
pub fn init_numa() -> crate::Result<LlamaBackend> {
64+
pub fn init_numa(strategy: NumaStrategy) -> crate::Result<LlamaBackend> {
6365
Self::mark_init()?;
64-
unsafe { llama_cpp_sys_2::llama_backend_init(true) }
66+
unsafe {
67+
llama_cpp_sys_2::llama_numa_init(llama_cpp_sys_2::ggml_numa_strategy::from(strategy))
68+
}
6569
Ok(LlamaBackend {})
6670
}
6771
}
6872

73+
/// A rusty wrapper around `numa_strategy`.
74+
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
75+
pub enum NumaStrategy {
76+
/// The numa strategy is disabled.
77+
DISABLED,
78+
/// help wanted: what does this do?
79+
DISTRIBUTE,
80+
/// help wanted: what does this do?
81+
ISOLATE,
82+
/// help wanted: what does this do?
83+
NUMACTL,
84+
/// help wanted: what does this do?
85+
MIRROR,
86+
/// help wanted: what does this do?
87+
COUNT,
88+
}
89+
90+
/// An invalid numa strategy was provided.
91+
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
92+
pub struct InvalidNumaStrategy(
93+
/// The invalid numa strategy that was provided.
94+
pub llama_cpp_sys_2::ggml_numa_strategy,
95+
);
96+
97+
impl TryFrom<llama_cpp_sys_2::ggml_numa_strategy> for NumaStrategy {
98+
type Error = InvalidNumaStrategy;
99+
100+
fn try_from(value: llama_cpp_sys_2::ggml_numa_strategy) -> Result<Self, Self::Error> {
101+
match value {
102+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED => Ok(Self::DISABLED),
103+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE => Ok(Self::DISTRIBUTE),
104+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE => Ok(Self::ISOLATE),
105+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL => Ok(Self::NUMACTL),
106+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR => Ok(Self::MIRROR),
107+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT => Ok(Self::COUNT),
108+
value => Err(InvalidNumaStrategy(value)),
109+
}
110+
}
111+
}
112+
113+
impl From<NumaStrategy> for llama_cpp_sys_2::ggml_numa_strategy {
114+
fn from(value: NumaStrategy) -> Self {
115+
match value {
116+
NumaStrategy::DISABLED => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED,
117+
NumaStrategy::DISTRIBUTE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE,
118+
NumaStrategy::ISOLATE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE,
119+
NumaStrategy::NUMACTL => llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL,
120+
NumaStrategy::MIRROR => llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR,
121+
NumaStrategy::COUNT => llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT,
122+
}
123+
}
124+
}
125+
69126
/// Drops the llama backend.
70127
/// ```
71128
///
@@ -92,3 +149,33 @@ impl Drop for LlamaBackend {
92149
unsafe { llama_cpp_sys_2::llama_backend_free() }
93150
}
94151
}
152+
153+
#[cfg(test)]
154+
mod tests {
155+
use super::*;
156+
157+
#[test]
158+
fn numa_from_and_to() {
159+
let numas = [
160+
NumaStrategy::DISABLED,
161+
NumaStrategy::DISTRIBUTE,
162+
NumaStrategy::ISOLATE,
163+
NumaStrategy::NUMACTL,
164+
NumaStrategy::MIRROR,
165+
NumaStrategy::COUNT,
166+
];
167+
168+
for numa in &numas {
169+
let from = llama_cpp_sys_2::ggml_numa_strategy::from(*numa);
170+
let to = NumaStrategy::try_from(from).expect("Failed to convert from and to");
171+
assert_eq!(*numa, to);
172+
}
173+
}
174+
175+
#[test]
176+
fn check_invalid_numa() {
177+
let invalid = 800;
178+
let invalid = NumaStrategy::try_from(invalid);
179+
assert_eq!(invalid, Err(InvalidNumaStrategy(invalid.unwrap_err().0)));
180+
}
181+
}

llama-cpp-sys-2/llama.cpp

0 commit comments

Comments
 (0)