Skip to content

Commit cb0ecd9

Browse files
committed
wrap llama_batch_get_one
1 parent 77af620 commit cb0ecd9

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

llama-cpp-2/src/llama_batch.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,25 @@ impl LlamaBatch {
149149
}
150150
}
151151

152+
/// llama_batch_get_one
153+
/// Return batch for single sequence of tokens starting at pos_0
154+
///
155+
/// NOTE: this is a helper function to facilitate transition to the new batch API
156+
///
157+
pub fn get_one(tokens: &[LlamaToken], pos_0: llama_pos, seq_id: llama_seq_id) -> Self {
158+
unsafe {
159+
let ptr = tokens.as_ptr() as *mut i32;
160+
let batch =
161+
llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32, pos_0, seq_id);
162+
163+
crate::llama_batch::LlamaBatch {
164+
allocated: 0,
165+
initialized_logits: vec![],
166+
llama_batch: batch,
167+
}
168+
}
169+
}
170+
152171
/// Returns the number of tokens in the batch.
153172
#[must_use]
154173
pub fn n_tokens(&self) -> i32 {
@@ -170,7 +189,9 @@ impl Drop for LlamaBatch {
170189
/// # }
171190
fn drop(&mut self) {
172191
unsafe {
173-
llama_batch_free(self.llama_batch);
192+
if self.allocated > 0 {
193+
llama_batch_free(self.llama_batch);
194+
}
174195
}
175196
}
176197
}

0 commit comments

Comments
 (0)