Skip to content

Commit d3eade6

Browse files
authored
Merge pull request #579 from tinglou/main
wrap llama_batch_get_one
2 parents 42aaeeb + 2822e3a commit d3eade6

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

llama-cpp-2/src/llama_batch.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ pub enum BatchAddError {
2020
/// There was not enough space in the batch to add the token.
2121
#[error("Insufficient Space of {0}")]
2222
InsufficientSpace(usize),
23+
/// Empty buffer is provided for get_one
24+
#[error("Empty buffer")]
25+
EmptyBuffer,
2326
}
2427

2528
impl LlamaBatch {
@@ -149,6 +152,31 @@ impl LlamaBatch {
149152
}
150153
}
151154

155+
/// llama_batch_get_one
156+
/// Return batch for single sequence of tokens starting at pos_0
157+
///
158+
/// NOTE: this is a helper function to facilitate transition to the new batch API
159+
///
160+
pub fn get_one(
161+
tokens: &[LlamaToken],
162+
pos_0: llama_pos,
163+
seq_id: llama_seq_id,
164+
) -> Result<Self, BatchAddError> {
165+
if tokens.is_empty() {
166+
return Err(BatchAddError::EmptyBuffer);
167+
}
168+
let batch = unsafe {
169+
let ptr = tokens.as_ptr() as *mut i32;
170+
llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32, pos_0, seq_id)
171+
};
172+
let batch = Self {
173+
allocated: 0,
174+
initialized_logits: vec![(tokens.len() - 1) as i32],
175+
llama_batch: batch,
176+
};
177+
Ok(batch)
178+
}
179+
152180
/// Returns the number of tokens in the batch.
153181
#[must_use]
154182
pub fn n_tokens(&self) -> i32 {
@@ -170,7 +198,9 @@ impl Drop for LlamaBatch {
170198
/// # }
171199
fn drop(&mut self) {
172200
unsafe {
173-
llama_batch_free(self.llama_batch);
201+
if self.allocated > 0 {
202+
llama_batch_free(self.llama_batch);
203+
}
174204
}
175205
}
176206
}

0 commit comments

Comments
 (0)