Skip to content

Commit e373fb7

Browse files
authored
Merge pull request #268 from supabase/try-ort
feat: Add experimental Supabase.ai API for running inferences
2 parents d6c7833 + 135499a commit e373fb7

File tree

19 files changed

+62174
-169
lines changed

19 files changed

+62174
-169
lines changed

Cargo.lock

Lines changed: 569 additions & 151 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ members = [
1313
"./crates/sb_graph",
1414
"./crates/sb_module_loader",
1515
"./crates/sb_fs",
16+
"./crates/sb_ai"
1617
]
1718
resolver = "2"
1819

Dockerfile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
FROM rust:1.74.1-bookworm as builder
33
ARG TARGETPLATFORM
44
ARG GIT_V_VERSION
5+
ARG ONNXRUNTIME_VERSION=1.17.0
56
RUN apt-get update && apt-get install -y llvm-dev libclang-dev clang cmake
67
WORKDIR /usr/src/edge-runtime
78
RUN --mount=type=cache,target=/usr/local/cargo/registry,id=${TARGETPLATFORM} \
@@ -11,10 +12,16 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry,id=${TARGETPLATFORM} --m
1112
GIT_V_TAG=${GIT_V_VERSION} cargo build --release && \
1213
cargo strip && \
1314
mv /usr/src/edge-runtime/target/release/edge-runtime /root
15+
RUN curl -O https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-${ONNXRUNTIME_VERSION}.tgz && tar zxvf onnxruntime-node-${ONNXRUNTIME_VERSION}.tgz && \
16+
mv ./package/bin/napi-v3/$TARGETPLATFORM/libonnxruntime.so.${ONNXRUNTIME_VERSION} /root/libonnxruntime.so
1417

1518

