Skip to content

Commit 0b2867c

Browse files
committed
added some safety comments and decreased unsafe scope
1 parent 07ab0eb commit 0b2867c

File tree

2 files changed

+43
-51
lines changed

2 files changed

+43
-51
lines changed

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,10 @@ impl LlamaContext<'_> {
2525
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
2626
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
2727
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
28+
let p0 = p0.map_or(-1, i32::from);
29+
let p1 = p1.map_or(-1, i32::from);
2830
unsafe {
29-
llama_cpp_sys_2::llama_kv_cache_seq_cp(
30-
self.context.as_ptr(),
31-
src,
32-
dest,
33-
p0.map_or(-1, i32::from),
34-
p1.map_or(-1, i32::from),
35-
)
31+
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1)
3632
}
3733
}
3834

@@ -44,13 +40,10 @@ impl LlamaContext<'_> {
4440
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1].
4541
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0].
4642
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
43+
let p0 = p0.map_or(-1, i32::from);
44+
let p1 = p1.map_or(-1, i32::from);
4745
unsafe {
48-
llama_cpp_sys_2::llama_kv_cache_seq_rm(
49-
self.context.as_ptr(),
50-
src,
51-
p0.map_or(-1, i32::from),
52-
p1.map_or(-1, i32::from),
53-
);
46+
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
5447
}
5548
}
5649

@@ -85,14 +78,10 @@ impl LlamaContext<'_> {
8578
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
8679
/// * `delta` - The relative position to add to the tokens
8780
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
81+
let p0 = p0.map_or(-1, i32::from);
82+
let p1 = p1.map_or(-1, i32::from);
8883
unsafe {
89-
llama_cpp_sys_2::llama_kv_cache_seq_add(
90-
self.context.as_ptr(),
91-
seq_id,
92-
p0.map_or(-1, i32::from),
93-
p1.map_or(-1, i32::from),
94-
delta,
95-
)
84+
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta)
9685
}
9786
}
9887

@@ -114,15 +103,10 @@ impl LlamaContext<'_> {
114103
p1: Option<u16>,
115104
d: NonZeroU8,
116105
) {
117-
unsafe {
118-
llama_cpp_sys_2::llama_kv_cache_seq_div(
119-
self.context.as_ptr(),
120-
seq_id,
121-
p0.map_or(-1, i32::from),
122-
p1.map_or(-1, i32::from),
123-
c_int::from(d.get()),
124-
)
125-
}
106+
let p0 = p0.map_or(-1, i32::from);
107+
let p1 = p1.map_or(-1, i32::from);
108+
let d = c_int::from(d.get());
109+
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
126110
}
127111

128112
/// Returns the largest position present in the KV cache for the specified sequence

llama-cpp-2/src/context/session.rs

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -105,29 +105,33 @@ impl LlamaContext<'_> {
105105
.ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
106106

107107
let cstr = CString::new(path)?;
108-
let mut tokens = Vec::with_capacity(max_tokens);
108+
let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
109109
let mut n_out = 0;
110110

111-
unsafe {
112-
if llama_cpp_sys_2::llama_load_session_file(
111+
// SAFETY: cast is valid as LlamaToken is repr(transparent)
112+
let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
113+
114+
let load_session_success = unsafe {
115+
llama_cpp_sys_2::llama_load_session_file(
113116
self.context.as_ptr(),
114117
cstr.as_ptr(),
115-
// cast is valid as LlamaToken is repr(transparent)
116-
Vec::<LlamaToken>::as_mut_ptr(&mut tokens).cast::<llama_cpp_sys_2::llama_token>(),
118+
tokens_out,
117119
max_tokens,
118120
&mut n_out,
119-
) {
120-
if n_out > max_tokens {
121-
return Err(LoadSessionError::InsufficientMaxLength {
122-
n_out,
123-
max_tokens,
124-
});
125-
}
126-
tokens.set_len(n_out);
127-
Ok(tokens)
121+
)
122+
};
123+
if load_session_success {
124+
if n_out > max_tokens {
125+
return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
128126
} else {
129-
Err(LoadSessionError::FailedToLoad)
127+
// SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
128+
unsafe {
129+
tokens.set_len(n_out);
130+
}
130131
}
132+
Ok(tokens)
133+
} else {
134+
Err(LoadSessionError::FailedToLoad)
131135
}
132136
}
133137

@@ -138,19 +142,23 @@ impl LlamaContext<'_> {
138142
}
139143

140144
/// Copies the state to the specified destination address.
141-
/// Destination needs to have allocated enough memory.
145+
///
142146
/// Returns the number of bytes copied
147+
///
148+
/// # Safety
149+
///
150+
/// Destination needs to have allocated enough memory.
143151
pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
144-
unsafe {
145-
llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest)
146-
}
152+
unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
147153
}
148154

149155
/// Set the state reading from the specified address
150156
/// Returns the number of bytes read
157+
///
158+
/// # Safety
159+
///
160+
/// help wanted: not entirely sure what the safety requirements are here.
151161
pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
152-
unsafe {
153-
llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr())
154-
}
162+
unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
155163
}
156164
}

0 commit comments

Comments
 (0)