@@ -14,8 +14,8 @@ use std::time::Duration;
1414use anyhow:: { bail, Context , Result } ;
1515use clap:: Parser ;
1616use hf_hub:: api:: sync:: ApiBuilder ;
17- use llama_cpp_2:: context:: LlamaContext ;
1817
18+ use llama_cpp_2:: context:: LlamaContext ;
1919use llama_cpp_2:: context:: params:: LlamaContextParams ;
2020use llama_cpp_2:: ggml_time_us;
2121use llama_cpp_2:: llama_backend:: LlamaBackend ;
@@ -149,35 +149,33 @@ fn main() -> Result<()> {
149149 // we use this object to submit token data for decoding
150150 let mut batch = LlamaBatch :: new ( n_ctx, 1 ) ;
151151
152- // Amount of tokens in the current batch
153- let mut s_batch = 0 ;
152+ let mut max_seq_id_batch = 0 ;
154153 let mut output = Vec :: with_capacity ( tokens_lines_list. len ( ) ) ;
155154
156155 let t_main_start = ggml_time_us ( ) ;
157156
158157 for tokens in & tokens_lines_list {
159158 // Flush the batch if the next prompt would exceed our batch size
160159 if ( batch. n_tokens ( ) as usize + tokens. len ( ) ) > n_ctx {
161- batch_decode ( & mut ctx, & mut batch, s_batch , & mut output, normalise) ?;
162- s_batch = 0 ;
160+ batch_decode ( & mut ctx, & mut batch, max_seq_id_batch , & mut output, normalise) ?;
161+ max_seq_id_batch = 0 ;
163162 }
164163
165- batch. add_sequence ( & tokens, s_batch , false ) ?;
166- s_batch += 1 ;
164+ batch. add_sequence ( & tokens, max_seq_id_batch , false ) ?;
165+ max_seq_id_batch += 1 ;
167166 }
168167 // Handle final batch
169- batch_decode ( & mut ctx, & mut batch, s_batch , & mut output, normalise) ?;
168+ batch_decode ( & mut ctx, & mut batch, max_seq_id_batch , & mut output, normalise) ?;
170169
171170 let t_main_end = ggml_time_us ( ) ;
172171
173172 for ( i, embeddings) in output. iter ( ) . enumerate ( ) {
174173 eprintln ! ( "Embeddings {i}: {embeddings:?}" ) ;
175- eprintln ! ( " \n " ) ;
174+ eprintln ! ( ) ;
176175 }
177176
178177 let duration = Duration :: from_micros ( ( t_main_end - t_main_start) as u64 ) ;
179178 let total_tokens: usize = tokens_lines_list. iter ( ) . map ( |v| v. len ( ) ) . sum ( ) ;
180-
181179 eprintln ! (
182180 "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n " ,
183181 total_tokens,
0 commit comments