Skip to content

Commit 1466f7e

Browse files
authored
Merge pull request #510 from brittlewis12/context-and-model-enhancements
2 parents 0ebae0b + 7d1b2d5 commit 1466f7e

File tree

5 files changed

+189
-22
lines changed

5 files changed

+189
-22
lines changed

examples/simple/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,15 @@ either reduce n_len or increase n_ctx"
247247
while n_cur <= n_len {
248248
// sample the next token
249249
{
250-
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
250+
let candidates = ctx.candidates();
251251

252252
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
253253

254254
// sample the most likely token
255255
let new_token_id = ctx.sample_token_greedy(candidates_p);
256256

257257
// is it an end of stream?
258-
if new_token_id == model.token_eos() {
258+
if model.is_eog_token(new_token_id) {
259259
eprintln!();
260260
break;
261261
}

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,21 @@
22
33
use crate::context::LlamaContext;
44
use std::ffi::c_int;
5-
use std::num::NonZeroU8;
5+
use std::num::{NonZeroU8, TryFromIntError};
6+
7+
/// Errors that can occur when attempting to prepare values for the kv cache
8+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9+
pub enum KvCacheConversionError {
10+
/// Sequence id conversion to i32 failed
11+
#[error("Provided sequence id is too large for a i32")]
12+
SeqIdTooLarge(#[source] TryFromIntError),
13+
/// Position 0 conversion to i32 failed
14+
#[error("Provided start position is too large for a i32")]
15+
P0TooLarge(#[source] TryFromIntError),
16+
/// Position 1 conversion to i32 failed
17+
#[error("Provided end position is too large for a i32")]
18+
P1TooLarge(#[source] TryFromIntError),
19+
}
620

721
impl LlamaContext<'_> {
822
/// Copy the cache from one sequence to another.
@@ -18,33 +32,63 @@ impl LlamaContext<'_> {
1832

1933
/// Copy the cache from one sequence to another.
2034
///
35+
/// # Returns
36+
/// A `Result` indicating whether the operation was successful. If the either position exceeds
37+
/// the maximum i32 value, no copy is attempted and an `Err` is returned.
38+
///
2139
/// # Parameters
2240
///
2341
/// * `src` - The sequence id to copy the cache from.
2442
/// * `dest` - The sequence id to copy the cache to.
2543
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
2644
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
27-
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
28-
let p0 = p0.map_or(-1, i32::from);
29-
let p1 = p1.map_or(-1, i32::from);
45+
pub fn copy_kv_cache_seq(
46+
&mut self,
47+
src: i32,
48+
dest: i32,
49+
p0: Option<u32>,
50+
p1: Option<u32>,
51+
) -> Result<(), KvCacheConversionError> {
52+
let p0 = p0
53+
.map_or(Ok(-1), i32::try_from)
54+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
55+
let p1 = p1
56+
.map_or(Ok(-1), i32::try_from)
57+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
3058
unsafe {
3159
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
3260
}
61+
Ok(())
3362
}
3463

35-
/// Clear the kv cache for the given sequence.
64+
/// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
65+
/// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
66+
///
67+
/// # Returns
68+
/// A `Result` indicating whether the operation was successful. If the sequence id or
69+
/// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
3670
///
3771
/// # Parameters
3872
///
39-
/// * `src` - The sequence id to clear the cache for.
73+
/// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
4074
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
4175
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
42-
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
43-
let p0 = p0.map_or(-1, i32::from);
44-
let p1 = p1.map_or(-1, i32::from);
45-
unsafe {
46-
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
47-
}
76+
pub fn clear_kv_cache_seq(
77+
&mut self,
78+
src: Option<u32>,
79+
p0: Option<u32>,
80+
p1: Option<u32>,
81+
) -> Result<bool, KvCacheConversionError> {
82+
let src = src
83+
.map_or(Ok(-1), i32::try_from)
84+
.map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?;
85+
let p0 = p0
86+
.map_or(Ok(-1), i32::try_from)
87+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
88+
let p1 = p1
89+
.map_or(Ok(-1), i32::try_from)
90+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
91+
Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) })
4892
}
4993

5094
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
@@ -73,25 +117,44 @@ impl LlamaContext<'_> {
73117
/// - lazily on next [`LlamaContext::decode`]
74118
/// - explicitly with [`Self::kv_cache_update`]
75119
///
120+
/// # Returns
121+
/// A `Result` indicating whether the operation was successful. If either position
122+
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
123+
///
76124
/// # Parameters
77125
///
78126
/// * `seq_id` - The sequence id to update
79127
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
80128
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
81129
/// * `delta` - The relative position to add to the tokens
82-
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
83-
let p0 = p0.map_or(-1, i32::from);
84-
let p1 = p1.map_or(-1, i32::from);
130+
pub fn kv_cache_seq_add(
131+
&mut self,
132+
seq_id: i32,
133+
p0: Option<u32>,
134+
p1: Option<u32>,
135+
delta: i32,
136+
) -> Result<(), KvCacheConversionError> {
137+
let p0 = p0
138+
.map_or(Ok(-1), i32::try_from)
139+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
140+
let p1 = p1
141+
.map_or(Ok(-1), i32::try_from)
142+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
85143
unsafe {
86144
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
87145
}
146+
Ok(())
88147
}
89148

90149
/// Integer division of the positions by factor of `d > 1`
91150
/// If the KV cache is `RoPEd`, the KV data is updated accordingly:
92151
/// - lazily on next [`LlamaContext::decode`]
93152
/// - explicitly with [`Self::kv_cache_update`]
94153
///
154+
/// # Returns
155+
/// A `Result` indicating whether the operation was successful. If either position
156+
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
157+
///
95158
/// # Parameters
96159
///
97160
/// * `seq_id` - The sequence id to update
@@ -101,14 +164,19 @@ impl LlamaContext<'_> {
101164
pub fn kv_cache_seq_div(
102165
&mut self,
103166
seq_id: i32,
104-
p0: Option<u16>,
105-
p1: Option<u16>,
167+
p0: Option<u32>,
168+
p1: Option<u32>,
106169
d: NonZeroU8,
107-
) {
108-
let p0 = p0.map_or(-1, i32::from);
109-
let p1 = p1.map_or(-1, i32::from);
170+
) -> Result<(), KvCacheConversionError> {
171+
let p0 = p0
172+
.map_or(Ok(-1), i32::try_from)
173+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
174+
let p1 = p1
175+
.map_or(Ok(-1), i32::try_from)
176+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
110177
let d = c_int::from(d.get());
111178
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
179+
Ok(())
112180
}
113181

114182
/// Returns the largest position present in the KV cache for the specified sequence

llama-cpp-2/src/context/params.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,66 @@ impl LlamaContextParams {
197197
self.context_params.n_ubatch
198198
}
199199

200+
/// Set the `flash_attention` parameter
201+
///
202+
/// # Examples
203+
///
204+
/// ```rust
205+
/// use llama_cpp_2::context::params::LlamaContextParams;
206+
/// let params = LlamaContextParams::default()
207+
/// .with_flash_attention(true);
208+
/// assert_eq!(params.flash_attention(), true);
209+
/// ```
210+
#[must_use]
211+
pub fn with_flash_attention(mut self, enabled: bool) -> Self {
212+
self.context_params.flash_attn = enabled;
213+
self
214+
}
215+
216+
/// Get the `flash_attention` parameter
217+
///
218+
/// # Examples
219+
///
220+
/// ```rust
221+
/// use llama_cpp_2::context::params::LlamaContextParams;
222+
/// let params = LlamaContextParams::default();
223+
/// assert_eq!(params.flash_attention(), false);
224+
/// ```
225+
#[must_use]
226+
pub fn flash_attention(&self) -> bool {
227+
self.context_params.flash_attn
228+
}
229+
230+
/// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
231+
///
232+
/// # Examples
233+
///
234+
/// ```rust
235+
/// use llama_cpp_2::context::params::LlamaContextParams;
236+
/// let params = LlamaContextParams::default()
237+
/// .with_offload_kqv(false);
238+
/// assert_eq!(params.offload_kqv(), false);
239+
/// ```
240+
#[must_use]
241+
pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
242+
self.context_params.offload_kqv = enabled;
243+
self
244+
}
245+
246+
/// Get the `offload_kqv` parameter
247+
///
248+
/// # Examples
249+
///
250+
/// ```rust
251+
/// use llama_cpp_2::context::params::LlamaContextParams;
252+
/// let params = LlamaContextParams::default();
253+
/// assert_eq!(params.offload_kqv(), true);
254+
/// ```
255+
#[must_use]
256+
pub fn offload_kqv(&self) -> bool {
257+
self.context_params.offload_kqv
258+
}
259+
200260
/// Set the type of rope scaling.
201261
///
202262
/// # Examples

llama-cpp-2/src/model.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ impl LlamaModel {
118118
LlamaToken(token)
119119
}
120120

121+
/// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
122+
#[must_use]
123+
pub fn is_eog_token(&self, token: LlamaToken) -> bool {
124+
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) }
125+
}
126+
121127
/// Get the decoder start token token.
122128
#[must_use]
123129
pub fn decode_start_token(&self) -> LlamaToken {

llama-cpp-2/src/token/data_array.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,4 +374,37 @@ impl LlamaTokenDataArray {
374374
*mu = unsafe { *mu_ptr };
375375
LlamaToken(token)
376376
}
377+
378+
/// Mirostat 1.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words.
379+
///
380+
/// # Parameters
381+
///
382+
/// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
383+
/// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
384+
/// * `m` The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
385+
/// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
386+
pub fn sample_token_mirostat_v1(
387+
&mut self,
388+
ctx: &mut LlamaContext,
389+
tau: f32,
390+
eta: f32,
391+
m: i32,
392+
mu: &mut f32,
393+
) -> LlamaToken {
394+
let mu_ptr = ptr::from_mut(mu);
395+
let token = unsafe {
396+
self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
397+
llama_cpp_sys_2::llama_sample_token_mirostat(
398+
ctx.context.as_ptr(),
399+
c_llama_token_data_array,
400+
tau,
401+
eta,
402+
m,
403+
mu_ptr,
404+
)
405+
})
406+
};
407+
*mu = unsafe { *mu_ptr };
408+
LlamaToken(token)
409+
}
377410
}

0 commit comments

Comments
 (0)