Skip to content

Commit 057da42

Browse files
authored
Merge branch 'main' into 8-metal-on-mac
2 parents 62ac6b5 + 2e05e66 commit 057da42

File tree

10 files changed

+155
-44
lines changed

10 files changed

+155
-44
lines changed

.github/workflows/llama-cpp-rs-check.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,16 @@ jobs:
6767
- name: Setup Rust
6868
uses: dtolnay/rust-toolchain@stable
6969
- name: Build
70-
run: cargo build
70+
run: cargo build
71+
windows:
72+
name: Check that it builds on windows
73+
runs-on: windows-latest
74+
steps:
75+
- name: checkout
76+
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11
77+
with:
78+
submodules: recursive
79+
- name: Setup Rust
80+
uses: dtolnay/rust-toolchain@stable
81+
- name: Build
82+
run: cargo build

Cargo.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama-cpp-2/Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
[package]
22
name = "llama-cpp-2"
33
description = "llama.cpp bindings for Rust"
4-
version = "0.1.25"
4+
version = "0.1.28"
55
edition = "2021"
66
license = "MIT OR Apache-2.0"
77
repository = "https://github.com/utilityai/llama-cpp-rs"
88

99
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1010

1111
[dependencies]
12-
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.25" }
12+
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.28" }
1313
thiserror = { workspace = true }
1414
tracing = { workspace = true }
1515

@@ -19,8 +19,8 @@ criterion = { workspace = true }
1919
pprof = { workspace = true, features = ["criterion", "flamegraph"] }
2020

2121
# used in examples
22-
clap = { version = "4.5.0", features = ["derive"] }
23-
anyhow = "1.0.79"
22+
clap = { version = "4.5.1", features = ["derive"] }
23+
anyhow = "1.0.80"
2424

2525
[[bench]]
2626
name = "grammar_bias"

llama-cpp-2/src/grammar.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ impl ParseState {
269269
rest = r;
270270
rule.push(llama_grammar_element {
271271
type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR,
272-
value: c,
272+
value: c as _,
273273
});
274274
}
275275
rest = Self::consume_whitespace_and_comments(&rest[1..], nested);
@@ -292,14 +292,14 @@ impl ParseState {
292292
};
293293
rule.push(llama_grammar_element {
294294
type_: gre_type,
295-
value: c,
295+
value: c as _,
296296
});
297297
if rest.starts_with("-]") {
298298
let (c, r) = Self::parse_char(rest)?;
299299
rest = r;
300300
rule.push(llama_grammar_element {
301301
type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_RNG_UPPER,
302-
value: c,
302+
value: c as _,
303303
});
304304
}
305305
}
@@ -386,7 +386,7 @@ impl ParseState {
386386
error,
387387
})?;
388388

389-
Ok((value, rest))
389+
Ok((value as llama_gretype, rest))
390390
}
391391

