@@ -20,6 +20,9 @@ pub enum BatchAddError {
2020 /// There was not enough space in the batch to add the token.
2121 #[ error( "Insufficient Space of {0}" ) ]
2222 InsufficientSpace ( usize ) ,
23+ /// Empty buffer is provided for get_one
24+ #[ error( "Empty buffer" ) ]
25+ EmptyBuffer ,
2326}
2427
2528impl LlamaBatch {
@@ -149,6 +152,31 @@ impl LlamaBatch {
149152 }
150153 }
151154
155+ /// llama_batch_get_one
156+ /// Return batch for single sequence of tokens starting at pos_0
157+ ///
158+ /// NOTE: this is a helper function to facilitate transition to the new batch API
159+ ///
160+ pub fn get_one (
161+ tokens : & [ LlamaToken ] ,
162+ pos_0 : llama_pos ,
163+ seq_id : llama_seq_id ,
164+ ) -> Result < Self , BatchAddError > {
165+ if tokens. is_empty ( ) {
166+ return Err ( BatchAddError :: EmptyBuffer ) ;
167+ }
168+ let batch = unsafe {
169+ let ptr = tokens. as_ptr ( ) as * mut i32 ;
170+ llama_cpp_sys_2:: llama_batch_get_one ( ptr, tokens. len ( ) as i32 , pos_0, seq_id)
171+ } ;
172+ let batch = Self {
173+ allocated : 0 ,
174+ initialized_logits : vec ! [ ( tokens. len( ) - 1 ) as i32 ] ,
175+ llama_batch : batch,
176+ } ;
177+ Ok ( batch)
178+ }
179+
152180 /// Returns the number of tokens in the batch.
153181 #[ must_use]
154182 pub fn n_tokens ( & self ) -> i32 {
@@ -170,7 +198,9 @@ impl Drop for LlamaBatch {
170198 /// # }
171199 fn drop ( & mut self ) {
172200 unsafe {
173- llama_batch_free ( self . llama_batch ) ;
201+ if self . allocated > 0 {
202+ llama_batch_free ( self . llama_batch ) ;
203+ }
174204 }
175205 }
176206}
0 commit comments