Skip to content

Commit ccc7b34

Browse files
committed
Introduces custom support for OpenAI API specs.
While preserving the default http client, it introduces an OpenAI client type which also supports APIs similar to OpenAI's specs. Also includes an example Signed-off-by: Aminu Oluwaseun Joshua <[email protected]>
1 parent dd6e0a8 commit ccc7b34

File tree

11 files changed

+1584
-115
lines changed

11 files changed

+1584
-115
lines changed

crates/factor-llm/src/spin.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::path::PathBuf;
22
use std::sync::Arc;
33

44
use spin_factors::runtime_config::toml::GetTomlValue;
5-
use spin_llm_remote_http::RemoteHttpLlmEngine;
5+
use spin_llm_remote_http::{ApiType, RemoteHttpLlmEngine};
66
use spin_world::async_trait;
77
use spin_world::v1::llm::{self as v1};
88
use spin_world::v2::llm::{self as v2};
@@ -122,6 +122,7 @@ impl LlmCompute {
122122
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
123123
config.url,
124124
config.auth_token,
125+
config.api_type,
125126
))),
126127
};
127128
Ok(engine)
@@ -132,6 +133,8 @@ impl LlmCompute {
132133
pub struct RemoteHttpCompute {
133134
url: Url,
134135
auth_token: String,
136+
#[serde(default)]
137+
api_type: ApiType,
135138
}
136139

137140
/// A noop engine used when the local engine feature is disabled.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
use anyhow::Result;
2+
use reqwest::{
3+
header::{HeaderMap, HeaderValue},
4+
Client, Url,
5+
};
6+
use serde_json::json;
7+
use spin_world::{
8+
async_trait,
9+
v2::llm::{self as wasi_llm},
10+
};
11+
12+
use crate::{EmbeddingResponseBody, InferRequestBodyParams, InferResponseBody, LlmWorker};
13+
14+
pub(crate) struct DefaultAgentEngine {
15+
auth_token: String,
16+
url: Url,
17+
client: Option<Client>,
18+
}
19+
20+
impl DefaultAgentEngine {
21+
pub fn new(auth_token: String, url: Url, client: Option<Client>) -> Self {
22+
Self {
23+
auth_token,
24+
url,
25+
client,
26+
}
27+
}
28+
}
29+
30+
#[async_trait]
31+
impl LlmWorker for DefaultAgentEngine {
32+
async fn infer(
33+
&mut self,
34+
model: wasi_llm::InferencingModel,
35+
prompt: String,
36+
params: wasi_llm::InferencingParams,
37+
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
38+
let client = self.client.get_or_insert_with(Default::default);
39+
40+
let mut headers = HeaderMap::new();
41+
headers.insert(
42+
"authorization",
43+
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
44+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
45+
})?,
46+
);
47+
spin_telemetry::inject_trace_context(&mut headers);
48+
49+
let inference_options = InferRequestBodyParams {
50+
max_tokens: params.max_tokens,
51+
repeat_penalty: params.repeat_penalty,
52+
repeat_penalty_last_n_token_count: params.repeat_penalty_last_n_token_count,
53+
temperature: params.temperature,
54+
top_k: params.top_k,
55+
top_p: params.top_p,
56+
};
57+
let body = serde_json::to_string(&json!({
58+
"model": model,
59+
"prompt": prompt,
60+
"options": inference_options
61+
}))
62+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
63+
64+
let infer_url = self
65+
.url
66+
.join("/infer")
67+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
68+
tracing::info!("Sending remote inference request to {infer_url}");
69+
70+
let resp = client
71+
.request(reqwest::Method::POST, infer_url)
72+
.headers(headers)
73+
.body(body)
74+
.send()
75+
.await
76+
.map_err(|err| {
77+
wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
78+
})?;
79+
80+
match resp.json::<InferResponseBody>().await {
81+
Ok(val) => Ok(val.into()),
82+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
83+
"Failed to deserialize response for \"POST /index\": {err}"
84+
))),
85+
}
86+
}
87+
88+
async fn generate_embeddings(
89+
&mut self,
90+
model: wasi_llm::EmbeddingModel,
91+
data: Vec<String>,
92+
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
93+
let client = self.client.get_or_insert_with(Default::default);
94+
95+
let mut headers = HeaderMap::new();
96+
headers.insert(
97+
"authorization",
98+
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
99+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
100+
})?,
101+
);
102+
spin_telemetry::inject_trace_context(&mut headers);
103+
104+
let body = serde_json::to_string(&json!({
105+
"model": model,
106+
"input": data
107+
}))
108+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
109+
110+
let resp = client
111+
.request(
112+
reqwest::Method::POST,
113+
self.url.join("/embed").map_err(|_| {
114+
wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
115+
})?,
116+
)
117+
.headers(headers)
118+
.body(body)
119+
.send()
120+
.await
121+
.map_err(|err| {
122+
wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
123+
})?;
124+
125+
match resp.json::<EmbeddingResponseBody>().await {
126+
Ok(val) => Ok(val.into()),
127+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
128+
"Failed to deserialize response for \"POST /embed\": {err}"
129+
))),
130+
}
131+
}
132+
133+
fn url(&self) -> Url {
134+
self.url.clone()
135+
}
136+
}

0 commit comments

Comments
 (0)