1
1
//! utilities for working with the kv cache
2
2
3
- use std:: num:: NonZeroU8 ;
4
3
use crate :: context:: LlamaContext ;
4
+ use std:: ffi:: c_int;
5
+ use std:: num:: NonZeroU8 ;
5
6
6
7
impl LlamaContext < ' _ > {
7
8
/// Copy the cache from one sequence to another.
@@ -24,14 +25,10 @@ impl LlamaContext<'_> {
24
25
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
25
26
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
26
27
pub fn copy_kv_cache_seq ( & mut self , src : i32 , dest : i32 , p0 : Option < u16 > , p1 : Option < u16 > ) {
28
+ let p0 = p0. map_or ( -1 , i32:: from) ;
29
+ let p1 = p1. map_or ( -1 , i32:: from) ;
27
30
unsafe {
28
- llama_cpp_sys_2:: llama_kv_cache_seq_cp (
29
- self . context . as_ptr ( ) ,
30
- src,
31
- dest,
32
- p0. map_or ( -1 , i32:: from) ,
33
- p1. map_or ( -1 , i32:: from) ,
34
- )
31
+ llama_cpp_sys_2:: llama_kv_cache_seq_cp ( self . context . as_ptr ( ) , src, dest, p0, p1) ;
35
32
}
36
33
}
37
34
@@ -43,17 +40,15 @@ impl LlamaContext<'_> {
43
40
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1].
44
41
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0].
45
42
pub fn clear_kv_cache_seq ( & mut self , src : i32 , p0 : Option < u16 > , p1 : Option < u16 > ) {
43
+ let p0 = p0. map_or ( -1 , i32:: from) ;
44
+ let p1 = p1. map_or ( -1 , i32:: from) ;
46
45
unsafe {
47
- llama_cpp_sys_2:: llama_kv_cache_seq_rm (
48
- self . context . as_ptr ( ) ,
49
- src,
50
- p0. map_or ( -1 , i32:: from) ,
51
- p1. map_or ( -1 , i32:: from) ,
52
- ) ;
46
+ llama_cpp_sys_2:: llama_kv_cache_seq_rm ( self . context . as_ptr ( ) , src, p0, p1) ;
53
47
}
54
48
}
55
49
56
50
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
51
+ #[ must_use]
57
52
pub fn get_kv_cache_used_cells ( & self ) -> i32 {
58
53
unsafe { llama_cpp_sys_2:: llama_get_kv_cache_used_cells ( self . context . as_ptr ( ) ) }
59
54
}
@@ -74,8 +69,8 @@ impl LlamaContext<'_> {
74
69
75
70
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
76
71
/// If the KV cache is RoPEd, the KV data is updated accordingly:
77
- /// - lazily on next llama_decode()
78
- /// - explicitly with llama_kv_cache_update()
72
+ /// - lazily on next [`LlamaContext::decode`]
73
+ /// - explicitly with [`Self::kv_cache_update`]
79
74
///
80
75
/// # Parameters
81
76
///
@@ -84,53 +79,51 @@ impl LlamaContext<'_> {
84
79
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
85
80
/// * `delta` - The relative position to add to the tokens
86
81
pub fn kv_cache_seq_add ( & mut self , seq_id : i32 , p0 : Option < u16 > , p1 : Option < u16 > , delta : i32 ) {
82
+ let p0 = p0. map_or ( -1 , i32:: from) ;
83
+ let p1 = p1. map_or ( -1 , i32:: from) ;
87
84
unsafe {
88
- llama_cpp_sys_2:: llama_kv_cache_seq_add (
89
- self . context . as_ptr ( ) ,
90
- seq_id,
91
- p0. map_or ( -1 , i32:: from) ,
92
- p1. map_or ( -1 , i32:: from) ,
93
- delta,
94
- )
85
+ llama_cpp_sys_2:: llama_kv_cache_seq_add ( self . context . as_ptr ( ) , seq_id, p0, p1, delta) ;
95
86
}
96
87
}
97
88
98
89
/// Integer division of the positions by factor of `d > 1`
99
- /// If the KV cache is RoPEd, the KV data is updated accordingly:
100
- /// - lazily on next llama_decode()
101
- /// - explicitly with llama_kv_cache_update()
90
+ /// If the KV cache is ` RoPEd` , the KV data is updated accordingly:
91
+ /// - lazily on next [`LlamaContext::decode`]
92
+ /// - explicitly with [`Self::kv_cache_update`]
102
93
///
103
94
/// # Parameters
104
95
///
105
96
/// * `seq_id` - The sequence id to update
106
97
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
107
98
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
108
99
/// * `d` - The factor to divide the positions by
109
- pub fn kv_cache_seq_div ( & mut self , seq_id : i32 , p0 : Option < u16 > , p1 : Option < u16 > , d : NonZeroU8 ) {
110
- unsafe {
111
- llama_cpp_sys_2:: llama_kv_cache_seq_div (
112
- self . context . as_ptr ( ) ,
113
- seq_id,
114
- p0. map_or ( -1 , i32:: from) ,
115
- p1. map_or ( -1 , i32:: from) ,
116
- d. get ( ) . try_into ( ) . expect ( "d does not fit into a i32" ) ,
117
- )
118
- }
100
+ pub fn kv_cache_seq_div (
101
+ & mut self ,
102
+ seq_id : i32 ,
103
+ p0 : Option < u16 > ,
104
+ p1 : Option < u16 > ,
105
+ d : NonZeroU8 ,
106
+ ) {
107
+ let p0 = p0. map_or ( -1 , i32:: from) ;
108
+ let p1 = p1. map_or ( -1 , i32:: from) ;
109
+ let d = c_int:: from ( d. get ( ) ) ;
110
+ unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_div ( self . context . as_ptr ( ) , seq_id, p0, p1, d) }
119
111
}
120
112
121
113
/// Returns the largest position present in the KV cache for the specified sequence
122
114
///
123
115
/// # Parameters
124
116
///
125
117
/// * `seq_id` - The sequence id to get the max position for
118
+ #[ must_use]
126
119
pub fn kv_cache_seq_pos_max ( & self , seq_id : i32 ) -> i32 {
127
120
unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_pos_max ( self . context . as_ptr ( ) , seq_id) }
128
121
}
129
122
130
123
/// Defragment the KV cache
131
124
/// This will be applied:
132
- /// - lazily on next llama_decode()
133
- /// - explicitly with llama_kv_cache_update()
125
+ /// - lazily on next [`LlamaContext::decode`]
126
+ /// - explicitly with [`Self::kv_cache_update`]
134
127
pub fn kv_cache_defrag ( & mut self ) {
135
128
unsafe { llama_cpp_sys_2:: llama_kv_cache_defrag ( self . context . as_ptr ( ) ) }
136
129
}
@@ -142,6 +135,7 @@ impl LlamaContext<'_> {
142
135
143
136
/// Returns the number of tokens in the KV cache (slow, use only for debug)
144
137
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
138
+ #[ must_use]
145
139
pub fn get_kv_cache_token_count ( & self ) -> i32 {
146
140
unsafe { llama_cpp_sys_2:: llama_get_kv_cache_token_count ( self . context . as_ptr ( ) ) }
147
141
}
@@ -152,14 +146,15 @@ impl LlamaContext<'_> {
152
146
///
153
147
/// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error
154
148
/// if there are more sequences in a cell than this value, however they will
155
- /// not be visible in the view cells_sequences.
149
+ /// not be visible in the view `cells_sequences`.
150
+ #[ must_use]
156
151
pub fn new_kv_cache_view ( & self , n_max_seq : i32 ) -> KVCacheView {
157
- let view = unsafe { llama_cpp_sys_2:: llama_kv_cache_view_init ( self . context . as_ptr ( ) , n_max_seq) } ;
152
+ let view =
153
+ unsafe { llama_cpp_sys_2:: llama_kv_cache_view_init ( self . context . as_ptr ( ) , n_max_seq) } ;
158
154
KVCacheView { view, ctx : self }
159
155
}
160
156
}
161
157
162
-
163
158
/// Information associated with an individual cell in the KV cache view.
164
159
#[ derive( Debug ) ]
165
160
pub struct KVCacheViewCell {
@@ -178,48 +173,75 @@ pub struct KVCacheView<'a> {
178
173
impl < ' a > KVCacheView < ' a > {
179
174
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
180
175
pub fn update ( & mut self ) {
181
- unsafe { llama_cpp_sys_2:: llama_kv_cache_view_update ( self . ctx . context . as_ptr ( ) , & mut self . view ) }
176
+ unsafe {
177
+ llama_cpp_sys_2:: llama_kv_cache_view_update ( self . ctx . context . as_ptr ( ) , & mut self . view ) ;
178
+ }
182
179
}
183
180
184
181
/// Number of KV cache cells. This will be the same as the context size.
182
+ #[ must_use]
185
183
pub fn n_cells ( & self ) -> i32 {
186
184
self . view . n_cells
187
185
}
188
186
189
187
/// Number of tokens in the cache. For example, if there are two populated
190
188
/// cells, the first with 1 sequence id in it and the second with 2 sequence
191
189
/// ids then you'll have 3 tokens.
190
+ #[ must_use]
192
191
pub fn token_count ( & self ) -> i32 {
193
192
self . view . token_count
194
193
}
195
194
196
195
/// Number of populated cache cells.
196
+ #[ must_use]
197
197
pub fn used_cells ( & self ) -> i32 {
198
198
self . view . used_cells
199
199
}
200
200
201
201
/// Maximum contiguous empty slots in the cache.
202
+ #[ must_use]
202
203
pub fn max_contiguous ( & self ) -> i32 {
203
204
self . view . max_contiguous
204
205
}
205
206
206
- /// Index to the start of the max_contiguous slot range. Can be negative
207
+ /// Index to the start of the ` max_contiguous` slot range. Can be negative
207
208
/// when cache is full.
209
+ #[ must_use]
208
210
pub fn max_contiguous_idx ( & self ) -> i32 {
209
211
self . view . max_contiguous_idx
210
212
}
211
213
212
214
/// Information for individual cells.
213
- pub fn cells ( & self ) -> impl Iterator < Item =KVCacheViewCell > {
214
- unsafe { std:: slice:: from_raw_parts ( self . view . cells , self . view . n_cells . try_into ( ) . unwrap ( ) ) }
215
- . iter ( )
216
- . map ( |& cell| KVCacheViewCell { pos : cell. pos } )
215
+ ///
216
+ /// # Panics
217
+ ///
218
+ /// - if `n_cells` does not fit into usize.
219
+ pub fn cells ( & self ) -> impl Iterator < Item = KVCacheViewCell > {
220
+ unsafe {
221
+ std:: slice:: from_raw_parts (
222
+ self . view . cells ,
223
+ usize:: try_from ( self . view . n_cells ) . expect ( "failed to fit n_cells into usize" ) ,
224
+ )
225
+ }
226
+ . iter ( )
227
+ . map ( |& cell| KVCacheViewCell { pos : cell. pos } )
217
228
}
218
229
219
- /// The sequences for each cell. There will be n_max_seq items per cell.
220
- pub fn cells_sequences ( & self ) -> impl Iterator < Item =& [ llama_cpp_sys_2:: llama_seq_id ] > {
221
- unsafe { std:: slice:: from_raw_parts ( self . view . cells_sequences , ( self . view . n_cells * self . view . n_max_seq ) . try_into ( ) . unwrap ( ) ) }
222
- . chunks ( self . view . n_max_seq . try_into ( ) . unwrap ( ) )
230
+ /// The sequences for each cell. There will be `n_max_seq` items per cell.
231
+ ///
232
+ /// # Panics
233
+ ///
234
+ /// - if `n_cells * n_max_seq` does not fit into usize.
235
+ /// - if `n_max_seq` does not fit into usize.
236
+ pub fn cells_sequences ( & self ) -> impl Iterator < Item = & [ llama_cpp_sys_2:: llama_seq_id ] > {
237
+ unsafe {
238
+ std:: slice:: from_raw_parts (
239
+ self . view . cells_sequences ,
240
+ usize:: try_from ( self . view . n_cells * self . view . n_max_seq )
241
+ . expect ( "failed to fit n_cells * n_max_seq into usize" ) ,
242
+ )
243
+ }
244
+ . chunks ( usize:: try_from ( self . view . n_max_seq ) . expect ( "failed to fit n_max_seq into usize" ) )
223
245
}
224
246
}
225
247
@@ -229,4 +251,4 @@ impl<'a> Drop for KVCacheView<'a> {
229
251
llama_cpp_sys_2:: llama_kv_cache_view_free ( & mut self . view ) ;
230
252
}
231
253
}
232
- }
254
+ }
0 commit comments