11//! utilities for working with the kv cache
22
3- use std:: num:: NonZeroU8 ;
43use crate :: context:: LlamaContext ;
4+ use std:: ffi:: c_int;
5+ use std:: num:: NonZeroU8 ;
56
67impl LlamaContext < ' _ > {
78 /// Copy the cache from one sequence to another.
@@ -106,14 +107,20 @@ impl LlamaContext<'_> {
106107 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
107108 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
108109 /// * `d` - The factor to divide the positions by
109- pub fn kv_cache_seq_div ( & mut self , seq_id : i32 , p0 : Option < u16 > , p1 : Option < u16 > , d : NonZeroU8 ) {
110+ pub fn kv_cache_seq_div (
111+ & mut self ,
112+ seq_id : i32 ,
113+ p0 : Option < u16 > ,
114+ p1 : Option < u16 > ,
115+ d : NonZeroU8 ,
116+ ) {
110117 unsafe {
111118 llama_cpp_sys_2:: llama_kv_cache_seq_div (
112119 self . context . as_ptr ( ) ,
113120 seq_id,
114121 p0. map_or ( -1 , i32:: from) ,
115122 p1. map_or ( -1 , i32:: from) ,
116- d. get ( ) . try_into ( ) . expect ( "d does not fit into a i32" ) ,
123+ c_int :: from ( d. get ( ) ) ,
117124 )
118125 }
119126 }
@@ -154,12 +161,12 @@ impl LlamaContext<'_> {
154161 /// if there are more sequences in a cell than this value, however they will
155162 /// not be visible in the view cells_sequences.
156163 pub fn new_kv_cache_view ( & self , n_max_seq : i32 ) -> KVCacheView {
157- let view = unsafe { llama_cpp_sys_2:: llama_kv_cache_view_init ( self . context . as_ptr ( ) , n_max_seq) } ;
164+ let view =
165+ unsafe { llama_cpp_sys_2:: llama_kv_cache_view_init ( self . context . as_ptr ( ) , n_max_seq) } ;
158166 KVCacheView { view, ctx : self }
159167 }
160168}
161169
162-
163170/// Information associated with an individual cell in the KV cache view.
164171#[ derive( Debug ) ]
165172pub struct KVCacheViewCell {
@@ -178,7 +185,9 @@ pub struct KVCacheView<'a> {
178185impl < ' a > KVCacheView < ' a > {
179186 /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
180187 pub fn update ( & mut self ) {
181- unsafe { llama_cpp_sys_2:: llama_kv_cache_view_update ( self . ctx . context . as_ptr ( ) , & mut self . view ) }
188+ unsafe {
189+ llama_cpp_sys_2:: llama_kv_cache_view_update ( self . ctx . context . as_ptr ( ) , & mut self . view )
190+ }
182191 }
183192
184193 /// Number of KV cache cells. This will be the same as the context size.
@@ -210,16 +219,27 @@ impl<'a> KVCacheView<'a> {
210219 }
211220
212221 /// Information for individual cells.
213- pub fn cells ( & self ) -> impl Iterator < Item =KVCacheViewCell > {
214- unsafe { std:: slice:: from_raw_parts ( self . view . cells , self . view . n_cells . try_into ( ) . unwrap ( ) ) }
215- . iter ( )
216- . map ( |& cell| KVCacheViewCell { pos : cell. pos } )
222+ pub fn cells ( & self ) -> impl Iterator < Item = KVCacheViewCell > {
223+ unsafe {
224+ std:: slice:: from_raw_parts (
225+ self . view . cells ,
226+ usize:: try_from ( self . view . n_cells ) . expect ( "failed to fit n_cells into usize" ) ,
227+ )
228+ }
229+ . iter ( )
230+ . map ( |& cell| KVCacheViewCell { pos : cell. pos } )
217231 }
218232
219233 /// The sequences for each cell. There will be n_max_seq items per cell.
220- pub fn cells_sequences ( & self ) -> impl Iterator < Item =& [ llama_cpp_sys_2:: llama_seq_id ] > {
221- unsafe { std:: slice:: from_raw_parts ( self . view . cells_sequences , ( self . view . n_cells * self . view . n_max_seq ) . try_into ( ) . unwrap ( ) ) }
222- . chunks ( self . view . n_max_seq . try_into ( ) . unwrap ( ) )
234+ pub fn cells_sequences ( & self ) -> impl Iterator < Item = & [ llama_cpp_sys_2:: llama_seq_id ] > {
235+ unsafe {
236+ std:: slice:: from_raw_parts (
237+ self . view . cells_sequences ,
238+ usize:: try_from ( self . view . n_cells * self . view . n_max_seq )
239+ . expect ( "failed to fit n_cells * n_max_seq into usize" ) ,
240+ )
241+ }
242+ . chunks ( usize:: try_from ( self . view . n_max_seq ) . expect ( "failed to fit n_max_seq into usize" ) )
223243 }
224244}
225245
@@ -229,4 +249,4 @@ impl<'a> Drop for KVCacheView<'a> {
229249 llama_cpp_sys_2:: llama_kv_cache_view_free ( & mut self . view ) ;
230250 }
231251 }
232- }
252+ }
0 commit comments