Skip to content

Commit 2849fcd

Browse files
authored
Merge pull request #119 from utilityai/update-llama-cpp
updated llama cpp and removed cast to mut
2 parents 446d16d + 035cb57 commit 2849fcd

File tree

3 files changed

+106
-78
lines changed

3 files changed

+106
-78
lines changed

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
//! utilities for working with the kv cache
22
3-
use std::num::NonZeroU8;
43
use crate::context::LlamaContext;
4+
use std::ffi::c_int;
5+
use std::num::NonZeroU8;
56

67
impl LlamaContext<'_> {
78
/// Copy the cache from one sequence to another.
@@ -24,14 +25,10 @@ impl LlamaContext<'_> {
2425
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
2526
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
2627
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
28+
let p0 = p0.map_or(-1, i32::from);
29+
let p1 = p1.map_or(-1, i32::from);
2730
unsafe {
28-
llama_cpp_sys_2::llama_kv_cache_seq_cp(
29-
self.context.as_ptr(),
30-
src,
31-
dest,
32-
p0.map_or(-1, i32::from),
33-
p1.map_or(-1, i32::from),
34-
)
31+
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
3532
}
3633
}
3734

@@ -43,17 +40,15 @@ impl LlamaContext<'_> {
4340
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1].
4441
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0].
4542
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
43+
let p0 = p0.map_or(-1, i32::from);
44+
let p1 = p1.map_or(-1, i32::from);
4645
unsafe {
47-
llama_cpp_sys_2::llama_kv_cache_seq_rm(
48-
self.context.as_ptr(),
49-
src,
50-
p0.map_or(-1, i32::from),
51-
p1.map_or(-1, i32::from),
52-
);
46+
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
5347
}
5448
}
5549

5650
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
51+
#[must_use]
5752
pub fn get_kv_cache_used_cells(&self) -> i32 {
5853
unsafe { llama_cpp_sys_2::llama_get_kv_cache_used_cells(self.context.as_ptr()) }
5954
}
@@ -74,8 +69,8 @@ impl LlamaContext<'_> {
7469

7570
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
7671
/// If the KV cache is RoPEd, the KV data is updated accordingly:
77-
/// - lazily on next llama_decode()
78-
/// - explicitly with llama_kv_cache_update()
72+
/// - lazily on next [`LlamaContext::decode`]
73+
/// - explicitly with [`Self::kv_cache_update`]
7974
///
8075
/// # Parameters
8176
///
@@ -84,53 +79,51 @@ impl LlamaContext<'_> {
8479
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
8580
/// * `delta` - The relative position to add to the tokens
8681
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
82+
let p0 = p0.map_or(-1, i32::from);
83+
let p1 = p1.map_or(-1, i32::from);
8784
unsafe {
88-
llama_cpp_sys_2::llama_kv_cache_seq_add(
89-
self.context.as_ptr(),
90-
seq_id,
91-
p0.map_or(-1, i32::from),
92-
p1.map_or(-1, i32::from),
93-
delta,
94-
)
85+
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
9586
}
9687
}
9788

9889
/// Integer division of the positions by factor of `d > 1`
99-
/// If the KV cache is RoPEd, the KV data is updated accordingly:
100-
/// - lazily on next llama_decode()
101-
/// - explicitly with llama_kv_cache_update()
90+
/// If the KV cache is `RoPEd`, the KV data is updated accordingly:
91+
/// - lazily on next [`LlamaContext::decode`]
92+
/// - explicitly with [`Self::kv_cache_update`]
10293
///
10394
/// # Parameters
10495
///
10596
/// * `seq_id` - The sequence id to update
10697
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
10798
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
10899
/// * `d` - The factor to divide the positions by
109-
pub fn kv_cache_seq_div(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, d: NonZeroU8) {
110-
unsafe {
111-
llama_cpp_sys_2::llama_kv_cache_seq_div(
112-
self.context.as_ptr(),
113-
seq_id,
114-
p0.map_or(-1, i32::from),
115-
p1.map_or(-1, i32::from),
116-
d.get().try_into().expect("d does not fit into a i32"),
117-
)
118-
}
100+
pub fn kv_cache_seq_div(
101+
&mut self,
102+
seq_id: i32,
103+
p0: Option<u16>,
104+
p1: Option<u16>,
105+
d: NonZeroU8,
106+
) {
107+
let p0 = p0.map_or(-1, i32::from);
108+
let p1 = p1.map_or(-1, i32::from);
109+
let d = c_int::from(d.get());
110+
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
119111
}
120112

121113
/// Returns the largest position present in the KV cache for the specified sequence
122114
///
123115
/// # Parameters
124116
///
125117
/// * `seq_id` - The sequence id to get the max position for
118+
#[must_use]
126119
pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
127120
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_pos_max(self.context.as_ptr(), seq_id) }
128121
}
129122

