Skip to content

Commit e78ddc1

Browse files
authored
Merge pull request #188 from utilityai/sampler
Sampler
2 parents f34c6dc + 411a679 commit e78ddc1

File tree

11 files changed

+390
-213
lines changed

11 files changed

+390
-213
lines changed

.github/workflows/llama-cpp-rs-check.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
- name: Fmt
3535
run: cargo fmt
3636
- name: Test
37-
run: cargo test
37+
run: cargo test --features sampler
3838
arm64:
3939
name: Check that it builds on various targets
4040
runs-on: ubuntu-latest
@@ -67,7 +67,7 @@ jobs:
6767
- name: Setup Rust
6868
uses: dtolnay/rust-toolchain@stable
6969
- name: Build
70-
run: cargo build
70+
run: cargo build --features sampler
7171
windows:
7272
name: Check that it builds on windows
7373
runs-on: windows-latest
@@ -79,4 +79,4 @@ jobs:
7979
- name: Setup Rust
8080
uses: dtolnay/rust-toolchain@stable
8181
- name: Build
82-
run: cargo build
82+
run: cargo build --features sampler

llama-cpp-2/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ harness = false
2828

2929
[features]
3030
cublas = ["llama-cpp-sys-2/cublas"]
31+
sampler = []
3132

3233
[lints]
3334
workspace = true

llama-cpp-2/src/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ impl<'model> LlamaContext<'model> {
6969
///
7070
/// # Panics
7171
///
72-
/// - the returned [`c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
72+
/// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
7373
pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
7474
let result =
7575
unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) };

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ impl LlamaContext<'_> {
2222
///
2323
/// * `src` - The sequence id to copy the cache from.
2424
/// * `dest` - The sequence id to copy the cache to.
25-
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
26-
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
25+
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
26+
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
2727
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
2828
let p0 = p0.map_or(-1, i32::from);
2929
let p1 = p1.map_or(-1, i32::from);
@@ -37,8 +37,8 @@ impl LlamaContext<'_> {
3737
/// # Parameters
3838
///
3939
/// * `src` - The sequence id to clear the cache for.
40-
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1].
41-
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0].
40+
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
41+
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
4242
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
4343
let p0 = p0.map_or(-1, i32::from);
4444
let p1 = p1.map_or(-1, i32::from);
@@ -68,16 +68,16 @@ impl LlamaContext<'_> {
6868
}
6969

