Skip to content

Commit 4226ea7

Browse files
committed
Introduces custom support for OpenAI API specs.
While preserving the default http client, it introduces an OpenAI client type which also supports APIs similar to OpenAI's specs. Also includes an example Signed-off-by: Aminu Oluwaseun Joshua <[email protected]>
1 parent dd6e0a8 commit 4226ea7

File tree

11 files changed

+1556
-147
lines changed

11 files changed

+1556
-147
lines changed

crates/factor-llm/src/spin.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::path::PathBuf;
22
use std::sync::Arc;
33

44
use spin_factors::runtime_config::toml::GetTomlValue;
5-
use spin_llm_remote_http::RemoteHttpLlmEngine;
5+
use spin_llm_remote_http::{ApiType, RemoteHttpLlmEngine};
66
use spin_world::async_trait;
77
use spin_world::v1::llm::{self as v1};
88
use spin_world::v2::llm::{self as v2};
@@ -122,6 +122,7 @@ impl LlmCompute {
122122
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
123123
config.url,
124124
config.auth_token,
125+
config.api_type,
125126
))),
126127
};
127128
Ok(engine)
@@ -132,6 +133,8 @@ impl LlmCompute {
132133
pub struct RemoteHttpCompute {
133134
url: Url,
134135
auth_token: String,
136+
#[serde(default)]
137+
api_type: ApiType,
135138
}
136139

137140
/// A noop engine used when the local engine feature is disabled.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
use anyhow::Result;
2+
use reqwest::{
3+
header::{HeaderMap, HeaderValue},
4+
Client, Url,
5+
};
6+
use serde::{Deserialize, Serialize};
7+
use serde_json::json;
8+
use spin_world::{
9+
async_trait,
10+
v2::llm::{self as wasi_llm},
11+
};
12+
13+
use crate::LlmWorker;
14+
15+
pub(crate) struct AgentEngine {
16+
auth_token: String,
17+
url: Url,
18+
client: Option<Client>,
19+
}
20+
21+
impl AgentEngine {
22+
pub fn new(auth_token: String, url: Url, client: Option<Client>) -> Self {
23+
Self {
24+
auth_token,
25+
url,
26+
client,
27+
}
28+
}
29+
}
30+
31+
#[async_trait]
32+
impl LlmWorker for AgentEngine {
33+
async fn infer(
34+
&mut self,
35+
model: wasi_llm::InferencingModel,
36+
prompt: String,
37+
params: wasi_llm::InferencingParams,
38+
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
39+
let client = self.client.get_or_insert_with(Default::default);
40+
41+
let mut headers = HeaderMap::new();
42+
headers.insert(
43+
"authorization",
44+
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
45+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
46+
})?,
47+
);
48+
spin_telemetry::inject_trace_context(&mut headers);
49+
50+
let inference_options = InferRequestBodyParams {
51+
max_tokens: params.max_tokens,
52+
repeat_penalty: params.repeat_penalty,
53+
repeat_penalty_last_n_token_count: params.repeat_penalty_last_n_token_count,
54+
temperature: params.temperature,
55+
top_k: params.top_k,
56+
top_p: params.top_p,
57+
};
58+
let body = serde_json::to_string(&json!({
59+
"model": model,
60+
"prompt": prompt,
61+
"options": inference_options
62+
}))
63+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
64+
65+
let infer_url = self
66+
.url
67+
.join("/infer")
68+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
69+
tracing::info!("Sending remote inference request to {infer_url}");
70+
71+
let resp = client
72+
.request(reqwest::Method::POST, infer_url)
73+
.headers(headers)
74+
.body(body)
75+
.send()
76+
.await
77+
.map_err(|err| {
78+
wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
79+
})?;
80+
81+
match resp.json::<InferResponseBody>().await {
82+
Ok(val) => Ok(val.into()),
83+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
84+
"Failed to deserialize response for \"POST /index\": {err}"
85+
))),
86+
}
87+
}
88+
89+
async fn generate_embeddings(
90+
&mut self,
91+
model: wasi_llm::EmbeddingModel,
92+
data: Vec<String>,
93+
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
94+
let client = self.client.get_or_insert_with(Default::default);
95+
96+
let mut headers = HeaderMap::new();
97+
headers.insert(
98+
"authorization",
99+
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
100+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
101+
})?,
102+
);
103+
spin_telemetry::inject_trace_context(&mut headers);
104+
105+
let body = serde_json::to_string(&json!({
106+
"model": model,
107+
"input": data
108+
}))
109+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
110+
111+
let resp = client
112+
.request(
113+
reqwest::Method::POST,
114+
self.url.join("/embed").map_err(|_| {
115+
wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
116+
})?,
117+
)
118+
.headers(headers)
119+
.body(body)
120+
.send()
121+
.await
122+
.map_err(|err| {
123+
wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
124+
})?;
125+
126+
match resp.json::<EmbeddingResponseBody>().await {
127+
Ok(val) => Ok(val.into()),
128+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
129+
"Failed to deserialize response for \"POST /embed\": {err}"
130+
))),
131+
}
132+
}
133+
134+
fn url(&self) -> Url {
135+
self.url.clone()
136+
}
137+
}
138+
139+
#[derive(Serialize)]
140+
#[serde(rename_all(serialize = "camelCase"))]
141+
struct InferRequestBodyParams {
142+
max_tokens: u32,
143+
repeat_penalty: f32,
144+
repeat_penalty_last_n_token_count: u32,
145+
temperature: f32,
146+
top_k: u32,
147+
top_p: f32,
148+
}
149+
150+
#[derive(Deserialize)]
151+
#[serde(rename_all(deserialize = "camelCase"))]
152+
pub struct InferUsage {
153+
prompt_token_count: u32,
154+
generated_token_count: u32,
155+
}
156+
157+
#[derive(Deserialize)]
158+
pub struct InferResponseBody {
159+
text: String,
160+
usage: InferUsage,
161+
}
162+
163+
#[derive(Deserialize)]
164+
#[serde(rename_all(deserialize = "camelCase"))]
165+
struct EmbeddingUsage {
166+
prompt_token_count: u32,
167+
}
168+
169+
#[derive(Deserialize)]
170+
struct EmbeddingResponseBody {
171+
embeddings: Vec<Vec<f32>>,
172+
usage: EmbeddingUsage,
173+
}
174+
175+
impl From<InferResponseBody> for wasi_llm::InferencingResult {
176+
fn from(value: InferResponseBody) -> Self {
177+
Self {
178+
text: value.text,
179+
usage: wasi_llm::InferencingUsage {
180+
prompt_token_count: value.usage.prompt_token_count,
181+
generated_token_count: value.usage.generated_token_count,
182+
},
183+
}
184+
}
185+
}
186+
187+
impl From<EmbeddingResponseBody> for wasi_llm::EmbeddingsResult {
188+
fn from(value: EmbeddingResponseBody) -> Self {
189+
Self {
190+
embeddings: value.embeddings,
191+
usage: wasi_llm::EmbeddingsUsage {
192+
prompt_token_count: value.usage.prompt_token_count,
193+
},
194+
}
195+
}
196+
}

0 commit comments

Comments
 (0)