Skip to content

Commit eb1bb49

Browse files
committed
feat: support for authorization header on model fetch
1 parent b17079b commit eb1bb49

File tree

7 files changed

+122
-89
lines changed

7 files changed

+122
-89
lines changed
Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,93 @@
1+
const core = globalThis.Deno.core;
12
import { InferenceSession, Tensor } from 'ext:ai/onnxruntime/onnx.js';
23

34
const DEFAULT_HUGGING_FACE_OPTIONS = {
4-
hostname: 'https://huggingface.co',
5-
path: {
6-
template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true',
7-
revision: 'main',
8-
modelFile: 'model_quantized.onnx',
9-
},
5+
hostname: 'https://huggingface.co',
6+
path: {
7+
template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true',
8+
revision: 'main',
9+
modelFile: 'model_quantized.onnx',
10+
},
1011
};
1112

1213
/**
1314
* An user friendly API for onnx backend
1415
*/
1516
class UserInferenceSession {
16-
inner;
17+
inner;
1718

18-
id;
19-
inputs;
20-
outputs;
19+
id;
20+
inputs;
21+
outputs;
2122

22-
constructor(session) {
23-
this.inner = session;
23+
constructor(session) {
24+
this.inner = session;
2425

25-
this.id = session.sessionId;
26-
this.inputs = session.inputNames;
27-
this.outputs = session.outputNames;
28-
}
29-
30-
static async fromUrl(modelUrl) {
31-
if (modelUrl instanceof URL) {
32-
modelUrl = modelUrl.toString();
33-
}
34-
35-
const encoder = new TextEncoder();
36-
const modelUrlBuffer = encoder.encode(modelUrl);
37-
const session = await InferenceSession.fromBuffer(modelUrlBuffer);
26+
this.id = session.sessionId;
27+
this.inputs = session.inputNames;
28+
this.outputs = session.outputNames;
29+
}
3830

39-
return new UserInferenceSession(session);
31+
static async fromUrl(modelUrl) {
32+
if (modelUrl instanceof URL) {
33+
modelUrl = modelUrl.toString();
4034
}
4135

42-
static async fromHuggingFace(repoId, opts = {}) {
43-
const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname;
44-
const pathOpts = {
45-
...DEFAULT_HUGGING_FACE_OPTIONS.path,
46-
...opts?.path,
47-
};
48-
49-
const modelPath = pathOpts.template
50-
.replaceAll('{REPO_ID}', repoId)
51-
.replaceAll('{REVISION}', pathOpts.revision)
52-
.replaceAll('{MODEL_FILE}', pathOpts.modelFile);
53-
54-
if (!URL.canParse(modelPath, hostname)) {
55-
throw Error(`[Invalid URL] Couldn't parse the model path: "${modelPath}"`);
56-
}
57-
58-
return await UserInferenceSession.fromUrl(new URL(modelPath, hostname));
36+
const encoder = new TextEncoder();
37+
const modelUrlBuffer = encoder.encode(modelUrl);
38+
const session = await InferenceSession.fromBuffer(modelUrlBuffer);
39+
40+
return new UserInferenceSession(session);
41+
}
42+
43+
static async fromHuggingFace(repoId, opts = {}) {
44+
const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname;
45+
const pathOpts = {
46+
...DEFAULT_HUGGING_FACE_OPTIONS.path,
47+
...opts?.path,
48+
};
49+
50+
const modelPath = pathOpts.template
51+
.replaceAll('{REPO_ID}', repoId)
52+
.replaceAll('{REVISION}', pathOpts.revision)
53+
.replaceAll('{MODEL_FILE}', pathOpts.modelFile);
54+
55+
if (!URL.canParse(modelPath, hostname)) {
56+
throw Error(
57+
`[Invalid URL] Couldn't parse the model path: "${modelPath}"`,
58+
);
5959
}
6060

61-
async run(inputs) {
62-
const outputs = await core.ops.op_sb_ai_ort_run_session(this.id, inputs);
61+
return await UserInferenceSession.fromUrl(new URL(modelPath, hostname));
62+
}
6363

64-
// Parse to Tensor
65-
for (const key in outputs) {
66-
if (Object.hasOwn(outputs, key)) {
67-
const { type, data, dims } = outputs[key];
64+
async run(inputs) {
65+
const outputs = await core.ops.op_ai_ort_run_session(this.id, inputs);
6866

69-
outputs[key] = new UserTensor(type, data.buffer, dims);
70-
}
71-
}
67+
// Parse to Tensor
68+
for (const key in outputs) {
69+
if (Object.hasOwn(outputs, key)) {
70+
const { type, data, dims } = outputs[key];
7271

73-
return outputs;
72+
outputs[key] = new UserTensor(type, data.buffer, dims);
73+
}
7474
}
75+
76+
return outputs;
77+
}
7578
}
7679

7780
class UserTensor extends Tensor {
78-
constructor(type, data, dim) {
79-
super(type, data, dim);
80-
}
81+
constructor(type, data, dim) {
82+
super(type, data, dim);
83+
}
8184

82-
async tryEncodeAudio(sampleRate) {
83-
return await core.ops.op_sb_ai_ort_encode_tensor_audio(this.data, sampleRate);
84-
}
85+
async tryEncodeAudio(sampleRate) {
86+
return await core.ops.op_ai_ort_encode_tensor_audio(this.data, sampleRate);
87+
}
8588
}
8689

8790
export default {
88-
RawSession: UserInferenceSession,
89-
RawTensor: UserTensor,
91+
RawSession: UserInferenceSession,
92+
RawTensor: UserTensor,
9093
};

ext/ai/js/onnxruntime/onnx.js

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ const DataTypeMap = Object.freeze({
1818
class TensorProxy {
1919
get(target, property) {
2020
switch (property) {
21-
case "data":
21+
case 'data':
2222
return target.data?.c ?? target.data;
2323

2424
default:
@@ -86,6 +86,15 @@ export class InferenceSession {
8686
return new InferenceSession(id, inputs, outputs);
8787
}
8888

89+
static async fromRequest(modelUrl, authorization) {
90+
const [id, inputs, outputs] = await core.ops.op_ai_ort_init_session(
91+
modelUrl,
92+
authorization,
93+
);
94+
95+
return new InferenceSession(id, inputs, outputs);
96+
}
97+
8998
async run(inputs) {
9099
const sessionInputs = {};
91100

@@ -125,4 +134,4 @@ const onnxruntime = {
125134
},
126135
};
127136

128-
globalThis[Symbol.for("onnxruntime")] = onnxruntime;
137+
globalThis[Symbol.for('onnxruntime')] = onnxruntime;

ext/ai/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
118118
let handle = handle.clone();
119119
move || {
120120
handle.block_on(async move {
121-
load_session_from_url(Url::parse(consts::GTE_SMALL_MODEL_URL).unwrap())
121+
load_session_from_url(Url::parse(consts::GTE_SMALL_MODEL_URL).unwrap(), None)
122122
.await
123123
})
124124
}
@@ -143,6 +143,7 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
143143
"tokenizer",
144144
Url::parse(consts::GTE_SMALL_TOKENIZER_URL).unwrap(),
145145
None,
146+
None
146147
)
147148
.map_err(AnyError::from)
148149
.and_then(|it| {

ext/ai/onnxruntime/mod.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,28 @@ use tokio_util::bytes::BufMut;
3737
#[op2(async)]
3838
#[to_v8]
3939
pub async fn op_ai_ort_init_session(
40-
state: Rc<RefCell<OpState>>,
41-
#[buffer] model_bytes: JsBuffer,
40+
state: Rc<RefCell<OpState>>,
41+
#[buffer] model_bytes: JsBuffer,
42+
// Maybe improve the code style to enum payload or something else
43+
#[string] req_authorization: Option<String>,
4244
) -> Result<ModelInfo> {
4345
let model_bytes = model_bytes.into_parts().to_boxed_slice();
4446
let model_bytes_or_url = str::from_utf8(&model_bytes)
4547
.map_err(AnyError::from)
4648
.and_then(|utf8_str| Url::parse(utf8_str).map_err(AnyError::from));
4749

48-
let model = match model_bytes_or_url {
49-
Ok(model_url) => {
50-
trace!(kind = "url", url = %model_url);
51-
Model::from_url(model_url).await?
52-
}
53-
Err(_) => {
54-
trace!(kind = "bytes", len = model_bytes.len());
55-
Model::from_bytes(&model_bytes).await?
56-
}
57-
};
58-
59-
let mut state = state.borrow_mut();
50+
let model = match model_bytes_or_url {
51+
Ok(model_url) => {
52+
trace!(kind = "url", url = %model_url);
53+
Model::from_url(model_url, req_authorization).await?
54+
}
55+
Err(_) => {
56+
trace!(kind = "bytes", len = model_bytes.len());
57+
Model::from_bytes(&model_bytes).await?
58+
}
59+
};
60+
61+
let mut state = state.borrow_mut();
6062
let mut sessions =
6163
{ state.try_take::<Vec<Arc<Session>>>().unwrap_or_default() };
6264

ext/ai/onnxruntime/model.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ impl Model {
7171
.map(Self::new)
7272
}
7373

74-
pub async fn from_url(model_url: Url) -> Result<Self> {
75-
load_session_from_url(model_url).await.map(Self::new)
76-
}
74+
pub async fn from_url(model_url: Url, authorization: Option<String>) -> Result<Self> {
75+
load_session_from_url(model_url, authorization)
76+
.await
77+
.map(Self::new)
78+
}
7779

7880
pub async fn from_bytes(model_bytes: &[u8]) -> Result<Self> {
7981
load_session_from_bytes(model_bytes).await.map(Self::new)

ext/ai/onnxruntime/session.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,10 @@ pub(crate) async fn load_session_from_bytes(
154154

155155
#[instrument(level = "debug", fields(%model_url), err)]
156156
pub(crate) async fn load_session_from_url(
157-
model_url: Url,
157+
model_url: Url,
158+
authorization: Option<String>,
158159
) -> Result<SessionWithId, Error> {
159-
let session_id = fxhash::hash(model_url.as_str()).to_string();
160+
let session_id = fxhash::hash(model_url.as_str()).to_string();
160161

161162
let mut sessions = SESSIONS.lock().await;
162163

@@ -165,12 +166,13 @@ pub(crate) async fn load_session_from_url(
165166
return Ok((session_id, session.clone()).into());
166167
}
167168

168-
let model_file_path = crate::utils::fetch_and_cache_from_url(
169-
"model",
170-
model_url,
171-
Some(session_id.to_string()),
172-
)
173-
.await?;
169+
let model_file_path = crate::utils::fetch_and_cache_from_url(
170+
"model",
171+
model_url,
172+
Some(session_id.to_string()),
173+
authorization,
174+
)
175+
.await?;
174176

175177
let model_bytes = tokio::fs::read(model_file_path).await?;
176178
let session = create_session(model_bytes.as_slice())?;

ext/ai/utils.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub async fn fetch_and_cache_from_url(
2020
kind: &'static str,
2121
url: Url,
2222
cache_id: Option<String>,
23+
authorization: Option<String>,
2324
) -> Result<PathBuf, AnyError> {
2425
let cache_id = cache_id.unwrap_or(fxhash::hash(url.as_str()).to_string());
2526
let download_dir = std::env::var("EXT_AI_CACHE_DIR")
@@ -91,13 +92,26 @@ pub async fn fetch_and_cache_from_url(
9192

9293
use reqwest::*;
9394

95+
let mut headers = header::HeaderMap::new();
96+
97+
if let Some(authorization) = authorization {
98+
let mut authorization =
99+
header::HeaderValue::from_str(authorization.as_str())?;
100+
authorization.set_sensitive(true);
101+
102+
headers.insert(header::AUTHORIZATION, authorization);
103+
};
104+
94105
let resp = Client::builder()
95106
.http1_only()
107+
.default_headers(headers)
96108
.build()
97109
.context("failed to create http client")?
98110
.get(url.clone())
99111
.send()
100112
.await
113+
.context("failed to download")?
114+
.error_for_status()
101115
.context("failed to download")?;
102116

103117
let file = tokio::fs::File::create(&filepath)

0 commit comments

Comments
 (0)