1619
FROM debian:bookworm-slim
1720
RUN apt-get update && apt-get install -y libssl-dev && rm -rf /var/lib/apt/lists/*
1821
RUN apt-get remove -y perl && apt-get autoremove -y
1922
COPY --from=builder /root/edge-runtime /usr/local/bin/edge-runtime
23+
COPY --from=builder /root/libonnxruntime.so /usr/local/bin/libonnxruntime.so
24+
COPY ./models /etc/sb_ai/models
25+
ENV ORT_DYLIB_PATH=/usr/local/bin/libonnxruntime.so
26+
ENV SB_AI_MODELS_DIR=/etc/sb_ai/models
2027
ENTRYPOINT ["edge-runtime"]

crates/base/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pin-project = { version = "1.1.3" }
6262
ctor = { workspace = true }
6363
deno_canvas.workspace = true
6464
deno_webgpu.workspace = true
65+
sb_ai = { version = "0.1.0", path = "../sb_ai" }
6566

6667
[dev-dependencies]
6768
flaky_test = { version = "0.1.0", path = "../flaky_test" }
@@ -102,4 +103,5 @@ event_worker ={ version = "0.1.0", path = "../event_worker" }
102103
deno_broadcast_channel.workspace = true
103104
deno_core.workspace = true
104105
deno_canvas.workspace = true
105-
deno_webgpu.workspace = true
106+
deno_webgpu.workspace = true
107+
sb_ai = { version = "0.1.0", path = "../sb_ai" }

crates/base/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mod supabase_startup_snapshot {
1111
use deno_http::DefaultHttpPropertyExtractor;
1212
use event_worker::js_interceptors::sb_events_js_interceptors;
1313
use event_worker::sb_user_event_worker;
14+
use sb_ai::sb_ai;
1415
use sb_core::http_start::sb_core_http;
1516
use sb_core::net::sb_core_net;
1617
use sb_core::permissions::sb_core_permissions;
@@ -196,6 +197,7 @@ mod supabase_startup_snapshot {
196197
deno_http::deno_http::init_ops_and_esm::<DefaultHttpPropertyExtractor>(),
197198
deno_io::deno_io::init_ops_and_esm(Some(Default::default())),
198199
deno_fs::deno_fs::init_ops_and_esm::<Permissions>(fs.clone()),
200+
sb_ai::init_ops_and_esm(),
199201
sb_env::init_ops_and_esm(),
200202
sb_os::sb_os::init_ops_and_esm(),
201203
sb_user_workers::init_ops_and_esm(),

crates/base/src/deno_runtime.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::snapshot;
3333
use event_worker::events::{EventMetadata, WorkerEventWithMetadata};
3434
use event_worker::js_interceptors::sb_events_js_interceptors;
3535
use event_worker::sb_user_event_worker;
36+
use sb_ai::sb_ai;
3637
use sb_core::cache::CacheSetting;
3738
use sb_core::cert::ValueRootCertStoreProvider;
3839
use sb_core::external_memory::custom_allocator;
@@ -277,6 +278,7 @@ impl DenoRuntime {
277278
deno_io::deno_io::init_ops(stdio),
278279
deno_fs::deno_fs::init_ops::<Permissions>(fs.clone()),
279280
sb_env_op::init_ops(),
281+
sb_ai::init_ops(),
280282
sb_os::sb_os::init_ops(),
281283
sb_user_workers::init_ops(),
282284
sb_user_event_worker::init_ops(),

crates/sb_ai/Cargo.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[package]
2+
name = "sb_ai"
3+
version = "0.1.0"
4+
authors = ["Supabase <[email protected]>"]
5+
edition = "2021"
6+
resolver = "2"
7+
license = "MIT"
8+
9+
[lib]
10+
path = "lib.rs"
11+
12+
[dependencies]
13+
anyhow.workspace = true
14+
deno_core.workspace = true
15+
log = { workspace = true }
16+
serde.workspace = true
17+
ort = { version = "2.0.0-alpha.4", default-features = false, features = [ "ndarray", "half", "load-dynamic" ] }
18+
ndarray = "0.15"
19+
ndarray-linalg = "0.15"
20+
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
21+
rand = "0.8"

crates/sb_ai/ai.js

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
const core = globalThis.Deno.core;
2+
3+
class SupabaseAI {
4+
runModel(name, prompt) {
5+
const result = core.ops.op_sb_ai_run_model(name, prompt);
6+
return result;
7+
}
8+
}
9+
10+
export { SupabaseAI };

crates/sb_ai/lib.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
use anyhow::{bail, Error};
2+
use deno_core::error::AnyError;
3+
use deno_core::op2;
4+
use deno_core::OpState;
5+
use ndarray::{Array1, Array2, Axis, Ix2};
6+
use ndarray_linalg::norm::{normalize, NormalizeAxis};
7+
use ort::{inputs, GraphOptimizationLevel, Session, Tensor};
8+
use std::path::Path;
9+
use tokenizers::normalizers::bert::BertNormalizer;
10+
use tokenizers::Tokenizer;
11+
12+
deno_core::extension!(
13+
sb_ai,
14+
ops = [op_sb_ai_run_model],
15+
esm_entry_point = "ext:sb_ai/ai.js",
16+
esm = ["ai.js",]
17+
);
18+
19+
fn run_gte(state: &mut OpState, prompt: String) -> Result<Vec<f32>, Error> {
20+
// Create the ONNX Runtime environment, for all sessions created in this process.
21+
ort::init().with_name("GTE").commit()?;
22+
23+
let models_dir = std::env::var("SB_AI_MODELS_DIR").unwrap_or("/etc/sb_ai/models".to_string());
24+
25+
let mut session = state.try_take::<Session>();
26+
if session.is_none() {
27+
session = Some(
28+
Session::builder()?
29+
.with_optimization_level(GraphOptimizationLevel::Disable)?
30+
.with_intra_threads(1)?
31+
.with_model_from_file(
32+
Path::new(&models_dir)
33+
.join("gte")
34+
.join("gte_small_quantized.onnx"),
35+
)?,
36+
);
37+
}
38+
let session = session.unwrap();
39+
40+
// Load the tokenizer and encode the prompt into a sequence of tokens.
41+
let mut tokenizer = state.try_take::<Tokenizer>();
42+
if tokenizer.is_none() {
43+
tokenizer = Some(
44+
Tokenizer::from_file(
45+
Path::new(&models_dir)
46+
.join("gte")
47+
.join("gte_small_tokenizer.json"),
48+
)
49+
.map_err(anyhow::Error::msg)?,
50+
)
51+
}
52+
let mut tokenizer = tokenizer.unwrap();
53+
54+
let tokenizer_impl = tokenizer
55+
.with_normalizer(BertNormalizer::default())
56+
.with_padding(None)
57+
.with_truncation(None)
58+
.map_err(anyhow::Error::msg)?;
59+
60+
let tokens = tokenizer_impl
61+
.encode(prompt, true)
62+
.map_err(anyhow::Error::msg)?
63+
.get_ids()
64+
.iter()
65+
.map(|i| *i as i64)
66+
.collect::<Vec<_>>();
67+
68+
let tokens = Array1::from_iter(tokens.iter().cloned());
69+
70+
let array = tokens.view().insert_axis(Axis(0));
71+
let dims = array.raw_dim();
72+
let token_type_ids = Array2::<i64>::zeros(dims);
73+
let attention_mask = Array2::<i64>::ones(dims);
74+
let outputs = session.run(inputs! {
75+
"input_ids" => array,
76+
"token_type_ids" => token_type_ids,
77+
"attention_mask" => attention_mask,
78+
}?)?;
79+
80+
let embeddings: Tensor<f32> = outputs["last_hidden_state"].extract_tensor()?;
81+
82+
let embeddings_view = embeddings.view();
83+
let mean_pool = embeddings_view.mean_axis(Axis(1)).unwrap();
84+
let (normalized, _) = normalize(
85+
mean_pool.into_dimensionality::<Ix2>().unwrap(),
86+
NormalizeAxis::Row,
87+
);
88+
89+
let slice = normalized.view().to_slice().unwrap().to_vec();
90+
91+
drop(outputs);
92+
93+
state.put::<Session>(session);
94+
state.put::<Tokenizer>(tokenizer);
95+
96+
Ok(slice)
97+
}
98+
99+
#[op2]
100+
#[serde]
101+
pub fn op_sb_ai_run_model(
102+
state: &mut OpState,
103+
#[string] name: String,
104+
#[string] prompt: String,
105+
) -> Result<Vec<f32>, AnyError> {
106+
if name == "gte" {
107+
run_gte(state, prompt)
108+
} else {
109+
bail!("model not supported")
110+
}
111+
}

crates/sb_core/cache/http_cache/local.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ mod manifest {
572572
self.serialized.modules.insert(url, new_data);
573573
}
574574

575+
#[allow(deprecated)]
575576
pub fn remove(&mut self, url: &Url, sub_path: &LocalCacheSubPath) -> bool {
576577
if self.serialized.modules.remove(url).is_some() {
577578
if let Some(reverse_mapping) = &mut self.reverse_mapping {

0 commit comments

Comments
 (0)