Skip to content

Commit 0ebae0b

Browse files
authored
Merge pull request #509 from brittlewis12/candidates
2 parents b1420f3 + 5f429c2 commit 0ebae0b

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

llama-cpp-2/src/context.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,45 @@ impl<'model> LlamaContext<'model> {
187187
}
188188
}
189189

190+
/// Get the logits for the last token in the context.
191+
///
192+
/// # Returns
193+
/// An iterator over unsorted `LlamaTokenData` containing the
194+
/// logits for the last token in the context.
195+
///
196+
/// # Panics
197+
///
198+
/// - underlying logits data is null
199+
pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
200+
(0_i32..).zip(self.get_logits()).map(|(i, logit)| {
201+
let token = LlamaToken::new(i);
202+
LlamaTokenData::new(token, *logit, 0_f32)
203+
})
204+
}
205+
206+
/// Token logits obtained from the last call to `decode()`.
207+
/// The logits for which `batch.logits[i] != 0` are stored contiguously
208+
/// in the order they have appeared in the batch.
209+
/// Rows: number of tokens for which `batch.logits[i] != 0`
210+
/// Cols: `n_vocab`
211+
///
212+
/// # Returns
213+
///
214+
/// A slice containing the logits for the last decoded token.
215+
/// The size corresponds to the `n_vocab` parameter of the context's model.
216+
///
217+
/// # Panics
218+
///
219+
/// - `n_vocab` does not fit into a usize
220+
/// - token data returned is null
221+
pub fn get_logits(&self) -> &[f32] {
222+
let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
223+
assert!(!data.is_null(), "logits data for last token is null");
224+
let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
225+
226+
unsafe { slice::from_raw_parts(data, len) }
227+
}
228+
190229
/// Get the logits for the ith token in the context.
191230
///
192231
/// # Panics

0 commit comments

Comments
 (0)