Skip to content

Commit 9c598f4

Browse files
committed
feat: creating inference_api
- Exposing an user friendly interface to consume the `onnx` backend
1 parent f023901 commit 9c598f4

File tree

5 files changed

+125
-6
lines changed

5 files changed

+125
-6
lines changed

examples/ort-raw-session/index.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
const { Tensor, RawSession } = Supabase.ai;
2+
3+
const session = await RawSession.fromHuggingFace('kallebysantos/vehicle-emission', {
4+
path: {
5+
modelFile: 'model.onnx',
6+
},
7+
});
8+
9+
Deno.serve(async (_req: Request) => {
10+
// sample data could be a JSON request
11+
const carsBatchInput = [{
12+
'Model_Year': 2021,
13+
'Engine_Size': 2.9,
14+
'Cylinders': 6,
15+
'Fuel_Consumption_in_City': 13.9,
16+
'Fuel_Consumption_in_City_Hwy': 10.3,
17+
'Fuel_Consumption_comb': 12.3,
18+
'Smog_Level': 3,
19+
}, {
20+
'Model_Year': 2023,
21+
'Engine_Size': 2.4,
22+
'Cylinders': 4,
23+
'Fuel_Consumption_in_City': 9.9,
24+
'Fuel_Consumption_in_City_Hwy': 7.0,
25+
'Fuel_Consumption_comb': 8.6,
26+
'Smog_Level': 3,
27+
}];
28+
29+
// Parsing objects to tensor input
30+
const inputTensors = {};
31+
session.inputs.forEach((inputKey) => {
32+
const values = carsBatchInput.map((item) => item[inputKey]);
33+
34+
inputTensors[inputKey] = new Tensor('float32', values, [values.length, 1]);
35+
});
36+
37+
const { emissions } = await session.run(inputTensors);
38+
// [ 289.01, 199.53]
39+
40+
return Response.json({ result: emissions });
41+
});

ext/ai/js/ai.js

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import "ext:ai/onnxruntime/onnx.js";
2-
import EventSourceStream from "ext:ai/util/event_source_stream.mjs";
1+
import 'ext:ai/onnxruntime/onnx.js';
2+
import InferenceAPI from 'ext:ai/onnxruntime/inference_api.js';
3+
import EventSourceStream from 'ext:ai/util/event_source_stream.mjs';
34

45
const core = globalThis.Deno.core;
56

@@ -257,7 +258,8 @@ const MAIN_WORKER_API = {
257258
};
258259

259260
const USER_WORKER_API = {
260-
Session,
261+
Session,
262+
...InferenceAPI
261263
};
262264

263265
export { MAIN_WORKER_API, USER_WORKER_API };
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import { InferenceSession, Tensor } from 'ext:ai/onnxruntime/onnx.js';
2+
3+
const DEFAULT_HUGGING_FACE_OPTIONS = {
4+
hostname: 'https://huggingface.co',
5+
path: {
6+
template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true',
7+
revision: 'main',
8+
modelFile: 'model_quantized.onnx',
9+
},
10+
};
11+
12+
/**
13+
* An user friendly API for onnx backend
14+
*/
15+
class UserInferenceSession {
16+
inner;
17+
18+
id;
19+
inputs;
20+
outputs;
21+
22+
constructor(session) {
23+
this.inner = session;
24+
25+
this.id = session.sessionId;
26+
this.inputs = session.inputNames;
27+
this.outputs = session.outputNames;
28+
}
29+
30+
static async fromUrl(modelUrl) {
31+
if (modelUrl instanceof URL) {
32+
modelUrl = modelUrl.toString();
33+
}
34+
35+
const encoder = new TextEncoder();
36+
const modelUrlBuffer = encoder.encode(modelUrl);
37+
const session = await InferenceSession.fromBuffer(modelUrlBuffer);
38+
39+
return new UserInferenceSession(session);
40+
}
41+
42+
static async fromHuggingFace(repoId, opts = {}) {
43+
const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname;
44+
const pathOpts = {
45+
...DEFAULT_HUGGING_FACE_OPTIONS.path,
46+
...opts?.path,
47+
};
48+
49+
const modelPath = pathOpts.template
50+
.replaceAll('{REPO_ID}', repoId)
51+
.replaceAll('{REVISION}', pathOpts.revision)
52+
.replaceAll('{MODEL_FILE}', pathOpts.modelFile);
53+
54+
if (!URL.canParse(modelPath, hostname)) {
55+
throw Error(`[Invalid URL] Couldn't parse the model path: "${modelPath}"`);
56+
}
57+
58+
return await UserInferenceSession.fromUrl(new URL(modelPath, hostname));
59+
}
60+
61+
async run(inputs) {
62+
return await this.inner.run(inputs);
63+
}
64+
}
65+
66+
class UserTensor extends Tensor {
67+
constructor(type, data, dim) {
68+
super(type, data, dim);
69+
}
70+
}
71+
72+
export default {
73+
RawSession: UserInferenceSession,
74+
Tensor: UserTensor,
75+
};

ext/ai/js/onnxruntime/onnx.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TensorProxy {
3131
}
3232
}
3333

34-
class Tensor {
34+
export class Tensor {
3535
/** @type {DataType} Type of the tensor. */
3636
type;
3737

@@ -67,7 +67,7 @@ class Tensor {
6767
}
6868
}
6969

70-
class InferenceSession {
70+
export class InferenceSession {
7171
sessionId;
7272
inputNames;
7373
outputNames;

ext/ai/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ deno_core::extension!(
5555
"util/event_stream_parser.mjs",
5656
"util/event_source_stream.mjs",
5757
"onnxruntime/onnx.js",
58-
"onnxruntime/cache_adapter.js"
58+
"onnxruntime/cache_adapter.js",
59+
"onnxruntime/inference_api.js"
5960
]
6061
);
6162

0 commit comments

Comments
 (0)