Skip to content

Commit 99d8dd1

Browse files
authored
Merge branch 'main' into dependabot/cargo/thiserror-1.0.64
2 parents bf65f31 + 4333caa commit 99d8dd1

File tree

8 files changed

+269
-26
lines changed

8 files changed

+269
-26
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ hf-hub = { version = "0.3.2" }
1212
criterion = "0.5.1"
1313
pprof = "0.13.0"
1414
bindgen = "0.69.4"
15-
cc = "1.1.14"
15+
cc = "1.1.21"
1616
anyhow = "1.0.86"
1717
clap = "4.5.16"
1818
encoding_rs = "0.8.34"

examples/simple/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,15 @@ either reduce n_len or increase n_ctx"
247247
while n_cur <= n_len {
248248
// sample the next token
249249
{
250-
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
250+
let candidates = ctx.candidates();
251251

252252
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
253253

254254
// sample the most likely token
255255
let new_token_id = ctx.sample_token_greedy(candidates_p);
256256

257257
// is it an end of stream?
258-
if new_token_id == model.token_eos() {
258+
if model.is_eog_token(new_token_id) {
259259
eprintln!();
260260
break;
261261
}

llama-cpp-2/src/context.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,18 @@ impl<'model> LlamaContext<'model> {
5252
}
5353
}
5454

55-
/// Gets the max number of tokens in a batch.
55+
/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch.
5656
#[must_use]
5757
pub fn n_batch(&self) -> u32 {
5858
unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
5959
}
6060

61+
/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch.
62+
#[must_use]
63+
pub fn n_ubatch(&self) -> u32 {
64+
unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
65+
}
66+
6167
/// Gets the size of the context.
6268
#[must_use]
6369
pub fn n_ctx(&self) -> u32 {
@@ -181,6 +187,45 @@ impl<'model> LlamaContext<'model> {
181187
}
182188
}
183189

190+
/// Get the logits for the last token in the context.
191+
///
192+
/// # Returns
193+
/// An iterator over unsorted `LlamaTokenData` containing the
194+
/// logits for the last token in the context.
195+
///
196+
/// # Panics
197+
///
198+
/// - underlying logits data is null
199+
pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
200+
(0_i32..).zip(self.get_logits()).map(|(i, logit)| {
201+
let token = LlamaToken::new(i);
202+
LlamaTokenData::new(token, *logit, 0_f32)
203+
})
204+
}
205+
206+
/// Token logits obtained from the last call to `decode()`.
207+
/// The logits for which `batch.logits[i] != 0` are stored contiguously
208+
/// in the order they have appeared in the batch.
209+
/// Rows: number of tokens for which `batch.logits[i] != 0`
210+
/// Cols: `n_vocab`
211+
///
212+
/// # Returns
213+
///
214+
/// A slice containing the logits for the last decoded token.
215+
/// The size corresponds to the `n_vocab` parameter of the context's model.
216+
///
217+
/// # Panics
218+
///
219+
/// - `n_vocab` does not fit into a usize
220+
/// - token data returned is null
221+
pub fn get_logits(&self) -> &[f32] {
222+
let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
223+
assert!(!data.is_null(), "logits data for last token is null");
224+
let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
225+
226+
unsafe { slice::from_raw_parts(data, len) }
227+
}
228+
184229
/// Get the logits for the ith token in the context.
185230
///
186231
/// # Panics

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,21 @@
22
33
use crate::context::LlamaContext;
44
use std::ffi::c_int;
5-
use std::num::NonZeroU8;
5+
use std::num::{NonZeroU8, TryFromIntError};
6+
7+
/// Errors that can occur when attempting to prepare values for the kv cache
8+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9+
pub enum KvCacheConversionError {
10+
/// Sequence id conversion to i32 failed
11+
#[error("Provided sequence id is too large for a i32")]
12+
SeqIdTooLarge(#[source] TryFromIntError),
13+
/// Position 0 conversion to i32 failed
14+
#[error("Provided start position is too large for a i32")]
15+
P0TooLarge(#[source] TryFromIntError),
16+
/// Position 1 conversion to i32 failed
17+
#[error("Provided end position is too large for a i32")]
18+
P1TooLarge(#[source] TryFromIntError),
19+
}
620

721
impl LlamaContext<'_> {
822
/// Copy the cache from one sequence to another.
@@ -18,33 +32,63 @@ impl LlamaContext<'_> {
1832

1933
/// Copy the cache from one sequence to another.
2034
///
35+
/// # Returns
36+
/// A `Result` indicating whether the operation was successful. If the either position exceeds
37+
/// the maximum i32 value, no copy is attempted and an `Err` is returned.
38+
///
2139
/// # Parameters
2240
///
2341
/// * `src` - The sequence id to copy the cache from.
2442
/// * `dest` - The sequence id to copy the cache to.
2543
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
2644
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
27-
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
28-
let p0 = p0.map_or(-1, i32::from);
29-
let p1 = p1.map_or(-1, i32::from);
45+
pub fn copy_kv_cache_seq(
46+
&mut self,
47+
src: i32,
48+
dest: i32,
49+
p0: Option<u32>,
50+
p1: Option<u32>,
51+
) -> Result<(), KvCacheConversionError> {
52+
let p0 = p0
53+
.map_or(Ok(-1), i32::try_from)
54+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
55+
let p1 = p1
56+
.map_or(Ok(-1), i32::try_from)
57+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
3058
unsafe {
3159
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
3260
}
61+
Ok(())
3362
}
3463

35-
/// Clear the kv cache for the given sequence.
64+
/// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
65+
/// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
66+
///
67+
/// # Returns
68+
/// A `Result` indicating whether the operation was successful. If the sequence id or
69+
/// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
3670
///
3771
/// # Parameters
3872
///
39-
/// * `src` - The sequence id to clear the cache for.
73+
/// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
4074
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
4175
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
42-
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
43-
let p0 = p0.map_or(-1, i32::from);
44-
let p1 = p1.map_or(-1, i32::from);
45-
unsafe {
46-
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
47-
}
76+
pub fn clear_kv_cache_seq(
77+
&mut self,
78+
src: Option<u32>,
79+
p0: Option<u32>,
80+
p1: Option<u32>,
81+
) -> Result<bool, KvCacheConversionError> {
82+
let src = src
83+
.map_or(Ok(-1), i32::try_from)
84+
.map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?;
85+
let p0 = p0
86+
.map_or(Ok(-1), i32::try_from)
87+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
88+
let p1 = p1
89+
.map_or(Ok(-1), i32::try_from)
90+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
91+
Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) })
4892
}
4993

5094
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
@@ -73,25 +117,44 @@ impl LlamaContext<'_> {
73117
/// - lazily on next [`LlamaContext::decode`]
74118
/// - explicitly with [`Self::kv_cache_update`]
75119
///
120+
/// # Returns
121+
/// A `Result` indicating whether the operation was successful. If either position
122+
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
123+
///
76124
/// # Parameters
77125
///
78126
/// * `seq_id` - The sequence id to update
79127
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
80128
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
81129
/// * `delta` - The relative position to add to the tokens
82-
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
83-
let p0 = p0.map_or(-1, i32::from);
84-
let p1 = p1.map_or(-1, i32::from);
130+
pub fn kv_cache_seq_add(
131+
&mut self,
132+
seq_id: i32,
133+
p0: Option<u32>,
134+
p1: Option<u32>,
135+
delta: i32,
136+
) -> Result<(), KvCacheConversionError> {
137+
let p0 = p0
138+
.map_or(Ok(-1), i32::try_from)
139+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
140+
let p1 = p1
141+
.map_or(Ok(-1), i32::try_from)
142+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
85143
unsafe {
86144
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
87145
}
146+
Ok(())
88147
}
89148

90149
/// Integer division of the positions by factor of `d > 1`
91150
/// If the KV cache is `RoPEd`, the KV data is updated accordingly:
92151
/// - lazily on next [`LlamaContext::decode`]
93152
/// - explicitly with [`Self::kv_cache_update`]
94153
///
154+
/// # Returns
155+
/// A `Result` indicating whether the operation was successful. If either position
156+
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
157+
///
95158
/// # Parameters
96159
///
97160
/// * `seq_id` - The sequence id to update
@@ -101,14 +164,19 @@ impl LlamaContext<'_> {
101164
pub fn kv_cache_seq_div(
102165
&mut self,
103166
seq_id: i32,
104-
p0: Option<u16>,
105-
p1: Option<u16>,
167+
p0: Option<u32>,
168+
p1: Option<u32>,
106169
d: NonZeroU8,
107-
) {
108-
let p0 = p0.map_or(-1, i32::from);
109-
let p1 = p1.map_or(-1, i32::from);
170+
) -> Result<(), KvCacheConversionError> {
171+
let p0 = p0
172+
.map_or(Ok(-1), i32::try_from)
173+
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
174+
let p1 = p1
175+
.map_or(Ok(-1), i32::try_from)
176+
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
110177
let d = c_int::from(d.get());
111178
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
179+
Ok(())
112180
}
113181

114182
/// Returns the largest position present in the KV cache for the specified sequence

llama-cpp-2/src/context/params.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,97 @@ impl LlamaContextParams {
166166
self.context_params.n_batch
167167
}
168168

169+
/// Set the `n_ubatch`
170+
///
171+
/// # Examples
172+
///
173+
/// ```rust
174+
/// # use std::num::NonZeroU32;
175+
/// use llama_cpp_2::context::params::LlamaContextParams;
176+
/// let params = LlamaContextParams::default()
177+
/// .with_n_ubatch(512);
178+
/// assert_eq!(params.n_ubatch(), 512);
179+
/// ```
180+
#[must_use]
181+
pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
182+
self.context_params.n_ubatch = n_ubatch;
183+
self
184+
}
185+
186+
/// Get the `n_ubatch`
187+
///
188+
/// # Examples
189+
///
190+
/// ```rust
191+
/// use llama_cpp_2::context::params::LlamaContextParams;
192+
/// let params = LlamaContextParams::default();
193+
/// assert_eq!(params.n_ubatch(), 512);
194+
/// ```
195+
#[must_use]
196+
pub fn n_ubatch(&self) -> u32 {
197+
self.context_params.n_ubatch
198+
}
199+
200+
/// Set the `flash_attention` parameter
201+
///
202+
/// # Examples
203+
///
204+
/// ```rust
205+
/// use llama_cpp_2::context::params::LlamaContextParams;
206+
/// let params = LlamaContextParams::default()
207+
/// .with_flash_attention(true);
208+
/// assert_eq!(params.flash_attention(), true);
209+
/// ```
210+
#[must_use]
211+
pub fn with_flash_attention(mut self, enabled: bool) -> Self {
212+
self.context_params.flash_attn = enabled;
213+
self
214+
}
215+
216+
/// Get the `flash_attention` parameter
217+
///
218+
/// # Examples
219+
///
220+
/// ```rust
221+
/// use llama_cpp_2::context::params::LlamaContextParams;
222+
/// let params = LlamaContextParams::default();
223+
/// assert_eq!(params.flash_attention(), false);
224+
/// ```
225+
#[must_use]
226+
pub fn flash_attention(&self) -> bool {
227+
self.context_params.flash_attn
228+
}
229+
230+
/// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
231+
///
232+
/// # Examples
233+
///
234+
/// ```rust
235+
/// use llama_cpp_2::context::params::LlamaContextParams;
236+
/// let params = LlamaContextParams::default()
237+
/// .with_offload_kqv(false);
238+
/// assert_eq!(params.offload_kqv(), false);
239+
/// ```
240+
#[must_use]
241+
pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
242+
self.context_params.offload_kqv = enabled;
243+
self
244+
}
245+
246+
/// Get the `offload_kqv` parameter
247+
///
248+
/// # Examples
249+
///
250+
/// ```rust
251+
/// use llama_cpp_2::context::params::LlamaContextParams;
252+
/// let params = LlamaContextParams::default();
253+
/// assert_eq!(params.offload_kqv(), true);
254+
/// ```
255+
#[must_use]
256+
pub fn offload_kqv(&self) -> bool {
257+
self.context_params.offload_kqv
258+
}
259+
169260
/// Set the type of rope scaling.
170261
///
171262
/// # Examples

llama-cpp-2/src/model.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ impl LlamaModel {
118118
LlamaToken(token)
119119
}
120120

121+
/// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
122+
#[must_use]
123+
pub fn is_eog_token(&self, token: LlamaToken) -> bool {
124+
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) }
125+
}
126+
121127
/// Get the decoder start token token.
122128
#[must_use]
123129
pub fn decode_start_token(&self) -> LlamaToken {

0 commit comments

Comments
 (0)