130123
/// Defragment the KV cache
131124
/// This will be applied:
132-
/// - lazily on next llama_decode()
133-
/// - explicitly with llama_kv_cache_update()
125+
/// - lazily on next [`LlamaContext::decode`]
126+
/// - explicitly with [`Self::kv_cache_update`]
134127
pub fn kv_cache_defrag(&mut self) {
135128
unsafe { llama_cpp_sys_2::llama_kv_cache_defrag(self.context.as_ptr()) }
136129
}
@@ -142,6 +135,7 @@ impl LlamaContext<'_> {
142135

143136
/// Returns the number of tokens in the KV cache (slow, use only for debug)
144137
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
138+
#[must_use]
145139
pub fn get_kv_cache_token_count(&self) -> i32 {
146140
unsafe { llama_cpp_sys_2::llama_get_kv_cache_token_count(self.context.as_ptr()) }
147141
}
@@ -152,14 +146,15 @@ impl LlamaContext<'_> {
152146
///
153147
/// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error
154148
/// if there are more sequences in a cell than this value, however they will
155-
/// not be visible in the view cells_sequences.
149+
/// not be visible in the view `cells_sequences`.
150+
#[must_use]
156151
pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView {
157-
let view = unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) };
152+
let view =
153+
unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) };
158154
KVCacheView { view, ctx: self }
159155
}
160156
}
161157

162-
163158
/// Information associated with an individual cell in the KV cache view.
164159
#[derive(Debug)]
165160
pub struct KVCacheViewCell {
@@ -178,48 +173,75 @@ pub struct KVCacheView<'a> {
178173
impl<'a> KVCacheView<'a> {
179174
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
180175
pub fn update(&mut self) {
181-
unsafe { llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view) }
176+
unsafe {
177+
llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view);
178+
}
182179
}
183180

184181
/// Number of KV cache cells. This will be the same as the context size.
182+
#[must_use]
185183
pub fn n_cells(&self) -> i32 {
186184
self.view.n_cells
187185
}
188186

189187
/// Number of tokens in the cache. For example, if there are two populated
190188
/// cells, the first with 1 sequence id in it and the second with 2 sequence
191189
/// ids then you'll have 3 tokens.
190+
#[must_use]
192191
pub fn token_count(&self) -> i32 {
193192
self.view.token_count
194193
}
195194

196195
/// Number of populated cache cells.
196+
#[must_use]
197197
pub fn used_cells(&self) -> i32 {
198198
self.view.used_cells
199199
}
200200

201201
/// Maximum contiguous empty slots in the cache.
202+
#[must_use]
202203
pub fn max_contiguous(&self) -> i32 {
203204
self.view.max_contiguous
204205
}
205206

206-
/// Index to the start of the max_contiguous slot range. Can be negative
207+
/// Index to the start of the `max_contiguous` slot range. Can be negative
207208
/// when cache is full.
209+
#[must_use]
208210
pub fn max_contiguous_idx(&self) -> i32 {
209211
self.view.max_contiguous_idx
210212
}
211213

