Skip to content

Commit ccb434b

Browse files
committed
Swap out unsafe for a nested LLamaBatch::add call
1 parent 0b0e850 commit ccb434b

File tree

1 file changed

+1
-18
lines changed

1 file changed

+1
-18
lines changed

llama-cpp-2/src/llama_batch.rs

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,26 +108,9 @@ impl LlamaBatch {
108108
if self.allocated < n_tokens_0 as usize + n_tokens {
109109
return Err(BatchAddError::InsufficientSpace(self.allocated));
110110
}
111-
if n_tokens == 0 {
112-
return Ok(())
113-
}
114111

115-
self.llama_batch.n_tokens += n_tokens as i32;
116112
for (i, token) in tokens.iter().enumerate() {
117-
let j = n_tokens_0 as usize + i;
118-
unsafe {
119-
self.llama_batch.token.add(j).write(token.0);
120-
self.llama_batch.pos.add(j).write(i as i32);
121-
let seq_id_ptr = *self.llama_batch.seq_id.add(j);
122-
seq_id_ptr.write(seq_id);
123-
self.llama_batch.n_seq_id.add(j).write(1);
124-
125-
let write_logits = logits_all || i == n_tokens - 1;
126-
self.llama_batch.logits.add(j).write(write_logits as i8);
127-
if write_logits {
128-
self.initialized_logits.push(j as i32);
129-
}
130-
}
113+
self.add(*token, i as llama_pos, &[seq_id], logits_all || i == n_tokens - 1)?;
131114
}
132115

133116
Ok(())

0 commit comments

Comments
 (0)