Skip to content

Commit ec76c45

Browse files
committed
fix: input ids array with wrong shape
1 parent bcd3c42 commit ec76c45

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ext/ai/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,9 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
200200
.map(|i| *i as i64)
201201
.collect::<Vec<_>>();
202202

203+
// Convert our flattened arrays into 2-dimensional tensors of shape [N, L] -> Since we're not batching 'N' will be always = 1
203204
let input_ids_array =
204-
TensorRef::from_array_view(([input_ids.len(), 1], &*input_ids))?;
205+
TensorRef::from_array_view(([1, input_ids.len()], &*input_ids))?;
205206
let attention_mask_array = TensorRef::from_array_view((
206207
[1, encoded_prompt.len()],
207208
&*attention_mask,
@@ -237,7 +238,6 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
237238
.insert_axis(Axis(0))
238239
.insert_axis(Axis(2));
239240

240-
println!("attention_mask: {attention_mask_array_clone:?}");
241241
mean_pool(embeddings, attention_mask_array_clone)
242242
} else {
243243
embeddings.into_owned().remove_axis(Axis(0))

0 commit comments

Comments
 (0)