Skip to content

Commit 9126146

Browse files
committed
Use a better name for the s_batch variable and remove excessive whitespace
1 parent fa7b508 commit 9126146

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

embeddings/src/main.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use std::time::Duration;
1414
use anyhow::{bail, Context, Result};
1515
use clap::Parser;
1616
use hf_hub::api::sync::ApiBuilder;
17-
use llama_cpp_2::context::LlamaContext;
1817

18+
use llama_cpp_2::context::LlamaContext;
1919
use llama_cpp_2::context::params::LlamaContextParams;
2020
use llama_cpp_2::ggml_time_us;
2121
use llama_cpp_2::llama_backend::LlamaBackend;
@@ -149,35 +149,33 @@ fn main() -> Result<()> {
149149
// we use this object to submit token data for decoding
150150
let mut batch = LlamaBatch::new(n_ctx, 1);
151151

152-
// Amount of tokens in the current batch
153-
let mut s_batch = 0;
152+
let mut max_seq_id_batch = 0;
154153
let mut output = Vec::with_capacity(tokens_lines_list.len());
155154

156155
let t_main_start = ggml_time_us();
157156

158157
for tokens in &tokens_lines_list {
159158
// Flush the batch if the next prompt would exceed our batch size
160159
if (batch.n_tokens() as usize + tokens.len()) > n_ctx {
161-
batch_decode(&mut ctx, &mut batch, s_batch, &mut output, normalise)?;
162-
s_batch = 0;
160+
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;
161+
max_seq_id_batch = 0;
163162
}
164163

165-
batch.add_sequence(&tokens, s_batch, false)?;
166-
s_batch += 1;
164+
batch.add_sequence(&tokens, max_seq_id_batch, false)?;
165+
max_seq_id_batch += 1;
167166
}
168167
// Handle final batch
169-
batch_decode(&mut ctx, &mut batch, s_batch, &mut output, normalise)?;
168+
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;
170169

171170
let t_main_end = ggml_time_us();
172171

173172
for (i, embeddings) in output.iter().enumerate() {
174173
eprintln!("Embeddings {i}: {embeddings:?}");
175-
eprintln!("\n");
174+
eprintln!();
176175
}
177176

178177
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
179178
let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum();
180-
181179
eprintln!(
182180
"Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n",
183181
total_tokens,

0 commit comments

Comments
 (0)