Skip to content

Commit 7cabe4d

Browse files
authored
Merge pull request #802 from tmetsch/main
fix: create fresh batch for each line in embedding example
2 parents 1fbcc4b + 628bf4c commit 7cabe4d

File tree

1 file changed

+10
-28
lines changed

1 file changed

+10
-28
lines changed

examples/embeddings/src/main.rs

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -150,40 +150,22 @@ fn main() -> Result<()> {
150150
}
151151

152152
std::io::stderr().flush()?;
153-
154-
// create a llama_batch with the size of the context
155-
// we use this object to submit token data for decoding
156-
let mut batch = LlamaBatch::new(n_ctx, 1);
157-
158-
let mut max_seq_id_batch = 0;
159153
let mut output = Vec::with_capacity(tokens_lines_list.len());
160154

161155
let t_main_start = ggml_time_us();
162156

163157
for tokens in &tokens_lines_list {
164-
// Flush the batch if the next prompt would exceed our batch size
165-
if (batch.n_tokens() as usize + tokens.len()) > n_ctx {
166-
batch_decode(
167-
&mut ctx,
168-
&mut batch,
169-
max_seq_id_batch,
170-
&mut output,
171-
normalise,
172-
)?;
173-
max_seq_id_batch = 0;
174-
}
175-
176-
batch.add_sequence(tokens, max_seq_id_batch, false)?;
177-
max_seq_id_batch += 1;
158+
// Create a fresh batch for each sequence
159+
let mut batch = LlamaBatch::new(n_ctx, 1);
160+
batch.add_sequence(tokens, 0, false)?;
161+
batch_decode(
162+
&mut ctx,
163+
&mut batch,
164+
1, // Only one sequence in this batch
165+
&mut output,
166+
normalise,
167+
)?;
178168
}
179-
// Handle final batch
180-
batch_decode(
181-
&mut ctx,
182-
&mut batch,
183-
max_seq_id_batch,
184-
&mut output,
185-
normalise,
186-
)?;
187169

188170
let t_main_end = ggml_time_us();
189171

0 commit comments

Comments
 (0)