Skip to content

Commit c1e17d7

Browse files
committed
Add top_n_sigma sampler
1 parent dabcb10 commit c1e17d7

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

llama-cpp-2/src/sampling.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,37 @@ impl LlamaSampler {
191191
Self { sampler }
192192
}
193193

194+
/// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
195+
/// <https://arxiv.org/pdf/2411.07641>
196+
///
197+
/// This method filters logits by selecting only those within *n* standard deviations of the mean.
198+
///
199+
/// # Parameters
200+
/// - `n`: Number of standard deviations from the mean to include in sampling
201+
///
202+
/// # Example
203+
/// ```rust
204+
/// use llama_cpp_2::sampling::LlamaSampler;
205+
/// use llama_cpp_2::token::{
206+
/// LlamaToken,
207+
/// data::LlamaTokenData,
208+
/// data_array::LlamaTokenDataArray
209+
/// };
210+
///
211+
/// let mut data_array = LlamaTokenDataArray::new(vec![
212+
/// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
213+
/// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
214+
/// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
215+
/// ], false);
216+
///
217+
/// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
218+
/// ```
219+
#[must_use]
220+
pub fn top_n_sigma(n: f32) -> Self {
221+
let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
222+
Self { sampler }
223+
}
224+
194225
/// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
195226
#[must_use]
196227
pub fn typical(p: f32, min_keep: usize) -> Self {

0 commit comments

Comments
 (0)