7070
#[allow(clippy::doc_markdown)]
71-
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
71+
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
7272
/// If the KV cache is RoPEd, the KV data is updated accordingly:
7373
/// - lazily on next [`LlamaContext::decode`]
7474
/// - explicitly with [`Self::kv_cache_update`]
7575
///
7676
/// # Parameters
7777
///
7878
/// * `seq_id` - The sequence id to update
79-
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
80-
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
79+
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
80+
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
8181
/// * `delta` - The relative position to add to the tokens
8282
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
8383
let p0 = p0.map_or(-1, i32::from);
@@ -95,8 +95,8 @@ impl LlamaContext<'_> {
9595
/// # Parameters
9696
///
9797
/// * `seq_id` - The sequence id to update
98-
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
99-
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
98+
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
99+
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
100100
/// * `d` - The factor to divide the positions by
101101
pub fn kv_cache_seq_div(
102102
&mut self,

llama-cpp-2/src/context/sample.rs

Lines changed: 36 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -5,130 +5,10 @@ use crate::grammar::LlamaGrammar;
55
use crate::token::data_array::LlamaTokenDataArray;
66
use crate::token::LlamaToken;
77

8-
/// struct to hold params for sampling
9-
#[derive(Debug)]
10-
#[deprecated(
11-
since = "0.1.32",
12-
note = "this does not scale well with many params and does not allow for changing of orders."
13-
)]
14-
pub struct Sampler<'grammar> {
15-
token_data_array: LlamaTokenDataArray,
16-
grammar: Option<&'grammar mut LlamaGrammar>,
17-
temperature: Option<f32>,
18-
}
19-
20-
impl<'grammar> Sampler<'grammar> {
21-
#[deprecated(
22-
since = "0.1.32",
23-
note = "this does not scale well with many params and does not allow for changing of orders."
24-
)]
25-
fn sample(self, llama_context: &mut LlamaContext) -> LlamaToken {
26-
match self {
27-
Sampler {
28-
token_data_array,
29-
grammar: None,
30-
temperature: None,
31-
} => llama_context.sample_token_greedy(token_data_array),
32-
Sampler {
33-
mut token_data_array,
34-
grammar: Some(grammar),
35-
temperature: None,
36-
} => {
37-
llama_context.sample_grammar(&mut token_data_array, grammar);
38-
let token = llama_context.sample_token_greedy(token_data_array);
39-
llama_context.grammar_accept_token(grammar, token);
40-
token
41-
}
42-
Sampler {
43-
mut token_data_array,
44-
grammar: None,
45-
temperature: Some(temp),
46-
} => {
47-
llama_context.sample_temp(&mut token_data_array, temp);
48-
llama_context.sample_token_softmax(&mut token_data_array);
49-
token_data_array.data[0].id()
50-
}
51-
Sampler {
52-
mut token_data_array,
53-
grammar: Some(grammar),
54-
temperature: Some(temperature),
55-
} => {
56-
llama_context.sample_grammar(&mut token_data_array, grammar);
57-
llama_context.sample_temp(&mut token_data_array, temperature);
58-
llama_context.sample_token_softmax(&mut token_data_array);
59-
let token = llama_context.sample_token_greedy(token_data_array);
60-
llama_context.grammar_accept_token(grammar, token);
61-
token
62-
}
63-
}
64-
}
65-
66-
/// Create a new sampler.
67-
#[must_use]
68-
#[deprecated(
69-
since = "0.1.32",
70-
note = "this does not scale well with many params and does not allow for changing of orders."
71-
)]
72-
pub fn new(llama_token_data_array: LlamaTokenDataArray) -> Self {
73-
Self {
74-
token_data_array: llama_token_data_array,
75-
grammar: None,
76-
temperature: None,
77-
}
78-
}
79-
80-
/// Set the grammar for sampling.
81-
#[must_use]
82-
#[deprecated(
83-
since = "0.1.32",
84-
note = "this does not scale well with many params and does not allow for changing of orders."
85-
)]
86-
pub fn with_grammar(mut self, grammar: &'grammar mut LlamaGrammar) -> Self {
87-
self.grammar = Some(grammar);
88-
self
89-
}
90-
91-
/// Set the temperature for sampling.
92-
///
93-
/// ```
94-
/// # use llama_cpp_2::context::LlamaContext;
95-
/// # use llama_cpp_2::context::sample::Sampler;
96-
/// # use llama_cpp_2::grammar::LlamaGrammar;
97-
/// # use llama_cpp_2::token::data::LlamaTokenData;
98-
/// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
99-
/// # use llama_cpp_2::token::LlamaToken;
100-
///
101-
/// let _sampler = Sampler::new(LlamaTokenDataArray::new(vec![LlamaTokenData::new(LlamaToken(0), 0.0, 0.0)], false))
102-
/// .with_temperature(0.5);
103-
/// ```
104-
#[must_use]
105-
#[deprecated(
106-
since = "0.1.32",
107-
note = "this does not scale well with many params and does not allow for changing of orders."
108-
)]
109-
pub fn with_temperature(mut self, temperature: f32) -> Self {
110-
if temperature == 0.0 {
111-
return self;
112-
}
113-
self.temperature = Some(temperature);
114-
self
115-
}
116-
}
8+
#[cfg(feature = "sampler")]
9+
pub mod sampler;
11710

11811
impl LlamaContext<'_> {
119-
/// Sample a token.
120-
///
121-
/// # Panics
122-
///
123-
/// - sampler contains no tokens
124-
#[deprecated(
125-
since = "0.1.32",
126-
note = "this does not scale well with many params and does not allow for changing of orders."
127-
)]
128-
pub fn sample(&mut self, sampler: Sampler) -> LlamaToken {
129-
sampler.sample(self)
130-
}
131-
13212
/// Accept a token into the grammar.
13313
pub fn grammar_accept_token(&mut self, grammar: &mut LlamaGrammar, token: LlamaToken) {
13414
unsafe {
@@ -157,38 +37,20 @@ impl LlamaContext<'_> {
15737
}
15838
}
15939