212214
/// Information for individual cells.
213-
pub fn cells(&self) -> impl Iterator<Item=KVCacheViewCell> {
214-
unsafe { std::slice::from_raw_parts(self.view.cells, self.view.n_cells.try_into().unwrap()) }
215-
.iter()
216-
.map(|&cell| KVCacheViewCell { pos: cell.pos })
215+
///
216+
/// # Panics
217+
///
218+
/// - if `n_cells` does not fit into usize.
219+
pub fn cells(&self) -> impl Iterator<Item = KVCacheViewCell> {
220+
unsafe {
221+
std::slice::from_raw_parts(
222+
self.view.cells,
223+
usize::try_from(self.view.n_cells).expect("failed to fit n_cells into usize"),
224+
)
225+
}
226+
.iter()
227+
.map(|&cell| KVCacheViewCell { pos: cell.pos })
217228
}
218229

219-
/// The sequences for each cell. There will be n_max_seq items per cell.
220-
pub fn cells_sequences(&self) -> impl Iterator<Item=&[llama_cpp_sys_2::llama_seq_id]> {
221-
unsafe { std::slice::from_raw_parts(self.view.cells_sequences, (self.view.n_cells * self.view.n_max_seq).try_into().unwrap()) }
222-
.chunks(self.view.n_max_seq.try_into().unwrap())
230+
/// The sequences for each cell. There will be `n_max_seq` items per cell.
231+
///
232+
/// # Panics
233+
///
234+
/// - if `n_cells * n_max_seq` does not fit into usize.
235+
/// - if `n_max_seq` does not fit into usize.
236+
pub fn cells_sequences(&self) -> impl Iterator<Item = &[llama_cpp_sys_2::llama_seq_id]> {
237+
unsafe {
238+
std::slice::from_raw_parts(
239+
self.view.cells_sequences,
240+
usize::try_from(self.view.n_cells * self.view.n_max_seq)
241+
.expect("failed to fit n_cells * n_max_seq into usize"),
242+
)
243+
}
244+
.chunks(usize::try_from(self.view.n_max_seq).expect("failed to fit n_max_seq into usize"))
223245
}
224246
}
225247

@@ -229,4 +251,4 @@ impl<'a> Drop for KVCacheView<'a> {
229251
llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view);
230252
}
231253
}
232-
}
254+
}

llama-cpp-2/src/context/session.rs

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -105,54 +105,60 @@ impl LlamaContext<'_> {
105105
.ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
106106

107107
let cstr = CString::new(path)?;
108-
let mut tokens = Vec::with_capacity(max_tokens);
108+
let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
109109
let mut n_out = 0;
110110

111-
unsafe {
112-
if llama_cpp_sys_2::llama_load_session_file(
111+
// SAFETY: cast is valid as LlamaToken is repr(transparent)
112+
let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
113+
114+
let load_session_success = unsafe {
115+
llama_cpp_sys_2::llama_load_session_file(
113116
self.context.as_ptr(),
114117
cstr.as_ptr(),
115-
// cast is valid as LlamaToken is repr(transparent)
116-
Vec::<LlamaToken>::as_mut_ptr(&mut tokens).cast::<llama_cpp_sys_2::llama_token>(),
118+
tokens_out,
117119
max_tokens,
118120
&mut n_out,
119-
) {
120-
if n_out > max_tokens {
121-
return Err(LoadSessionError::InsufficientMaxLength {
122-
n_out,
123-
max_tokens,
124-
});
125-
}
121+
)
122+
};
123+
if load_session_success {
124+
if n_out > max_tokens {
125+
return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
126+
}
127+
// SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
128+
unsafe {
126129
tokens.set_len(n_out);
127-
Ok(tokens)
128-
} else {
129-
Err(LoadSessionError::FailedToLoad)
130130
}
131+
Ok(tokens)
132+
} else {
133+
Err(LoadSessionError::FailedToLoad)
131134
}
132135
}
133136

134137
/// Returns the maximum size in bytes of the state (rng, logits, embedding
135-
/// and kv_cache) - will often be smaller after compacting tokens
138+
/// and `kv_cache`) - will often be smaller after compacting tokens
139+
#[must_use]
136140
pub fn get_state_size(&self) -> usize {
137141
unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
138142
}
139143

140144
/// Copies the state to the specified destination address.
141-
/// Destination needs to have allocated enough memory.
145+
///
142146
/// Returns the number of bytes copied
147+
///
148+
/// # Safety
149+
///
150+
/// Destination needs to have allocated enough memory.
143151
pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
144-
unsafe {
145-
llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest)
146-
}
152+
unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
147153
}
148154

149155
/// Set the state reading from the specified address
150156
/// Returns the number of bytes read
157+
///
158+
/// # Safety
159+
///
160+
/// help wanted: not entirely sure what the safety requirements are here.
151161
pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
152-
unsafe {
153-
// we don't really need a mutable pointer for `src` -- this is a llama-cpp lapse,
154-
// so we cast away the constness
155-
llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr() as *mut u8)
156-
}
162+
unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
157163
}
158164
}

llama-cpp-sys-2/llama.cpp

0 commit comments

Comments
 (0)