Skip to content

Commit 07ab0eb

Browse files
committed
updated llama cpp and removed cast to mut
1 parent 446d16d commit 07ab0eb

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
//! utilities for working with the kv cache
22
3-
use std::num::NonZeroU8;
43
use crate::context::LlamaContext;
4+
use std::ffi::c_int;
5+
use std::num::NonZeroU8;
56

67
impl LlamaContext<'_> {
78
/// Copy the cache from one sequence to another.
@@ -106,14 +107,20 @@ impl LlamaContext<'_> {
106107
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
107108
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
108109
/// * `d` - The factor to divide the positions by
109-
pub fn kv_cache_seq_div(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, d: NonZeroU8) {
110+
pub fn kv_cache_seq_div(
111+
&mut self,
112+
seq_id: i32,
113+
p0: Option<u16>,
114+
p1: Option<u16>,
115+
d: NonZeroU8,
116+
) {
110117
unsafe {
111118
llama_cpp_sys_2::llama_kv_cache_seq_div(
112119
self.context.as_ptr(),
113120
seq_id,
114121
p0.map_or(-1, i32::from),
115122
p1.map_or(-1, i32::from),
116-
d.get().try_into().expect("d does not fit into a i32"),
123+
c_int::from(d.get()),
117124
)
118125
}
119126
}
@@ -154,12 +161,12 @@ impl LlamaContext<'_> {
154161
/// if there are more sequences in a cell than this value, however they will
155162
/// not be visible in the view cells_sequences.
156163
pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView {
157-
let view = unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) };
164+
let view =
165+
unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) };
158166
KVCacheView { view, ctx: self }
159167
}
160168
}
161169

162-
163170
/// Information associated with an individual cell in the KV cache view.
164171
#[derive(Debug)]
165172
pub struct KVCacheViewCell {
@@ -178,7 +185,9 @@ pub struct KVCacheView<'a> {
178185
impl<'a> KVCacheView<'a> {
179186
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
180187
pub fn update(&mut self) {
181-
unsafe { llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view) }
188+
unsafe {
189+
llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view)
190+
}
182191
}
183192

184193
/// Number of KV cache cells. This will be the same as the context size.
@@ -210,16 +219,27 @@ impl<'a> KVCacheView<'a> {
210219
}
211220

212221
/// Information for individual cells.
213-
pub fn cells(&self) -> impl Iterator<Item=KVCacheViewCell> {
214-
unsafe { std::slice::from_raw_parts(self.view.cells, self.view.n_cells.try_into().unwrap()) }
215-
.iter()
216-
.map(|&cell| KVCacheViewCell { pos: cell.pos })
222+
pub fn cells(&self) -> impl Iterator<Item = KVCacheViewCell> {
223+
unsafe {
224+
std::slice::from_raw_parts(
225+
self.view.cells,
226+
usize::try_from(self.view.n_cells).expect("failed to fit n_cells into usize"),
227+
)
228+
}
229+
.iter()
230+
.map(|&cell| KVCacheViewCell { pos: cell.pos })
217231
}
218232

219233
/// The sequences for each cell. There will be n_max_seq items per cell.
220-
pub fn cells_sequences(&self) -> impl Iterator<Item=&[llama_cpp_sys_2::llama_seq_id]> {
221-
unsafe { std::slice::from_raw_parts(self.view.cells_sequences, (self.view.n_cells * self.view.n_max_seq).try_into().unwrap()) }
222-
.chunks(self.view.n_max_seq.try_into().unwrap())
234+
pub fn cells_sequences(&self) -> impl Iterator<Item = &[llama_cpp_sys_2::llama_seq_id]> {
235+
unsafe {
236+
std::slice::from_raw_parts(
237+
self.view.cells_sequences,
238+
usize::try_from(self.view.n_cells * self.view.n_max_seq)
239+
.expect("failed to fit n_cells * n_max_seq into usize"),
240+
)
241+
}
242+
.chunks(usize::try_from(self.view.n_max_seq).expect("failed to fit n_max_seq into usize"))
223243
}
224244
}
225245

@@ -229,4 +249,4 @@ impl<'a> Drop for KVCacheView<'a> {
229249
llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view);
230250
}
231251
}
232-
}
252+
}

llama-cpp-2/src/context/session.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,7 @@ impl LlamaContext<'_> {
150150
/// Returns the number of bytes read
151151
pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
152152
unsafe {
153-
// we don't really need a mutable pointer for `src` -- this is a llama-cpp lapse,
154-
// so we cast away the constness
155-
llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr() as *mut u8)
153+
llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr())
156154
}
157155
}
158156
}

llama-cpp-sys-2/llama.cpp

0 commit comments

Comments
 (0)