Skip to content

Commit 1f41e6e

Browse files
committed
Implement new sampler API
This commit implements the new sampler API from `llama.cpp` introduced in b3680 and removes the custom sampling logic. The new sampling API is exposes through a builder pattern. Made tests pass.
1 parent 5c27009 commit 1f41e6e

File tree

6 files changed

+251
-141
lines changed

6 files changed

+251
-141
lines changed

examples/simple/src/main.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
1717
use llama_cpp_2::model::params::LlamaModelParams;
1818
use llama_cpp_2::model::LlamaModel;
1919
use llama_cpp_2::model::{AddBos, Special};
20-
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
20+
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
21+
use llama_cpp_2::sampling::LlamaSampler;
22+
2123
use std::ffi::CString;
2224
use std::io::Write;
2325
use std::num::NonZeroU32;
@@ -174,9 +176,9 @@ fn main() -> Result<()> {
174176
.with_context(|| "unable to load model")?;
175177

176178
// initialize the context
177-
let mut ctx_params = LlamaContextParams::default()
178-
.with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())))
179-
.with_seed(seed.unwrap_or(1234));
179+
let mut ctx_params =
180+
LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())));
181+
180182
if let Some(threads) = threads {
181183
ctx_params = ctx_params.with_n_threads(threads);
182184
}
@@ -244,31 +246,31 @@ either reduce n_len or increase n_ctx"
244246
// The `Decoder`
245247
let mut decoder = encoding_rs::UTF_8.new_decoder();
246248

249+
let sampler_params = LlamaSamplerChainParams::default();
250+
let mut sampler = LlamaSampler::new(sampler_params)?.add_dist(seed.unwrap_or(1234));
251+
247252
while n_cur <= n_len {
248253
// sample the next token
249254
{
250-
let candidates = ctx.candidates();
251-
252-
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
255+
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
253256

254-
// sample the most likely token
255-
let new_token_id = ctx.sample_token_greedy(candidates_p);
257+
sampler.accept(token);
256258

257259
// is it an end of stream?
258-
if model.is_eog_token(new_token_id) {
260+
if model.is_eog_token(token) {
259261
eprintln!();
260262
break;
261263
}
262264

263-
let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize)?;
265+
let output_bytes = model.token_to_bytes(token, Special::Tokenize)?;
264266
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
265267
let mut output_string = String::with_capacity(32);
266268
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
267269
print!("{output_string}");
268270
std::io::stdout().flush()?;
269271

270272
batch.clear();
271-
batch.add(new_token_id, n_cur, &[0], true)?;
273+
batch.add(token, n_cur, &[0], true)?;
272274
}
273275

274276
n_cur += 1;

examples/usage.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use llama_cpp_2::llama_batch::LlamaBatch;
1414
use llama_cpp_2::model::params::LlamaModelParams;
1515
use llama_cpp_2::model::LlamaModel;
1616
use llama_cpp_2::model::{AddBos, Special};
17+
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
18+
use llama_cpp_2::sampling::LlamaSampler;
1719
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
1820
use std::io::Write;
1921

@@ -54,33 +56,33 @@ fn main() {
5456
// The `Decoder`
5557
let mut decoder = encoding_rs::UTF_8.new_decoder();
5658

59+
let sampler_params = LlamaSamplerChainParams::default();
60+
let mut sampler = LlamaSampler::new(sampler_params)
61+
.expect("Failed to create sampler")
62+
.add_greedy();
63+
5764
while n_cur <= n_len {
5865
// sample the next token
5966
{
60-
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
61-
62-
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
67+
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
6368

64-
// sample the most likely token
65-
let new_token_id = ctx.sample_token_greedy(candidates_p);
69+
sampler.accept(token);
6670

6771
// is it an end of stream?
68-
if new_token_id == model.token_eos() {
72+
if token == model.token_eos() {
6973
eprintln!();
7074
break;
7175
}
7276

73-
let output_bytes = model
74-
.token_to_bytes(new_token_id, Special::Tokenize)
75-
.unwrap();
77+
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
7678
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
7779
let mut output_string = String::with_capacity(32);
7880
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
7981
print!("{output_string}");
8082
std::io::stdout().flush().unwrap();
8183

8284
batch.clear();
83-
batch.add(new_token_id, n_cur, &[0], true).unwrap();
85+
batch.add(token, n_cur, &[0], true).unwrap();
8486
}
8587

8688
n_cur += 1;

llama-cpp-2/src/context/params.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl From<RopeScalingType> for i32 {
4747
pub enum LlamaPoolingType {
4848
/// The pooling type is unspecified
4949
Unspecified = -1,
50-
/// No pooling
50+
/// No pooling
5151
None = 0,
5252
/// Mean pooling
5353
Mean = 1,
@@ -95,10 +95,8 @@ impl From<LlamaPoolingType> for i32 {
9595
/// use llama_cpp_2::context::params::LlamaContextParams;
9696
///
9797
///let ctx_params = LlamaContextParams::default()
98-
/// .with_n_ctx(NonZeroU32::new(2048))
99-
/// .with_seed(1234);
98+
/// .with_n_ctx(NonZeroU32::new(2048));
10099
///
101-
/// assert_eq!(ctx_params.seed(), 1234);
102100
/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
103101
/// ```
104102
#[derive(Debug, Clone)]

llama-cpp-2/src/context/sample/sampler.rs

Lines changed: 0 additions & 112 deletions
This file was deleted.

0 commit comments

Comments
 (0)