Skip to content

Commit aeb76dc

Browse files
committed
Add LlamaTokenDataArray::with_sampler; use Borrow instead of AsRef for LlamaToken
1 parent 7aa4367 commit aeb76dc

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

llama-cpp-2/src/sampling.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Safe wrapper around `llama_sampler`.
22
3+
use std::borrow::Borrow;
34
use std::ffi::CString;
45
use std::fmt::{Debug, Formatter};
56

@@ -43,16 +44,16 @@ impl LlamaSampler {
4344

4445
/// Accepts several tokens from the sampler or context, possibly updating the internal state of
4546
/// certain samplers (e.g. grammar, repetition, etc.)
46-
pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl AsRef<LlamaToken>>) {
47+
pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
4748
for token in tokens {
48-
unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.as_ref().0) }
49+
unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) }
4950
}
5051
}
5152

5253
/// Accepts several tokens from the sampler or context, possibly updating the internal state of
5354
/// certain samplers (e.g. grammar, repetition, etc.)
5455
#[must_use]
55-
pub fn with_tokens(mut self, tokens: impl IntoIterator<Item = impl AsRef<LlamaToken>>) -> Self {
56+
pub fn with_tokens(mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) -> Self {
5657
self.accept_many(tokens);
5758
self
5859
}

llama-cpp-2/src/token/data_array.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,27 @@ impl LlamaTokenDataArray {
132132
}
133133
}
134134

135+
/// Modifies the data array by applying a sampler to it
136+
#[must_use]
137+
pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
138+
self.apply_sampler(sampler);
139+
self
140+
}
141+
135142
/// Randomly selects a token from the candidates based on their probabilities.
143+
///
144+
/// # Panics
145+
/// If the internal llama.cpp sampler fails to select a token.
136146
pub fn sample_token(&mut self, seed: u32) -> LlamaToken {
137147
self.apply_sampler(&mut LlamaSampler::dist(seed));
138148
self.selected_token()
139149
.expect("Dist sampler failed to select a token!")
140150
}
141151

142152
/// Selects the token with the highest probability.
153+
///
154+
/// # Panics
155+
/// If the internal llama.cpp sampler fails to select a token.
143156
pub fn sample_token_greedy(&mut self) -> LlamaToken {
144157
self.apply_sampler(&mut LlamaSampler::greedy());
145158
self.selected_token()

0 commit comments

Comments
 (0)