Skip to content

Commit 42aaeeb

Browse files
authored
Merge pull request #580 from nobodywho-ooo/bump-llama-cpp
Implement new sampler API and bump llama.cpp
2 parents 77af620 + 1f41e6e commit 42aaeeb

File tree

13 files changed

+343
-1226
lines changed

13 files changed

+343
-1226
lines changed

examples/simple/src/main.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
1717
use llama_cpp_2::model::params::LlamaModelParams;
1818
use llama_cpp_2::model::LlamaModel;
1919
use llama_cpp_2::model::{AddBos, Special};
20-
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
20+
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
21+
use llama_cpp_2::sampling::LlamaSampler;
22+
2123
use std::ffi::CString;
2224
use std::io::Write;
2325
use std::num::NonZeroU32;
@@ -174,9 +176,9 @@ fn main() -> Result<()> {
174176
.with_context(|| "unable to load model")?;
175177

176178
// initialize the context
177-
let mut ctx_params = LlamaContextParams::default()
178-
.with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())))
179-
.with_seed(seed.unwrap_or(1234));
179+
let mut ctx_params =
180+
LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())));
181+
180182
if let Some(threads) = threads {
181183
ctx_params = ctx_params.with_n_threads(threads);
182184
}
@@ -244,31 +246,31 @@ either reduce n_len or increase n_ctx"
244246
// The `Decoder`
245247
let mut decoder = encoding_rs::UTF_8.new_decoder();
246248

249+
let sampler_params = LlamaSamplerChainParams::default();
250+
let mut sampler = LlamaSampler::new(sampler_params)?.add_dist(seed.unwrap_or(1234));
251+
247252
while n_cur <= n_len {
248253
// sample the next token
249254
{
250-
let candidates = ctx.candidates();
251-
252-
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
255+
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
253256

254-
// sample the most likely token
255-
let new_token_id = ctx.sample_token_greedy(candidates_p);
257+
sampler.accept(token);
256258

257259
// is it an end of stream?
258-
if model.is_eog_token(new_token_id) {
260+
if model.is_eog_token(token) {
259261
eprintln!();
260262
break;
261263
}
262264

263-
let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize)?;
265+
let output_bytes = model.token_to_bytes(token, Special::Tokenize)?;
264266
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
265267
let mut output_string = String::with_capacity(32);
266268
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
267269
print!("{output_string}");
268270
std::io::stdout().flush()?;
269271

270272
batch.clear();
271-
batch.add(new_token_id, n_cur, &[0], true)?;
273+
batch.add(token, n_cur, &[0], true)?;
272274
}
273275

274276
n_cur += 1;

examples/usage.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use llama_cpp_2::llama_batch::LlamaBatch;
1414
use llama_cpp_2::model::params::LlamaModelParams;
1515
use llama_cpp_2::model::LlamaModel;
1616
use llama_cpp_2::model::{AddBos, Special};
17+
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
18+
use llama_cpp_2::sampling::LlamaSampler;
1719
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
1820
use std::io::Write;
1921

@@ -54,33 +56,33 @@ fn main() {
5456
// The `Decoder`
5557
let mut decoder = encoding_rs::UTF_8.new_decoder();
5658

59+
let sampler_params = LlamaSamplerChainParams::default();
60+
let mut sampler = LlamaSampler::new(sampler_params)
61+
.expect("Failed to create sampler")
62+
.add_greedy();
63+
5764
while n_cur <= n_len {
5865
// sample the next token
5966
{
60-
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
61-
62-
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
67+
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
6368

64-
// sample the most likely token
65-
let new_token_id = ctx.sample_token_greedy(candidates_p);
69+
sampler.accept(token);
6670

6771
// is it an end of stream?
68-
if new_token_id == model.token_eos() {
72+
if token == model.token_eos() {
6973
eprintln!();
7074
break;
7175
}
7276

73-
let output_bytes = model
74-
.token_to_bytes(new_token_id, Special::Tokenize)
75-
.unwrap();
77+
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
7678
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
7779
let mut output_string = String::with_capacity(32);
7880
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
7981
print!("{output_string}");
8082
std::io::stdout().flush().unwrap();
8183

8284
batch.clear();
83-
batch.add(new_token_id, n_cur, &[0], true).unwrap();
85+
batch.add(token, n_cur, &[0], true).unwrap();
8486
}
8587

8688
n_cur += 1;

llama-cpp-2/src/context.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use crate::{
1717

1818
pub mod kv_cache;
1919
pub mod params;
20-
pub mod sample;
2120
pub mod session;
2221

2322
/// Safe wrapper around `llama_context`.
@@ -267,12 +266,12 @@ impl<'model> LlamaContext<'model> {
267266

268267
/// Reset the timings for the context.
269268
pub fn reset_timings(&mut self) {
270-
unsafe { llama_cpp_sys_2::llama_reset_timings(self.context.as_ptr()) }
269+
unsafe { llama_cpp_sys_2::llama_perf_context_reset(self.context.as_ptr()) }
271270
}
272271

273272
/// Returns the timings for the context.
274273
pub fn timings(&mut self) -> LlamaTimings {
275-
let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) };
274+
let timings = unsafe { llama_cpp_sys_2::llama_perf_context(self.context.as_ptr()) };
276275
LlamaTimings { timings }
277276
}
278277

llama-cpp-2/src/context/params.rs

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl From<RopeScalingType> for i32 {
4747
pub enum LlamaPoolingType {
4848
/// The pooling type is unspecified
4949
Unspecified = -1,
50-
/// No pooling
50+
/// No pooling
5151
None = 0,
5252
/// Mean pooling
5353
Mean = 1,
@@ -95,10 +95,8 @@ impl From<LlamaPoolingType> for i32 {
9595
/// use llama_cpp_2::context::params::LlamaContextParams;
9696
///
9797
///let ctx_params = LlamaContextParams::default()
98-
/// .with_n_ctx(NonZeroU32::new(2048))
99-
/// .with_seed(1234);
98+
/// .with_n_ctx(NonZeroU32::new(2048));
10099
///
101-
/// assert_eq!(ctx_params.seed(), 1234);
102100
/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
103101
/// ```
104102
#[derive(Debug, Clone)]
@@ -116,37 +114,6 @@ unsafe impl Send for LlamaContextParams {}
116114
unsafe impl Sync for LlamaContextParams {}
117115

118116
impl LlamaContextParams {
119-
/// Set the seed of the context
120-
///
121-
/// # Examples
122-
///
123-
/// ```rust
124-
/// use llama_cpp_2::context::params::LlamaContextParams;
125-
/// let params = LlamaContextParams::default();
126-
/// let params = params.with_seed(1234);
127-
/// assert_eq!(params.seed(), 1234);
128-
/// ```
129-
#[must_use]
130-
pub fn with_seed(mut self, seed: u32) -> Self {
131-
self.context_params.seed = seed;
132-
self
133-
}
134-
135-
/// Get the seed of the context
136-
///
137-
/// # Examples
138-
///
139-
/// ```rust
140-
/// use llama_cpp_2::context::params::LlamaContextParams;
141-
/// let params = LlamaContextParams::default()
142-
/// .with_seed(1234);
143-
/// assert_eq!(params.seed(), 1234);
144-
/// ```
145-
#[must_use]
146-
pub fn seed(&self) -> u32 {
147-
self.context_params.seed
148-
}
149-
150117
/// Set the side of the context
151118
///
152119
/// # Examples

llama-cpp-2/src/context/sample.rs

Lines changed: 0 additions & 141 deletions
This file was deleted.

0 commit comments

Comments
 (0)