22
33use crate :: context:: LlamaContext ;
44use std:: ffi:: c_int;
5- use std:: num:: NonZeroU8 ;
5+ use std:: num:: { NonZeroU8 , TryFromIntError } ;
6+
7+ /// Errors that can occur when attempting to prepare values for the kv cache
8+ #[ derive( Debug , Eq , PartialEq , thiserror:: Error ) ]
9+ pub enum KvCacheConversionError {
10+ /// Sequence id conversion to i32 failed
11+ #[ error( "Provided sequence id is too large for a i32" ) ]
12+ SeqIdTooLarge ( #[ source] TryFromIntError ) ,
13+ /// Position 0 conversion to i32 failed
14+ #[ error( "Provided start position is too large for a i32" ) ]
15+ P0TooLarge ( #[ source] TryFromIntError ) ,
16+ /// Position 1 conversion to i32 failed
17+ #[ error( "Provided end position is too large for a i32" ) ]
18+ P1TooLarge ( #[ source] TryFromIntError ) ,
19+ }
620
721impl LlamaContext < ' _ > {
822 /// Copy the cache from one sequence to another.
@@ -18,33 +32,63 @@ impl LlamaContext<'_> {
1832
1933 /// Copy the cache from one sequence to another.
2034 ///
35+ /// # Returns
36+ /// A `Result` indicating whether the operation was successful. If the either position exceeds
37+ /// the maximum i32 value, no copy is attempted and an `Err` is returned.
38+ ///
2139 /// # Parameters
2240 ///
2341 /// * `src` - The sequence id to copy the cache from.
2442 /// * `dest` - The sequence id to copy the cache to.
2543 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
2644 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
27- pub fn copy_kv_cache_seq ( & mut self , src : i32 , dest : i32 , p0 : Option < u16 > , p1 : Option < u16 > ) {
28- let p0 = p0. map_or ( -1 , i32:: from) ;
29- let p1 = p1. map_or ( -1 , i32:: from) ;
45+ pub fn copy_kv_cache_seq (
46+ & mut self ,
47+ src : i32 ,
48+ dest : i32 ,
49+ p0 : Option < u32 > ,
50+ p1 : Option < u32 > ,
51+ ) -> Result < ( ) , KvCacheConversionError > {
52+ let p0 = p0
53+ . map_or ( Ok ( -1 ) , i32:: try_from)
54+ . map_err ( |e| KvCacheConversionError :: P0TooLarge ( e) ) ?;
55+ let p1 = p1
56+ . map_or ( Ok ( -1 ) , i32:: try_from)
57+ . map_err ( |e| KvCacheConversionError :: P1TooLarge ( e) ) ?;
3058 unsafe {
3159 llama_cpp_sys_2:: llama_kv_cache_seq_cp ( self . context . as_ptr ( ) , src, dest, p0, p1) ;
3260 }
61+ Ok ( ( ) )
3362 }
3463
35- /// Clear the kv cache for the given sequence.
64+ /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
65+ /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
66+ ///
67+ /// # Returns
68+ /// A `Result` indicating whether the operation was successful. If the sequence id or
69+ /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
3670 ///
3771 /// # Parameters
3872 ///
39- /// * `src` - The sequence id to clear the cache for.
73+ /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
4074 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
4175 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
42- pub fn clear_kv_cache_seq ( & mut self , src : i32 , p0 : Option < u16 > , p1 : Option < u16 > ) {
43- let p0 = p0. map_or ( -1 , i32:: from) ;
44- let p1 = p1. map_or ( -1 , i32:: from) ;
45- unsafe {
46- llama_cpp_sys_2:: llama_kv_cache_seq_rm ( self . context . as_ptr ( ) , src, p0, p1) ;
47- }
76+ pub fn clear_kv_cache_seq (
77+ & mut self ,
78+ src : Option < u32 > ,
79+ p0 : Option < u32 > ,
80+ p1 : Option < u32 > ,
81+ ) -> Result < bool , KvCacheConversionError > {
82+ let src = src
83+ . map_or ( Ok ( -1 ) , i32:: try_from)
84+ . map_err ( |e| KvCacheConversionError :: SeqIdTooLarge ( e) ) ?;
85+ let p0 = p0
86+ . map_or ( Ok ( -1 ) , i32:: try_from)
87+ . map_err ( |e| KvCacheConversionError :: P0TooLarge ( e) ) ?;
88+ let p1 = p1
89+ . map_or ( Ok ( -1 ) , i32:: try_from)
90+ . map_err ( |e| KvCacheConversionError :: P1TooLarge ( e) ) ?;
91+ Ok ( unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_rm ( self . context . as_ptr ( ) , src, p0, p1) } )
4892 }
4993
5094 /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
@@ -73,25 +117,44 @@ impl LlamaContext<'_> {
73117 /// - lazily on next [`LlamaContext::decode`]
74118 /// - explicitly with [`Self::kv_cache_update`]
75119 ///
120+ /// # Returns
121+ /// A `Result` indicating whether the operation was successful. If either position
122+ /// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
123+ ///
76124 /// # Parameters
77125 ///
78126 /// * `seq_id` - The sequence id to update
79127 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
80128 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
81129 /// * `delta` - The relative position to add to the tokens
82- pub fn kv_cache_seq_add ( & mut self , seq_id : i32 , p0 : Option < u16 > , p1 : Option < u16 > , delta : i32 ) {
83- let p0 = p0. map_or ( -1 , i32:: from) ;
84- let p1 = p1. map_or ( -1 , i32:: from) ;
130+ pub fn kv_cache_seq_add (
131+ & mut self ,
132+ seq_id : i32 ,
133+ p0 : Option < u32 > ,
134+ p1 : Option < u32 > ,
135+ delta : i32 ,
136+ ) -> Result < ( ) , KvCacheConversionError > {
137+ let p0 = p0
138+ . map_or ( Ok ( -1 ) , i32:: try_from)
139+ . map_err ( |e| KvCacheConversionError :: P0TooLarge ( e) ) ?;
140+ let p1 = p1
141+ . map_or ( Ok ( -1 ) , i32:: try_from)
142+ . map_err ( |e| KvCacheConversionError :: P1TooLarge ( e) ) ?;
85143 unsafe {
86144 llama_cpp_sys_2:: llama_kv_cache_seq_add ( self . context . as_ptr ( ) , seq_id, p0, p1, delta) ;
87145 }
146+ Ok ( ( ) )
88147 }
89148
90149 /// Integer division of the positions by factor of `d > 1`
91150 /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
92151 /// - lazily on next [`LlamaContext::decode`]
93152 /// - explicitly with [`Self::kv_cache_update`]
94153 ///
154+ /// # Returns
155+ /// A `Result` indicating whether the operation was successful. If either position
156+ /// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
157+ ///
95158 /// # Parameters
96159 ///
97160 /// * `seq_id` - The sequence id to update
@@ -101,14 +164,19 @@ impl LlamaContext<'_> {
101164 pub fn kv_cache_seq_div (
102165 & mut self ,
103166 seq_id : i32 ,
104- p0 : Option < u16 > ,
105- p1 : Option < u16 > ,
167+ p0 : Option < u32 > ,
168+ p1 : Option < u32 > ,
106169 d : NonZeroU8 ,
107- ) {
108- let p0 = p0. map_or ( -1 , i32:: from) ;
109- let p1 = p1. map_or ( -1 , i32:: from) ;
170+ ) -> Result < ( ) , KvCacheConversionError > {
171+ let p0 = p0
172+ . map_or ( Ok ( -1 ) , i32:: try_from)
173+ . map_err ( |e| KvCacheConversionError :: P0TooLarge ( e) ) ?;
174+ let p1 = p1
175+ . map_or ( Ok ( -1 ) , i32:: try_from)
176+ . map_err ( |e| KvCacheConversionError :: P1TooLarge ( e) ) ?;
110177 let d = c_int:: from ( d. get ( ) ) ;
111178 unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_div ( self . context . as_ptr ( ) , seq_id, p0, p1, d) }
179+ Ok ( ( ) )
112180 }
113181
114182 /// Returns the largest position present in the KV cache for the specified sequence
0 commit comments