Skip to content

Commit 5c27009

Browse files
committed
Expose safe wrappers around the new sampling API types
1 parent baea38c commit 5c27009

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

llama-cpp-2/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod context;
2626
pub mod llama_backend;
2727
pub mod llama_batch;
2828
pub mod model;
29+
pub mod sampling;
2930
pub mod timing;
3031
pub mod token;
3132
pub mod token_type;
@@ -61,6 +62,7 @@ pub enum LLamaCppError {
6162
/// see [`EmbeddingsError`]
6263
#[error(transparent)]
6364
EmbeddingError(#[from] EmbeddingsError),
65+
// See [`LlamaSamplerError`]
6466
}
6567

6668
/// There was an error while getting the chat template from a model.
@@ -193,6 +195,14 @@ pub enum LlamaLoraAdapterRemoveError {
193195
ErrorResult(i32),
194196
}
195197

198+
/// An error that can occur when initializing a sampler.
199+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
200+
pub enum LlamaSamplerError {
201+
/// llama.cpp returned null
202+
#[error("null reference from llama.cpp")]
203+
NullReturn,
204+
}
205+
196206
/// get the time (in microseconds) according to llama.cpp
197207
/// ```
198208
/// # use llama_cpp_2::llama_time_us;

llama-cpp-2/src/sampling.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//! Safe wrapper around `llama_sampler`.
2+
pub mod params;
3+
4+
use std::fmt::{Debug, Formatter};
5+
use std::ptr::NonNull;
6+
7+
use crate::LlamaSamplerError;
8+
9+
/// A safe wrapper around `llama_sampler`.
10+
pub struct LlamaSampler {
11+
pub(crate) sampler: NonNull<llama_cpp_sys_2::llama_sampler>,
12+
}
13+
14+
impl Debug for LlamaSampler {
15+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
16+
f.debug_struct("LlamaSamplerChain").finish()
17+
}
18+
}
19+
20+
impl LlamaSampler {
21+
pub fn new(params: params::LlamaSamplerChainParams) -> Result<Self, LlamaSamplerError> {
22+
let sampler = unsafe {
23+
NonNull::new(llama_cpp_sys_2::llama_sampler_chain_init(
24+
params.sampler_chain_params,
25+
))
26+
.ok_or(LlamaSamplerError::NullReturn)
27+
}?;
28+
29+
Ok(Self { sampler })
30+
}
31+
}
32+
33+
impl Drop for LlamaSampler {
34+
fn drop(&mut self) {
35+
unsafe {
36+
llama_cpp_sys_2::llama_sampler_free(self.sampler.as_ptr());
37+
}
38+
}
39+
}

llama-cpp-2/src/sampling/params.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//! Safe wrapper around `llama_sampler`.
2+
3+
use std::fmt::{Debug, Formatter};
4+
use std::ptr::NonNull;
5+
6+
/// A safe wrapper around `llama_sampler`.
7+
pub struct LlamaSamplerChainParams {
8+
pub(crate) sampler_chain_params: llama_cpp_sys_2::llama_sampler_chain_params,
9+
}
10+
11+
impl Debug for LlamaSamplerChainParams {
12+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
13+
f.debug_struct("LlamaSamplerChainParams").finish()
14+
}
15+
}
16+
17+
impl Default for LlamaSamplerChainParams {
18+
fn default() -> Self {
19+
let sampler_chain_params = unsafe { llama_cpp_sys_2::llama_sampler_chain_default_params() };
20+
21+
Self {
22+
sampler_chain_params,
23+
}
24+
}
25+
}
26+
27+
impl LlamaSamplerChainParams {
28+
pub fn with_no_perf(&mut self, no_perf: bool) -> &mut Self {
29+
self.sampler_chain_params.no_perf = no_perf;
30+
self
31+
}
32+
33+
pub fn no_perf(&self) -> bool {
34+
self.sampler_chain_params.no_perf
35+
}
36+
}

0 commit comments

Comments
 (0)