Skip to content

Commit 7d1b2d5

Browse files
committed
Handle KV cache mutations for llama_pos values greater than u16
* return `Result`s to handle failed u32 -> i32 conversion * unify kv cache seq rm methods
1 parent b10bd0a commit 7d1b2d5

File tree

1 file changed

+88
-41
lines changed

1 file changed

+88
-41
lines changed

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 88 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,21 @@
22
33
use crate::context::LlamaContext;
44
use std::ffi::c_int;
5-
use std::num::NonZeroU8;
5+
use std::num::{NonZeroU8, TryFromIntError};
6+
7+
/// Errors that can occur when attempting to prepare values for the kv cache
8+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9+
pub enum KvCacheConversionError {
10+
/// Sequence id conversion to i32 failed
11+
#[error("Provided sequence id is too large for a i32")]
12+
SeqIdTooLarge(#[source] TryFromIntError),
13+
/// Position 0 conversion to i32 failed
14+
#[error("Provided start position is too large for a i32")]
15+
P0TooLarge(#[source] TryFromIntError),
16+
/// Position 1 conversion to i32 failed
17+
#[error("Provided end position is too large for a i32")]
18+
P1TooLarge(#[source] TryFromIntError),
19+
}
620

721
impl LlamaContext<'_> {
822
/// Copy the cache from one sequence to another.
@@ -18,33 +32,63 @@ impl LlamaContext<'_> {
1832

1933
/// Copy the cache from one sequence to another.
2034
///
35+
/// # Returns
36+
/// A `Result` indicating whether the operation was successful. If the either position exceeds
37+
/// the maximum i32 value, no copy is attempted and an `Err` is returned.
38+
///
2139
/// # Parameters
2240
///
2341
/// * `src` - The sequence id to copy the cache from.
2442
/// * `dest` - The sequence id to copy the cache to.
2543
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
2644
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
27-
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
28-
let p0 = p0.map_or(-1, i32::from);
29-
let p1 = p1.map_or(-1, i32::from);
45+
pub fn copy_kv_cache_seq(
46+
&mut self,
47+
src: i32,
48+
dest: i32,
49+
p0: Option<u32>,
50+
p1: Option<u32>,
51+
) -> Result<(), KvCacheConversionError> {
52+
let p0 = p0
53+
.map_or(Ok(-1), i32::try_from)
54+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
55+
let p1 = p1
56+
.map_or(Ok(-1), i32::try_from)
57+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
3058
unsafe {
3159
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
3260
}
61+
Ok(())
3362
}
3463

35-
/// Clear the kv cache for the given sequence.
64+
/// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
65+
/// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
66+
///
67+
/// # Returns
68+
/// A `Result` indicating whether the operation was successful. If the sequence id or
69+
/// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
3670
///
3771
/// # Parameters
3872
///
39-
/// * `src` - The sequence id to clear the cache for.
73+
/// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
4074
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
4175
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
42-
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
43-
let p0 = p0.map_or(-1, i32::from);
44-
let p1 = p1.map_or(-1, i32::from);
45-
unsafe {
46-
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
47-
}
76+
pub fn clear_kv_cache_seq(
77+
&mut self,
78+
src: Option<u32>,
79+
p0: Option<u32>,
80+
p1: Option<u32>,
81+
) -> Result<bool, KvCacheConversionError> {
82+
let src = src
83+
.map_or(Ok(-1), i32::try_from)
84+
.map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?;
85+
let p0 = p0
86+
.map_or(Ok(-1), i32::try_from)
87+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
88+
let p1 = p1
89+
.map_or(Ok(-1), i32::try_from)
90+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
91+
Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) })
4892
}
4993

5094
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
@@ -73,25 +117,44 @@ impl LlamaContext<'_> {
73117
/// - lazily on next [`LlamaContext::decode`]
74118
/// - explicitly with [`Self::kv_cache_update`]
75119
///
120+
/// # Returns
121+
/// A `Result` indicating whether the operation was successful. If either position
122+
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
123+
///
76124
/// # Parameters
77125
///
78126
/// * `seq_id` - The sequence id to update
79127
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
80128
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
81129
/// * `delta` - The relative position to add to the tokens
82-
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
83-
let p0 = p0.map_or(-1, i32::from);
84-
let p1 = p1.map_or(-1, i32::from);
130+
pub fn kv_cache_seq_add(
131+
&mut self,
132+
seq_id: i32,
133+
p0: Option<u32>,
134+
p1: Option<u32>,
135+
delta: i32,
136+
) -> Result<(), KvCacheConversionError> {
137+
let p0 = p0
138+
.map_or(Ok(-1), i32::try_from)
139+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
140+
let p1 = p1
141+
.map_or(Ok(-1), i32::try_from)
142+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
85143
unsafe {
86144
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
87145
}
146+
Ok(())
88147
}
89148

90149
/// Integer division of the positions by factor of `d > 1`
91150
/// If the KV cache is `RoPEd`, the KV data is updated accordingly:
92151
/// - lazily on next [`LlamaContext::decode`]
93152
/// - explicitly with [`Self::kv_cache_update`]
94153
///
154+
/// # Returns
155+
/// A `Result` indicating whether the operation was successful. If either position
156+
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
157+
///
95158
/// # Parameters
96159
///
97160
/// * `seq_id` - The sequence id to update
@@ -101,14 +164,19 @@ impl LlamaContext<'_> {
101164
pub fn kv_cache_seq_div(
102165
&mut self,
103166
seq_id: i32,
104-
p0: Option<u16>,
105-
p1: Option<u16>,
167+
p0: Option<u32>,
168+
p1: Option<u32>,
106169
d: NonZeroU8,
107-
) {
108-
let p0 = p0.map_or(-1, i32::from);
109-
let p1 = p1.map_or(-1, i32::from);
170+
) -> Result<(), KvCacheConversionError> {
171+
let p0 = p0
172+
.map_or(Ok(-1), i32::try_from)
173+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
174+
let p1 = p1
175+
.map_or(Ok(-1), i32::try_from)
176+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
110177
let d = c_int::from(d.get());
111178
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
179+
Ok(())
112180
}
113181

114182
/// Returns the largest position present in the KV cache for the specified sequence
@@ -121,27 +189,6 @@ impl LlamaContext<'_> {
121189
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_pos_max(self.context.as_ptr(), seq_id) }
122190
}
123191

124-
/// Remove all tokens within the specified range `[p0, p1)` from the KV cache
125-
/// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
126-
///
127-
/// # Parameters
128-
///
129-
/// * `seq_id` - The sequence id to remove the tokens from. If `None`, matches all sequences
130-
/// * `p0` - The start position of the cache to remove. If `None`, the entire cache is removed up to `p1`
131-
/// * `p1` - The end position of the cache to remove. If `None`, the entire cache is removed starting from `p0`
132-
#[must_use]
133-
pub fn kv_cache_seq_rm(
134-
&mut self,
135-
seq_id: Option<u16>,
136-
p0: Option<u16>,
137-
p1: Option<u16>,
138-
) -> bool {
139-
let seq_id = seq_id.map_or(-1, i32::from);
140-
let p0 = p0.map_or(-1, i32::from);
141-
let p1 = p1.map_or(-1, i32::from);
142-
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), seq_id, p0, p1) }
143-
}
144-
145192
/// Defragment the KV cache
146193
/// This will be applied:
147194
/// - lazily on next [`LlamaContext::decode`]

0 commit comments

Comments
 (0)