Skip to content

Commit e522925

Browse files
fix(sb_ai): Sessions are dropped while still in use (#443)
* fix(sb_ai): error handling on `run session` - removing `unwrap()` from `run session` * fix(sb_ai): cleanup logic, passing Session ref to js land - adding currently active session's refs to `OpState` * stamp: clippy * feat(k6): add scenario for `ort-rust-backend` - adding a scenario that uses `transformers.js` + `ort rust backend`
1 parent 390a5c1 commit e522925

File tree

5 files changed

+111
-11
lines changed

5 files changed

+111
-11
lines changed

crates/sb_ai/onnxruntime/mod.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,29 @@ pub(crate) mod onnx;
33
pub(crate) mod session;
44
mod tensor;
55

6-
use std::{borrow::Cow, collections::HashMap};
6+
use std::{borrow::Cow, cell::RefCell, collections::HashMap, rc::Rc, sync::Arc};
77

8-
use anyhow::Result;
9-
use deno_core::op2;
8+
use anyhow::{anyhow, Result};
9+
use deno_core::{op2, OpState};
1010

1111
use model_session::{ModelInfo, ModelSession};
12+
use ort::Session;
1213
use tensor::{JsTensor, ToJsTensor};
1314

1415
#[op2]
1516
#[to_v8]
16-
pub fn op_sb_ai_ort_init_session(#[buffer] model_bytes: &[u8]) -> Result<ModelInfo> {
17+
pub fn op_sb_ai_ort_init_session(
18+
state: Rc<RefCell<OpState>>,
19+
#[buffer] model_bytes: &[u8],
20+
) -> Result<ModelInfo> {
21+
let mut state = state.borrow_mut();
1722
let model_info = ModelSession::from_bytes(model_bytes)?;
1823

24+
let mut sessions = { state.try_take::<Vec<Arc<Session>>>().unwrap_or_default() };
25+
26+
sessions.push(model_info.inner());
27+
state.put(sessions);
28+
1929
Ok(model_info.info())
2030
}
2131

@@ -25,10 +35,11 @@ pub fn op_sb_ai_ort_run_session(
2535
#[string] model_id: String,
2636
#[serde] input_values: HashMap<String, JsTensor>,
2737
) -> Result<HashMap<String, ToJsTensor>> {
28-
let model = ModelSession::from_id(model_id).unwrap();
38+
let model = ModelSession::from_id(model_id.to_owned())
39+
.ok_or(anyhow!("could not found session for id={model_id:?}"))?;
40+
2941
let model_session = model.inner();
3042

31-
// println!("{model_session:?}");
3243
let input_values = input_values
3344
.into_iter()
3445
.map(|(key, value)| {
@@ -44,7 +55,9 @@ pub fn op_sb_ai_ort_run_session(
4455
// We need to `pop` over outputs to get 'value' ownership, since keys are attached to 'model_session' lifetime
4556
// it can't be iterated with `into_iter()`
4657
for _ in 0..outputs.len() {
47-
let (key, value) = outputs.pop_first().unwrap();
58+
let (key, value) = outputs.pop_first().ok_or(anyhow!(
59+
"could not retrieve output value from model session"
60+
))?;
4861

4962
let value = ToJsTensor::from_ort_tensor(value)?;
5063

crates/sb_ai/onnxruntime/session.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ pub fn cleanup() -> Result<usize, AnyError> {
145145
let mut to_be_removed = vec![];
146146

147147
for (key, session) in &mut *guard {
148+
// Since we're currently referencing the session at this point
149+
// It also will increments the counter, so we need to check: counter > 1
148150
if Arc::strong_count(session) > 1 {
149151
continue;
150152
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import { env, pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected]';
2+
3+
// Ensure we do not use browser cache
4+
env.useBrowserCache = false;
5+
env.allowLocalModels = false;
6+
7+
const pipe = await pipeline('feature-extraction', 'supabase/gte-small', { device: 'auto' });
8+
9+
Deno.serve(async (req) => {
10+
const payload = await req.json();
11+
const text_for_embedding = payload.text_for_embedding;
12+
13+
// Generate embedding
14+
const embedding = await pipe(text_for_embedding, { pooling: 'mean', normalize: true });
15+
16+
return Response.json({
17+
length: embedding.ort_tensor.size,
18+
});
19+
});

examples/main/index.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ console.log('main function started');
1010

1111
// cleanup unused sessions every 30s
1212
// setInterval(async () => {
13-
// const { activeUserWorkersCount } = await EdgeRuntime.getRuntimeMetrics();
14-
// if (activeUserWorkersCount > 0) {
15-
// return;
16-
// }
1713
// try {
1814
// const cleanupCount = await EdgeRuntime.ai.tryCleanupUnusedSession();
1915
// if (cleanupCount == 0) {

k6/specs/ort-rust-backend.ts

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
./scripts/run.sh
3+
4+
#!/usr/bin/env bash
5+
6+
GIT_V_TAG=0.1.1 cargo build --features cli/tracing && \
7+
EDGE_RUNTIME_WORKER_POOL_SIZE=8 \
8+
EDGE_RUNTIME_PORT=9998 RUST_BACKTRACE=full ./target/debug/edge-runtime "$@" start \
9+
--main-service ./examples/main \
10+
--event-worker ./examples/event-manager
11+
12+
*/
13+
14+
import http from "k6/http";
15+
16+
import { check, fail } from "k6";
17+
import { Options } from "k6/options";
18+
19+
import { target } from "../config";
20+
21+
/** @ts-ignore */
22+
import { randomIntBetween } from "https://jslib.k6.io/k6-utils/1.2.0/index.js";
23+
import { MSG_CANCELED } from "../constants";
24+
25+
export const options: Options = {
26+
scenarios: {
27+
simple: {
28+
executor: "constant-vus",
29+
vus: 12,
30+
duration: "3m",
31+
}
32+
}
33+
};
34+
35+
const GENERATORS = import("../generators");
36+
37+
export async function setup() {
38+
const pkg = await GENERATORS;
39+
return {
40+
words: pkg.makeText(1000)
41+
}
42+
}
43+
44+
export default function ort_rust_backend(data: { words: string[] }) {
45+
const wordIdx = randomIntBetween(0, data.words.length - 1);
46+
47+
console.debug(`WORD[${wordIdx}]: ${data.words[wordIdx]}`);
48+
const res = http.post(
49+
`${target}/k6-ort-rust-backend`,
50+
JSON.stringify({
51+
"text_for_embedding": data.words[wordIdx]
52+
})
53+
);
54+
55+
const isOk = check(res, {
56+
"status is 200": r => r.status === 200
57+
});
58+
59+
const isRequestCancelled = check(res, {
60+
"request cancelled": r => {
61+
const msg = r.json("msg");
62+
return r.status === 500 && msg === MSG_CANCELED;
63+
}
64+
});
65+
66+
if (!isOk && !isRequestCancelled) {
67+
console.log(res.body);
68+
fail("unexpected response");
69+
}
70+
}

0 commit comments

Comments
 (0)