Skip to content

Commit 64556ee

Browse files
committed
stamp: clippy & fmt
1 parent 6a03c86 commit 64556ee

File tree

5 files changed

+69
-57
lines changed

5 files changed

+69
-57
lines changed

ext/ai/lib.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,17 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
200200
.map(|i| *i as i64)
201201
.collect::<Vec<_>>();
202202

203-
let input_ids_array = TensorRef::from_array_view(([input_ids.len(), 1], &*input_ids))?;
204-
let attention_mask_array = TensorRef::from_array_view(([1, encoded_prompt.len()], &*attention_mask))?;
205-
let token_type_ids_array = TensorRef::from_array_view(([1, encoded_prompt.len()], &*token_type_ids))?;
206-
203+
let input_ids_array =
204+
TensorRef::from_array_view(([input_ids.len(), 1], &*input_ids))?;
205+
let attention_mask_array = TensorRef::from_array_view((
206+
[1, encoded_prompt.len()],
207+
&*attention_mask,
208+
))?;
209+
210+
let token_type_ids_array = TensorRef::from_array_view((
211+
[1, encoded_prompt.len()],
212+
&*token_type_ids,
213+
))?;
207214

208215
let Ok(mut guard) = session.lock() else {
209216
let err = anyhow!("failed to lock session");
@@ -223,10 +230,12 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
223230
let embeddings = embeddings.into_dimensionality::<Ix3>()?;
224231

225232
let result = if do_mean_pooling {
226-
let attention_mask_array_clone= Array1::from_iter(attention_mask.iter().cloned());
227-
let attention_mask_array_clone= attention_mask_array_clone.view()
228-
.insert_axis(Axis(0))
229-
.insert_axis(Axis(2));
233+
let attention_mask_array_clone =
234+
Array1::from_iter(attention_mask.iter().cloned());
235+
let attention_mask_array_clone = attention_mask_array_clone
236+
.view()
237+
.insert_axis(Axis(0))
238+
.insert_axis(Axis(2));
230239

231240
println!("attention_mask: {attention_mask_array_clone:?}");
232241
mean_pool(embeddings, attention_mask_array_clone)

ext/ai/onnxruntime/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@ pub async fn op_ai_ort_init_session(
5656
};
5757

5858
let mut state = state.borrow_mut();
59-
let mut sessions =
60-
{ state.try_take::<Vec<Arc<Mutex<Session>>>>().unwrap_or_default() };
59+
let mut sessions = {
60+
state
61+
.try_take::<Vec<Arc<Mutex<Session>>>>()
62+
.unwrap_or_default()
63+
};
6164

6265
sessions.push(model.get_session());
6366
state.put(sessions);

ext/ai/onnxruntime/model.rs

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,23 @@ impl Model {
3636
fn new(session_with_id: SessionWithId) -> Result<Self> {
3737
let (input_names, output_names) = {
3838
let Ok(session_guard) = session_with_id.session.lock() else {
39-
return Err(anyhow!("Could not lock model session {}", session_with_id.id));
40-
};
41-
42-
let input_names = session_guard
43-
.inputs
44-
.iter()
45-
.map(|input| input.name.clone())
46-
.collect::<Vec<_>>();
47-
48-
let output_names = session_guard
49-
.outputs
50-
.iter()
51-
.map(|output| output.name.clone())
52-
.collect::<Vec<_>>();
39+
return Err(anyhow!(
40+
"Could not lock model session {}",
41+
session_with_id.id
42+
));
43+
};
44+
45+
let input_names = session_guard
46+
.inputs
47+
.iter()
48+
.map(|input| input.name.clone())
49+
.collect::<Vec<_>>();
50+
51+
let output_names = session_guard
52+
.outputs
53+
.iter()
54+
.map(|output| output.name.clone())
55+
.collect::<Vec<_>>();
5356

5457
(input_names, output_names)
5558
};
@@ -75,13 +78,11 @@ impl Model {
7578
pub async fn from_id(id: &str) -> Option<Self> {
7679
let session = {
7780
get_session(id)
78-
.await
79-
.map(|it| SessionWithId::from((id.to_string(), it)))
81+
.await
82+
.map(|it| SessionWithId::from((id.to_string(), it)))
8083
};
8184

82-
let Some(session) = session else {
83-
return None;
84-
};
85+
let session = session?;
8586

8687
Self::new(session).ok()
8788
}

ext/ai/onnxruntime/session.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ pub(crate) fn get_session_builder() -> Result<SessionBuilder, AnyError> {
7676
}
7777

7878
fn cpu_execution_provider() -> ExecutionProviderDispatch {
79-
// NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to
80-
// False.
81-
//
82-
// Backgrounds:
83-
// [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18
84-
// [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50
85-
CPUExecutionProvider::default().build()
79+
// NOTE(Nyannacha): See the comment above. This makes `enable_cpu_mem_arena` set to
80+
// False.
81+
//
82+
// Backgrounds:
83+
// [1]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#9-18
84+
// [2]: https://docs.rs/ort/2.0.0-rc.4/src/ort/execution_providers/cpu.rs.html#46-50
85+
CPUExecutionProvider::default().build()
8686
}
8787

8888
fn cuda_execution_provider() -> Option<ExecutionProviderDispatch> {
@@ -92,13 +92,12 @@ fn cuda_execution_provider() -> Option<ExecutionProviderDispatch> {
9292

9393
if is_cuda_available {
9494
Some(cuda.build())
95-
}else{
95+
} else {
9696
None
9797
}
9898
}
9999

100-
fn get_execution_providers(
101-
) -> Vec<ExecutionProviderDispatch> {
100+
fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
102101
let cpu = cpu_execution_provider();
103102

104103
if let Some(cuda) = cuda_execution_provider() {

ext/ai/onnxruntime/tensor.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,23 +115,23 @@ pub enum JsTensorType {
115115
/// Brain 16-bit floating point number, equivalent to [`half::bf16`] (requires the `half` feature).
116116
Bfloat16,
117117
Complex64,
118-
Complex128,
119-
/// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values and no infinite
120-
/// values.
121-
Float8E4M3FN,
122-
/// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values, no infinite
123-
/// values, and no negative zero.
124-
Float8E4M3FNUZ,
125-
/// 8-bit floating point number with 5 exponent bits and 2 mantissa bits.
126-
Float8E5M2,
127-
/// 8-bit floating point number with 5 exponent bits and 2 mantissa bits, with only NaN values, no infinite
128-
/// values, and no negative zero.
129-
Float8E5M2FNUZ,
130-
/// 4-bit unsigned integer.
131-
Uint4,
132-
/// 4-bit signed integer.
133-
Int4,
134-
Undefined
118+
Complex128,
119+
/// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values and no infinite
120+
/// values.
121+
Float8E4M3FN,
122+
/// 8-bit floating point number with 4 exponent bits and 3 mantissa bits, with only NaN values, no infinite
123+
/// values, and no negative zero.
124+
Float8E4M3FNUZ,
125+
/// 8-bit floating point number with 5 exponent bits and 2 mantissa bits.
126+
Float8E5M2,
127+
/// 8-bit floating point number with 5 exponent bits and 2 mantissa bits, with only NaN values, no infinite
128+
/// values, and no negative zero.
129+
Float8E5M2FNUZ,
130+
/// 4-bit unsigned integer.
131+
Uint4,
132+
/// 4-bit signed integer.
133+
Int4,
134+
Undefined,
135135
}
136136

137137
#[derive(Serialize, Deserialize)]
@@ -282,7 +282,7 @@ impl ToJsTensor {
282282
TensorElementType::String => todo!(),
283283
TensorElementType::Float16 => todo!(),
284284
TensorElementType::Bfloat16 => todo!(),
285-
_ => todo!()
285+
_ => todo!(),
286286
};
287287

288288
Ok(Self {

0 commit comments

Comments
 (0)