11//! Safe wrapper around `llama_batch`.
22
3- use crate :: token:: LlamaToken ;
3+ use crate :: token:: { self , LlamaToken } ;
44use llama_cpp_sys_2:: { llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id} ;
55
66/// A safe wrapper around `llama_batch`.
@@ -20,6 +20,9 @@ pub enum BatchAddError {
2020 /// There was not enough space in the batch to add the token.
2121 #[ error( "Insufficient Space of {0}" ) ]
2222 InsufficientSpace ( usize ) ,
23+ /// Empty buffer is provided for get_one
24+ #[ error( "Empty buffer" ) ]
25+ EmptyBuffer ,
2326}
2427
2528impl LlamaBatch {
@@ -154,18 +157,24 @@ impl LlamaBatch {
154157 ///
155158 /// NOTE: this is a helper function to facilitate transition to the new batch API
156159 ///
157- pub fn get_one ( tokens : & [ LlamaToken ] , pos_0 : llama_pos , seq_id : llama_seq_id ) -> Self {
158- unsafe {
159- let ptr = tokens. as_ptr ( ) as * mut i32 ;
160- let batch =
161- llama_cpp_sys_2:: llama_batch_get_one ( ptr, tokens. len ( ) as i32 , pos_0, seq_id) ;
162-
163- crate :: llama_batch:: LlamaBatch {
164- allocated : 0 ,
165- initialized_logits : vec ! [ ] ,
166- llama_batch : batch,
167- }
160+ pub fn get_one (
161+ tokens : & [ LlamaToken ] ,
162+ pos_0 : llama_pos ,
163+ seq_id : llama_seq_id ,
164+ ) -> Result < Self , BatchAddError > {
165+ if tokens. is_empty ( ) {
166+ return Err ( BatchAddError :: EmptyBuffer ) ;
168167 }
168+ let batch = unsafe {
169+ let ptr = tokens. as_ptr ( ) as * mut i32 ;
170+ llama_cpp_sys_2:: llama_batch_get_one ( ptr, tokens. len ( ) as i32 , pos_0, seq_id)
171+ } ;
172+ let batch = Self {
173+ allocated : 0 ,
174+ initialized_logits : vec ! [ ( tokens. len( ) - 1 ) as i32 ] ,
175+ llama_batch : batch,
176+ } ;
177+ Ok ( batch)
169178 }
170179
171180 /// Returns the number of tokens in the batch.
0 commit comments