160-
/// Modify [`token_data`] in place using temperature sampling.
161-
///
162-
/// # Panics
163-
///
164-
/// - [`temperature`] is not between 0.0 and 1.0
165-
pub fn sample_temp(&self, token_data: &mut LlamaTokenDataArray, temperature: f32) {
166-
assert!(
167-
temperature >= 0.0,
168-
"temperature must be positive (was {temperature})"
169-
);
170-
assert!(
171-
temperature <= 1.0,
172-
"temperature must be less than or equal to 1.0 (was {temperature})"
173-
);
174-
if temperature == 0.0 {
175-
return;
176-
}
177-
let ctx: *mut llama_cpp_sys_2::llama_context = self.context.as_ptr();
178-
unsafe {
179-
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
180-
llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature);
181-
});
182-
}
40+
/// See [`LlamaTokenDataArray::sample_temp`]
41+
pub fn sample_temp(&mut self, token_data: &mut LlamaTokenDataArray, temperature: f32) {
42+
token_data.sample_temp(Some(self), temperature);
18343
}
18444

185-
/// Sample a token greedily.
45+
/// Sample a token greedily. Note that this *does not* take into account anything that has modified the probabilities - it only looks at logits.
46+
///
47+
/// Most of the time [`LlamaTokenDataArray::sample_softmax`] or [`LlamaTokenDataArray::sample_token`] should be used instead.
18648
///
18749
/// # Panics
18850
///
189-
/// - [`token_data`] is empty
51+
/// - if `token_data` is empty
19052
#[must_use]
191-
pub fn sample_token_greedy(&self, mut token_data: LlamaTokenDataArray) -> LlamaToken {
53+
pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken {
19254
assert!(!token_data.data.is_empty(), "no tokens");
19355
let mut data_arr = llama_cpp_sys_2::llama_token_data_array {
19456
data: token_data
@@ -207,39 +69,34 @@ impl LlamaContext<'_> {
20769
LlamaToken(token)
20870
}
20971

210-
/// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/).
211-
pub fn sample_tail_free(&self, token_data: &mut LlamaTokenDataArray, z: f32, min_keep: usize) {
212-
let ctx = self.context.as_ptr();
213-
unsafe {
214-
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
215-
llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep);
216-
});
217-
}
72+
/// See [`LlamaTokenDataArray::sample_tail_free`]
73+
pub fn sample_tail_free(
74+
&mut self,
75+
token_data: &mut LlamaTokenDataArray,
76+
z: f32,
77+
min_keep: usize,
78+
) {
79+
token_data.sample_tail_free(Some(self), z, min_keep);
21880
}
21981

220-
/// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666).
221-
pub fn sample_typical(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
222-
let ctx = self.context.as_ptr();
223-
unsafe {
224-
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
225-
llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep);
226-
});
227-
}
82+
/// See [`LlamaTokenDataArray::sample_typical`]
83+
pub fn sample_typical(
84+
&mut self,
85+
token_data: &mut LlamaTokenDataArray,
86+
p: f32,
87+
min_keep: usize,
88+
) {
89+
token_data.sample_typical(Some(self), p, min_keep);
22890
}
22991

230-
/// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)"
231-
pub fn sample_top_p(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
232-
let ctx = self.context.as_ptr();
233-
unsafe {
234-
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
235-
llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep);
236-
});
237-
}
92+
/// See [`LlamaTokenDataArray::sample_top_p`]
93+
pub fn sample_top_p(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
94+
token_data.sample_top_p(Some(self), p, min_keep);
23895
}
23996

24097
/// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841)
24198
pub fn sample_min_p(
242-
&self,
99+
&mut self,
243100
llama_token_data: &mut LlamaTokenDataArray,
244101
p: f32,
245102
min_keep: usize,
@@ -252,24 +109,14 @@ impl LlamaContext<'_> {
252109
}
253110
}
254111

255-
/// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)
256-
pub fn sample_top_k(&self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) {
257-
let ctx = self.context.as_ptr();
258-
unsafe {
259-
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
260-
llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep);
261-
});
262-
}
112+
/// See [`LlamaTokenDataArray::sample_top_k`]
113+
pub fn sample_top_k(&mut self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) {
114+
token_data.sample_top_k(Some(self), k, min_keep);
263115
}
264116

265-
/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
266-
pub fn sample_token_softmax(&self, token_data: &mut LlamaTokenDataArray) {
267-
let ctx = self.context.as_ptr();
268-
unsafe {
269-
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
270-
llama_cpp_sys_2::llama_sample_softmax(ctx, c_llama_token_data_array);
271-
});
272-
}
117+
/// See [`LlamaTokenDataArray::sample_softmax`]
118+
pub fn sample_token_softmax(&mut self, token_data: &mut LlamaTokenDataArray) {
119+
token_data.sample_softmax(Some(self));
273120
}
274121

275122
/// See [`LlamaTokenDataArray::sample_repetition_penalty`]

0 commit comments

Comments
 (0)