392392
fn parse_char(rest: &str) -> Result<(llama_gretype, &str), GrammarParseError> {
@@ -401,17 +401,17 @@ impl ParseState {
401401
'x' => Self::parse_hex(rest, 2),
402402
'u' => Self::parse_hex(rest, 4),
403403
'U' => Self::parse_hex(rest, 8),
404-
't' => Ok((u32::from('\t'), rest)),
405-
'r' => Ok((u32::from('\r'), rest)),
406-
'n' => Ok((u32::from('\n'), rest)),
407-
'\\' => Ok((u32::from('\\'), rest)),
408-
'"' => Ok((u32::from('"'), rest)),
409-
'[' => Ok((u32::from('['), rest)),
410-
']' => Ok((u32::from(']'), rest)),
404+
't' => Ok((u32::from('\t') as llama_gretype, rest)),
405+
'r' => Ok((u32::from('\r') as llama_gretype, rest)),
406+
'n' => Ok((u32::from('\n') as llama_gretype, rest)),
407+
'\\' => Ok((u32::from('\\') as llama_gretype, rest)),
408+
'"' => Ok((u32::from('"') as llama_gretype, rest)),
409+
'[' => Ok((u32::from('[') as llama_gretype, rest)),
410+
']' => Ok((u32::from(']') as llama_gretype, rest)),
411411
c => Err(GrammarParseError::UnknownEscape { escape: c }),
412412
}
413413
} else if let Some(c) = rest.chars().next() {
414-
Ok((u32::from(c), &rest[c.len_utf8()..]))
414+
Ok((u32::from(c) as llama_gretype, &rest[c.len_utf8()..]))
415415
} else {
416416
Err(GrammarParseError::UnexpectedEndOfInput {
417417
parse_stage: "char",

llama-cpp-2/src/llama_backend.rs

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,86 @@ impl LlamaBackend {
4343
#[tracing::instrument(skip_all)]
4444
pub fn init() -> crate::Result<LlamaBackend> {
4545
Self::mark_init()?;
46-
unsafe { llama_cpp_sys_2::llama_backend_init(false) }
46+
unsafe { llama_cpp_sys_2::llama_backend_init() }
4747
Ok(LlamaBackend {})
4848
}
4949

5050
/// Initialize the llama backend (with numa).
5151
/// ```
5252
///# use llama_cpp_2::llama_backend::LlamaBackend;
5353
///# use std::error::Error;
54+
///# use llama_cpp_2::llama_backend::NumaStrategy;
5455
///
5556
///# fn main() -> Result<(), Box<dyn Error>> {
56-
/// let llama_backend = LlamaBackend::init_numa()?;
57+
///
58+
/// let llama_backend = LlamaBackend::init_numa(NumaStrategy::MIRROR)?;
5759
///
5860
///# Ok(())
5961
///# }
6062
/// ```
6163
#[tracing::instrument(skip_all)]
62-
pub fn init_numa() -> crate::Result<LlamaBackend> {
64+
pub fn init_numa(strategy: NumaStrategy) -> crate::Result<LlamaBackend> {
6365
Self::mark_init()?;
64-
unsafe { llama_cpp_sys_2::llama_backend_init(true) }
66+
unsafe {
67+
llama_cpp_sys_2::llama_numa_init(llama_cpp_sys_2::ggml_numa_strategy::from(strategy))
68+
}
6569
Ok(LlamaBackend {})
6670
}
6771
}
6872

73+
/// A rusty wrapper around `numa_strategy`.
74+
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
75+
pub enum NumaStrategy {
76+
/// The numa strategy is disabled.
77+
DISABLED,
78+
/// help wanted: what does this do?
79+
DISTRIBUTE,
80+
/// help wanted: what does this do?
81+
ISOLATE,
82+
/// help wanted: what does this do?
83+
NUMACTL,
84+
/// help wanted: what does this do?
85+
MIRROR,
86+
/// help wanted: what does this do?
87+
COUNT,
88+
}
89+
90+
/// An invalid numa strategy was provided.
91+
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
92+
pub struct InvalidNumaStrategy(
93+
/// The invalid numa strategy that was provided.
94+
pub llama_cpp_sys_2::ggml_numa_strategy,
95+
);
96+
97+
impl TryFrom<llama_cpp_sys_2::ggml_numa_strategy> for NumaStrategy {
98+
type Error = InvalidNumaStrategy;
99+
100+
fn try_from(value: llama_cpp_sys_2::ggml_numa_strategy) -> Result<Self, Self::Error> {
101+
match value {
102+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED => Ok(Self::DISABLED),
103+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE => Ok(Self::DISTRIBUTE),
104+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE => Ok(Self::ISOLATE),
105+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL => Ok(Self::NUMACTL),
106+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR => Ok(Self::MIRROR),
107+
llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT => Ok(Self::COUNT),
108+
value => Err(InvalidNumaStrategy(value)),
109+
}
110+
}
111+
}
112+
113+
impl From<NumaStrategy> for llama_cpp_sys_2::ggml_numa_strategy {
114+
fn from(value: NumaStrategy) -> Self {
115+
match value {
116+
NumaStrategy::DISABLED => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED,
117+
NumaStrategy::DISTRIBUTE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE,
118+
NumaStrategy::ISOLATE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE,
119+
NumaStrategy::NUMACTL => llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL,
120+
NumaStrategy::MIRROR => llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR,
121+
NumaStrategy::COUNT => llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT,
122+
}
123+
}
124+
}
125+
69126
/// Drops the llama backend.
70127
/// ```
71128
///
@@ -92,3 +149,33 @@ impl Drop for LlamaBackend {
92149
unsafe { llama_cpp_sys_2::llama_backend_free() }
93150
}
94151
}
152+
153+
#[cfg(test)]
154+
mod tests {
155+
use super::*;
156+
157+
#[test]
158+
fn numa_from_and_to() {
159+
let numas = [
160+
NumaStrategy::DISABLED,
161+
NumaStrategy::DISTRIBUTE,
162+
NumaStrategy::ISOLATE,
163+
NumaStrategy::NUMACTL,
164+
NumaStrategy::MIRROR,
165+
NumaStrategy::COUNT,
166+
];
167+
168+
for numa in &numas {
169+
let from = llama_cpp_sys_2::ggml_numa_strategy::from(*numa);
170+
let to = NumaStrategy::try_from(from).expect("Failed to convert from and to");
171+
assert_eq!(*numa, to);
172+
}
173+
}
174+
175+
#[test]
176+
fn check_invalid_numa() {
177+
let invalid = 800;
178+
let invalid = NumaStrategy::try_from(invalid);
179+
assert_eq!(invalid, Err(InvalidNumaStrategy(invalid.unwrap_err().0)));
180+
}
181+
}

llama-cpp-2/src/model.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ impl LlamaModel {
210210
}
211211

212212
match self.token_type(token) {
213-
LlamaTokenType::Normal => {}
213+
LlamaTokenType::Normal | LlamaTokenType::UserDefined => {}
214214
LlamaTokenType::Control => {
215215
if token == self.token_bos() || token == self.token_eos() {
216216
return Ok(String::new());
@@ -219,7 +219,6 @@ impl LlamaModel {
219219
LlamaTokenType::Unknown
220220
| LlamaTokenType::Undefined
221221
| LlamaTokenType::Byte
222-
| LlamaTokenType::UserDefined
223222
| LlamaTokenType::Unused => {
224223
return Ok(String::new());
225224
}
@@ -332,9 +331,9 @@ impl Drop for LlamaModel {
332331
#[derive(Debug, Eq, Copy, Clone, PartialEq)]
333332
pub enum VocabType {
334333
/// Byte Pair Encoding
335-
BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE,
334+
BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
336335
/// Sentence Piece Tokenizer
337-
SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM,
336+
SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
338337
}
339338

340339
/// There was an error converting a `llama_vocab_type` to a `VocabType`.

llama-cpp-2/src/token_type.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
#[allow(clippy::module_name_repetitions)]
77
pub enum LlamaTokenType {
88
/// An undefined token type.
9-
Undefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNDEFINED,
9+
Undefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNDEFINED as _,
1010
/// A normal token type.
11-
Normal = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_NORMAL,
11+
Normal = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_NORMAL as _,
1212
/// An unknown token type.
13-
Unknown = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNKNOWN,
13+
Unknown = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNKNOWN as _,
1414
/// A control token type.
15-
Control = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_CONTROL,
15+
Control = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_CONTROL as _,
1616
/// A user defined token type.
17-
UserDefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_USER_DEFINED,
17+
UserDefined = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_USER_DEFINED as _,
1818
/// An unused token type.
19-
Unused = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNUSED,
19+
Unused = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNUSED as _,
2020
/// A byte token type.
21-
Byte = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_BYTE,
21+
Byte = llama_cpp_sys_2::LLAMA_TOKEN_TYPE_BYTE as _,
2222
}
2323

2424
/// A safe wrapper for converting potentially deceptive `llama_token_type` values into
@@ -52,7 +52,7 @@ impl TryFrom<llama_cpp_sys_2::llama_token_type> for LlamaTokenType {
5252
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_USER_DEFINED => Ok(LlamaTokenType::UserDefined),
5353
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_UNUSED => Ok(LlamaTokenType::Unused),
5454
llama_cpp_sys_2::LLAMA_TOKEN_TYPE_BYTE => Ok(LlamaTokenType::Byte),
55-
_ => Err(LlamaTokenTypeFromIntError::UnknownValue(value)),
55+
_ => Err(LlamaTokenTypeFromIntError::UnknownValue(value as _)),
5656
}
5757
}
5858
}

llama-cpp-sys-2/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "llama-cpp-sys-2"
33
description = "Low Level Bindings to llama.cpp"
4-
version = "0.1.25"
4+
version = "0.1.28"
55
edition = "2021"
66
license = "MIT OR Apache-2.0"
77
repository = "https://github.com/utilityai/llama-cpp-rs"

llama-cpp-sys-2/build.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ fn main() {
2424
] {
2525
println!("cargo:rustc-link-lib={}", lib);
2626
}
27+
if !ggml_cuda.get_compiler().is_like_msvc() {
28+
for lib in ["culibos", "pthread", "dl", "rt"] {
29+
println!("cargo:rustc-link-lib={}", lib);
30+
}
31+
}
2732

2833
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
2934

@@ -37,11 +42,19 @@ fn main() {
3742
.flag_if_supported("-mno-unaligned-access");
3843
}
3944

40-
ggml.cuda(true)
45+
46+
ggml
47+
.cuda(true)
4148
.std("c++17")
4249
.flag("-arch=all")
4350
.file("llama.cpp/ggml-cuda.cu");
4451

52+
if ggml_cuda.get_compiler().is_like_msvc() {
53+
ggml_cuda.std("c++14");
54+
} else {
55+
ggml_cuda.std("c++17");
56+
}
57+
4558
ggml.define("GGML_USE_CUBLAS", None);
4659
ggml.define("GGML_USE_CUBLAS", None);
4760
llama_cpp.define("GGML_USE_CUBLAS", None);

llama-cpp-sys-2/llama.cpp

0 commit comments

Comments
 (0)