Skip to content

Commit cf69db5

Browse files
authored
Merge pull request #594 from nkoppel/sampler_api
Add sampling API back to LlamaTokenDataArray; Add DRY and XTC Samplers
2 parents 3d29dbf + 67ea688 commit cf69db5

File tree

8 files changed

+478
-220
lines changed

8 files changed

+478
-220
lines changed

examples/simple/src/main.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
1717
use llama_cpp_2::model::params::LlamaModelParams;
1818
use llama_cpp_2::model::LlamaModel;
1919
use llama_cpp_2::model::{AddBos, Special};
20-
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
2120
use llama_cpp_2::sampling::LlamaSampler;
2221

2322
use std::ffi::CString;
@@ -246,10 +245,10 @@ either reduce n_len or increase n_ctx"
246245
// The `Decoder`
247246
let mut decoder = encoding_rs::UTF_8.new_decoder();
248247

249-
let sampler_params = LlamaSamplerChainParams::default();
250-
let mut sampler = LlamaSampler::new(sampler_params)?
251-
.add_dist(seed.unwrap_or(1234))
252-
.add_greedy();
248+
let mut sampler = LlamaSampler::chain_simple([
249+
LlamaSampler::dist(seed.unwrap_or(1234)),
250+
LlamaSampler::greedy(),
251+
]);
253252

254253
while n_cur <= n_len {
255254
// sample the next token

examples/usage.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ use llama_cpp_2::llama_batch::LlamaBatch;
1414
use llama_cpp_2::model::params::LlamaModelParams;
1515
use llama_cpp_2::model::LlamaModel;
1616
use llama_cpp_2::model::{AddBos, Special};
17-
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
1817
use llama_cpp_2::sampling::LlamaSampler;
19-
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
2018
use std::io::Write;
2119

2220
#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
@@ -55,11 +53,7 @@ fn main() {
5553

5654
// The `Decoder`
5755
let mut decoder = encoding_rs::UTF_8.new_decoder();
58-
59-
let sampler_params = LlamaSamplerChainParams::default();
60-
let mut sampler = LlamaSampler::new(sampler_params)
61-
.expect("Failed to create sampler")
62-
.add_greedy();
56+
let mut sampler = LlamaSampler::greedy();
6357

6458
while n_cur <= n_len {
6559
// sample the next token

llama-cpp-2/src/context.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::llama_batch::LlamaBatch;
99
use crate::model::{LlamaLoraAdapter, LlamaModel};
1010
use crate::timing::LlamaTimings;
1111
use crate::token::data::LlamaTokenData;
12+
use crate::token::data_array::LlamaTokenDataArray;
1213
use crate::token::LlamaToken;
1314
use crate::{
1415
DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
@@ -202,6 +203,21 @@ impl<'model> LlamaContext<'model> {
202203
})
203204
}
204205

206+
/// Get the token data array for the last token in the context.
207+
///
208+
/// This is a convience method that implements:
209+
/// ```ignore
210+
/// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
211+
/// ```
212+
///
213+
/// # Panics
214+
///
215+
/// - underlying logits data is null
216+
#[must_use]
217+
pub fn token_data_array(&self) -> LlamaTokenDataArray {
218+
LlamaTokenDataArray::from_iter(self.candidates(), false)
219+
}
220+
205221
/// Token logits obtained from the last call to `decode()`.
206222
/// The logits for which `batch.logits[i] != 0` are stored contiguously
207223
/// in the order they have appeared in the batch.
@@ -217,6 +233,7 @@ impl<'model> LlamaContext<'model> {
217233
///
218234
/// - `n_vocab` does not fit into a usize
219235
/// - token data returned is null
236+
#[must_use]
220237
pub fn get_logits(&self) -> &[f32] {
221238
let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
222239
assert!(!data.is_null(), "logits data for last token is null");
@@ -237,6 +254,21 @@ impl<'model> LlamaContext<'model> {
237254
})
238255
}
239256

257+
/// Get the token data array for the ith token in the context.
258+
///
259+
/// This is a convience method that implements:
260+
/// ```ignore
261+
/// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false)
262+
/// ```
263+
///
264+
/// # Panics
265+
///
266+
/// - logit `i` is not initialized.
267+
#[must_use]
268+
pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray {
269+
LlamaTokenDataArray::from_iter(self.candidates_ith(i), false)
270+
}
271+
240272
/// Get the logits for the ith token in the context.
241273
///
242274
/// # Panics

llama-cpp-2/src/lib.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,6 @@ pub enum LlamaLoraAdapterRemoveError {
195195
ErrorResult(i32),
196196
}
197197

198-
/// An error that can occur when initializing a sampler.
199-
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
200-
pub enum LlamaSamplerError {
201-
/// llama.cpp returned null
202-
#[error("null reference from llama.cpp")]
203-
NullReturn,
204-
}
205-
206198
/// get the time (in microseconds) according to llama.cpp
207199
/// ```
208200
/// # use llama_cpp_2::llama_time_us;

0 commit comments

Comments
 (0)