Skip to content

Commit 4394e4b

Browse files
authored
Merge pull request #3238 from seun-ja/implement-open-ai-api-llm-backend
Implement open ai api llm backend
2 parents f052ff5 + 4226ea7 commit 4394e4b

File tree

11 files changed

+1556
-147
lines changed

11 files changed

+1556
-147
lines changed

crates/factor-llm/src/spin.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::path::PathBuf;
22
use std::sync::Arc;
33

44
use spin_factors::runtime_config::toml::GetTomlValue;
5-
use spin_llm_remote_http::RemoteHttpLlmEngine;
5+
use spin_llm_remote_http::{ApiType, RemoteHttpLlmEngine};
66
use spin_world::async_trait;
77
use spin_world::v1::llm::{self as v1};
88
use spin_world::v2::llm::{self as v2};
@@ -122,6 +122,7 @@ impl LlmCompute {
122122
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
123123
config.url,
124124
config.auth_token,
125+
config.api_type,
125126
))),
126127
};
127128
Ok(engine)
@@ -132,6 +133,8 @@ impl LlmCompute {
132133
pub struct RemoteHttpCompute {
133134
url: Url,
134135
auth_token: String,
136+
#[serde(default)]
137+
api_type: ApiType,
135138
}
136139

137140
/// A noop engine used when the local engine feature is disabled.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
use anyhow::Result;
2+
use reqwest::{
3+
header::{HeaderMap, HeaderValue},
4+
Client, Url,
5+
};
6+
use serde::{Deserialize, Serialize};
7+
use serde_json::json;
8+
use spin_world::{
9+
async_trait,
10+
v2::llm::{self as wasi_llm},
11+
};
12+
13+
use crate::LlmWorker;
14+
15+
pub(crate) struct AgentEngine {
16+
auth_token: String,
17+
url: Url,
18+
client: Option<Client>,
19+
}
20+
21+
impl AgentEngine {
22+
pub fn new(auth_token: String, url: Url, client: Option<Client>) -> Self {
23+
Self {
24+
auth_token,
25+
url,
26+
client,
27+
}
28+
}
29+
}
30+
31+
#[async_trait]
32+
impl LlmWorker for AgentEngine {
33+
async fn infer(
34+
&mut self,
35+
model: wasi_llm::InferencingModel,
36+
prompt: String,
37+
params: wasi_llm::InferencingParams,
38+
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
39+
let client = self.client.get_or_insert_with(Default::default);
40+
41+
let mut headers = HeaderMap::new();
42+
headers.insert(
43+
"authorization",
44+
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
45+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
46+
})?,
47+
);
48+
spin_telemetry::inject_trace_context(&mut headers);
49+
50+
let inference_options = InferRequestBodyParams {
51+
max_tokens: params.max_tokens,
52+
repeat_penalty: params.repeat_penalty,
53+
repeat_penalty_last_n_token_count: params.repeat_penalty_last_n_token_count,
54+
temperature: params.temperature,
55+
top_k: params.top_k,
56+
top_p: params.top_p,
57+
};
58+
let body = serde_json::to_string(&json!({
59+
"model": model,
60+
"prompt": prompt,
61+
"options": inference_options
62+
}))
63+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
64+
65+
let infer_url = self
66+
.url
67+
.join("/infer")
68+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
69+
tracing::info!("Sending remote inference request to {infer_url}");
70+
71+
let resp = client
72+
.request(reqwest::Method::POST, infer_url)
73+
.headers(headers)
74+
.body(body)
75+
.send()
76+
.await
77+
.map_err(|err| {
78+
wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
79+
})?;
80+
81+
match resp.json::<InferResponseBody>().await {
82+
Ok(val) => Ok(val.into()),
83+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
84+
"Failed to deserialize response for \"POST /index\": {err}"
85+
))),
86+
}
87+
}
88+
89+
async fn generate_embeddings(
90+
&mut self,
91+
model: wasi_llm::EmbeddingModel,
92+
data: Vec<String>,
93+
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
94+
let client = self.client.get_or_insert_with(Default::default);
95+
96+
let mut headers = HeaderMap::new();
97+
headers.insert(
98+
"authorization",
99+
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
100+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
101+
})?,
102+
);
103+
spin_telemetry::inject_trace_context(&mut headers);
104+
105+
let body = serde_json::to_string(&json!({
106+
"model": model,
107+
"input": data
108+
}))
109+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
110+
111+
let resp = client
112+
.request(
113+
reqwest::Method::POST,
114+
self.url.join("/embed").map_err(|_| {
115+
wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
116+
})?,
117+
)
118+
.headers(headers)
119+
.body(body)
120+
.send()
121+
.await
122+
.map_err(|err| {
123+
wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
124+
})?;
125+
126+
match resp.json::<EmbeddingResponseBody>().await {
127+
Ok(val) => Ok(val.into()),
128+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
129+
"Failed to deserialize response for \"POST /embed\": {err}"
130+
))),
131+
}
132+
}
133+
134+
fn url(&self) -> Url {
135+
self.url.clone()
136+
}
137+
}
138+
139+
#[derive(Serialize)]
140+
#[serde(rename_all(serialize = "camelCase"))]
141+
struct InferRequestBodyParams {
142+
max_tokens: u32,
143+
repeat_penalty: f32,
144+
repeat_penalty_last_n_token_count: u32,
145+
temperature: f32,
146+
top_k: u32,
147+
top_p: f32,
148+
}
149+
150+
#[derive(Deserialize)]
151+
#[serde(rename_all(deserialize = "camelCase"))]
152+
pub struct InferUsage {
153+
prompt_token_count: u32,
154+
generated_token_count: u32,
155+
}
156+
157+
#[derive(Deserialize)]
158+
pub struct InferResponseBody {
159+
text: String,
160+
usage: InferUsage,
161+
}
162+
163+
#[derive(Deserialize)]
164+
#[serde(rename_all(deserialize = "camelCase"))]
165+
struct EmbeddingUsage {
166+
prompt_token_count: u32,
167+
}
168+
169+
#[derive(Deserialize)]
170+
struct EmbeddingResponseBody {
171+
embeddings: Vec<Vec<f32>>,
172+
usage: EmbeddingUsage,
173+
}
174+
175+
impl From<InferResponseBody> for wasi_llm::InferencingResult {
176+
fn from(value: InferResponseBody) -> Self {
177+
Self {
178+
text: value.text,
179+
usage: wasi_llm::InferencingUsage {
180+
prompt_token_count: value.usage.prompt_token_count,
181+
generated_token_count: value.usage.generated_token_count,
182+
},
183+
}
184+
}
185+
}
186+
187+
impl From<EmbeddingResponseBody> for wasi_llm::EmbeddingsResult {
188+
fn from(value: EmbeddingResponseBody) -> Self {
189+
Self {
190+
embeddings: value.embeddings,
191+
usage: wasi_llm::EmbeddingsUsage {
192+
prompt_token_count: value.usage.prompt_token_count,
193+
},
194+
}
195+
}
196+
}

0 commit comments

Comments
 (0)