Skip to content

Commit b770937

Browse files
committed
WIP: Request API Setup
Signed-off-by: Aminu Oluwaseun Joshua <[email protected]>
1 parent 6d60290 commit b770937

File tree

4 files changed

+445
-103
lines changed

4 files changed

+445
-103
lines changed

crates/factor-llm/src/spin.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ impl LlmCompute {
122122
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
123123
config.url,
124124
config.auth_token,
125+
config.agent,
125126
))),
126127
};
127128
Ok(engine)
@@ -132,6 +133,7 @@ impl LlmCompute {
132133
pub struct RemoteHttpCompute {
133134
url: Url,
134135
auth_token: String,
136+
agent: Option<String>,
135137
}
136138

137139
/// A noop engine used when the local engine feature is disabled.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
use anyhow::Result;
2+
use reqwest::{
3+
header::{HeaderMap, HeaderValue},
4+
Client, Url,
5+
};
6+
use serde_json::json;
7+
use spin_world::v2::llm::{self as wasi_llm};
8+
9+
use crate::{EmbeddingResponseBody, InferRequestBodyParams, InferResponseBody};
10+
11+
pub(crate) struct DefaultAgentEngine;
12+
13+
impl DefaultAgentEngine {
14+
pub async fn infer(
15+
auth_token: &String,
16+
url: &Url,
17+
mut client: Option<Client>,
18+
model: wasi_llm::InferencingModel,
19+
prompt: String,
20+
params: wasi_llm::InferencingParams,
21+
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
22+
let client = client.get_or_insert_with(Default::default);
23+
24+
let mut headers = HeaderMap::new();
25+
headers.insert(
26+
"authorization",
27+
HeaderValue::from_str(&format!("bearer {}", auth_token)).map_err(|_| {
28+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
29+
})?,
30+
);
31+
spin_telemetry::inject_trace_context(&mut headers);
32+
33+
let inference_options = InferRequestBodyParams {
34+
max_tokens: params.max_tokens,
35+
repeat_penalty: params.repeat_penalty,
36+
repeat_penalty_last_n_token_count: params.repeat_penalty_last_n_token_count,
37+
temperature: params.temperature,
38+
top_k: params.top_k,
39+
top_p: params.top_p,
40+
};
41+
let body = serde_json::to_string(&json!({
42+
"model": model,
43+
"prompt": prompt,
44+
"options": inference_options
45+
}))
46+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
47+
48+
let infer_url = url
49+
.join("/infer")
50+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
51+
tracing::info!("Sending remote inference request to {infer_url}");
52+
53+
let resp = client
54+
.request(reqwest::Method::POST, infer_url)
55+
.headers(headers)
56+
.body(body)
57+
.send()
58+
.await
59+
.map_err(|err| {
60+
wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
61+
})?;
62+
63+
match resp.json::<InferResponseBody>().await {
64+
Ok(val) => Ok(val.into()),
65+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
66+
"Failed to deserialize response for \"POST /index\": {err}"
67+
))),
68+
}
69+
}
70+
71+
pub async fn generate_embeddings(
72+
auth_token: &String,
73+
url: &Url,
74+
mut client: Option<Client>,
75+
model: wasi_llm::EmbeddingModel,
76+
data: Vec<String>,
77+
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
78+
let client = client.get_or_insert_with(Default::default);
79+
80+
let mut headers = HeaderMap::new();
81+
headers.insert(
82+
"authorization",
83+
HeaderValue::from_str(&format!("bearer {}", auth_token)).map_err(|_| {
84+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
85+
})?,
86+
);
87+
spin_telemetry::inject_trace_context(&mut headers);
88+
89+
let body = serde_json::to_string(&json!({
90+
"model": model,
91+
"input": data
92+
}))
93+
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
94+
95+
let resp = client
96+
.request(
97+
reqwest::Method::POST,
98+
url.join("/embed").map_err(|_| {
99+
wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
100+
})?,
101+
)
102+
.headers(headers)
103+
.body(body)
104+
.send()
105+
.await
106+
.map_err(|err| {
107+
wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
108+
})?;
109+
110+
match resp.json::<EmbeddingResponseBody>().await {
111+
Ok(val) => Ok(val.into()),
112+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
113+
"Failed to deserialize response for \"POST /embed\": {err}"
114+
))),
115+
}
116+
}
117+
}

crates/llm-remote-http/src/lib.rs

Lines changed: 87 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,48 @@
11
use anyhow::Result;
2-
use reqwest::{
3-
header::{HeaderMap, HeaderValue},
4-
Client, Url,
5-
};
2+
use reqwest::{Client, Url};
63
use serde::{Deserialize, Serialize};
7-
use serde_json::json;
84
use spin_world::v2::llm::{self as wasi_llm};
95

6+
use crate::{default::DefaultAgentEngine, open_ai::OpenAIAgentEngine};
7+
8+
mod default;
109
mod open_ai;
1110

11+
#[derive(Clone)]
12+
pub enum Agent {
13+
//TODO: Joshua: Naming??!
14+
Default {
15+
auth_token: String,
16+
url: Url,
17+
client: Option<Client>,
18+
},
19+
OpenAI {
20+
auth_token: String,
21+
url: Url,
22+
client: Option<Client>,
23+
},
24+
}
25+
26+
impl Agent {
27+
pub fn from(url: Url, auth_token: String, agent: Option<String>) -> Self {
28+
match agent {
29+
Some(agent_name) if agent_name == *"open_ai" => Agent::OpenAI {
30+
auth_token,
31+
url,
32+
client: None,
33+
},
34+
_ => Agent::Default {
35+
auth_token,
36+
url,
37+
client: None,
38+
},
39+
}
40+
}
41+
}
42+
1243
#[derive(Clone)]
1344
pub struct RemoteHttpLlmEngine {
14-
auth_token: String,
15-
url: Url,
16-
client: Option<Client>,
45+
agent: Agent,
1746
}
1847

1948
#[derive(Serialize)]
@@ -53,59 +82,35 @@ struct EmbeddingResponseBody {
5382
}
5483

5584
impl RemoteHttpLlmEngine {
85+
pub fn new(url: Url, auth_token: String, agent: Option<String>) -> Self {
86+
RemoteHttpLlmEngine {
87+
agent: Agent::from(url, auth_token, agent),
88+
}
89+
}
90+
5691
pub async fn infer(
5792
&mut self,
5893
model: wasi_llm::InferencingModel,
5994
prompt: String,
6095
params: wasi_llm::InferencingParams,
6196
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
62-
let client = self.client.get_or_insert_with(Default::default);
63-
64-
let mut headers = HeaderMap::new();
65-
headers.insert(
66-
"authorization",
67-
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
68-
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
69-
})?,
70-
);
71-
spin_telemetry::inject_trace_context(&mut headers);
72-
73-
let inference_options = InferRequestBodyParams {
74-
max_tokens: params.max_tokens,
75-
repeat_penalty: params.repeat_penalty,
76-
repeat_penalty_last_n_token_count: params.repeat_penalty_last_n_token_count,
77-
temperature: params.temperature,
78-
top_k: params.top_k,
79-
top_p: params.top_p,
80-
};
81-
let body = serde_json::to_string(&json!({
82-
"model": model,
83-
"prompt": prompt,
84-
"options": inference_options
85-
}))
86-
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
87-
88-
let infer_url = self
89-
.url
90-
.join("/infer")
91-
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
92-
tracing::info!("Sending remote inference request to {infer_url}");
93-
94-
let resp = client
95-
.request(reqwest::Method::POST, infer_url)
96-
.headers(headers)
97-
.body(body)
98-
.send()
99-
.await
100-
.map_err(|err| {
101-
wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
102-
})?;
103-
104-
match resp.json::<InferResponseBody>().await {
105-
Ok(val) => Ok(val.into()),
106-
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
107-
"Failed to deserialize response for \"POST /index\": {err}"
108-
))),
97+
match &self.agent {
98+
Agent::OpenAI {
99+
auth_token,
100+
url,
101+
client,
102+
} => {
103+
OpenAIAgentEngine::infer(auth_token, url, client.clone(), model, prompt, params)
104+
.await
105+
}
106+
Agent::Default {
107+
auth_token,
108+
url,
109+
client,
110+
} => {
111+
DefaultAgentEngine::infer(auth_token, url, client.clone(), model, prompt, params)
112+
.await
113+
}
109114
}
110115
}
111116

@@ -114,48 +119,37 @@ impl RemoteHttpLlmEngine {
114119
model: wasi_llm::EmbeddingModel,
115120
data: Vec<String>,
116121
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
117-
let client = self.client.get_or_insert_with(Default::default);
118-
119-
let mut headers = HeaderMap::new();
120-
headers.insert(
121-
"authorization",
122-
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
123-
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
124-
})?,
125-
);
126-
spin_telemetry::inject_trace_context(&mut headers);
127-
128-
let body = serde_json::to_string(&json!({
129-
"model": model,
130-
"input": data
131-
}))
132-
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
133-
134-
let resp = client
135-
.request(
136-
reqwest::Method::POST,
137-
self.url.join("/embed").map_err(|_| {
138-
wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
139-
})?,
140-
)
141-
.headers(headers)
142-
.body(body)
143-
.send()
144-
.await
145-
.map_err(|err| {
146-
wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
147-
})?;
148-
149-
match resp.json::<EmbeddingResponseBody>().await {
150-
Ok(val) => Ok(val.into()),
151-
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
152-
"Failed to deserialize response for \"POST /embed\": {err}"
153-
))),
122+
match &self.agent {
123+
Agent::OpenAI {
124+
auth_token,
125+
url,
126+
client,
127+
} => {
128+
OpenAIAgentEngine::generate_embeddings(auth_token, url, client.clone(), model, data)
129+
.await
130+
}
131+
Agent::Default {
132+
auth_token,
133+
url,
134+
client,
135+
} => {
136+
DefaultAgentEngine::generate_embeddings(
137+
auth_token,
138+
url,
139+
client.clone(),
140+
model,
141+
data,
142+
)
143+
.await
144+
}
154145
}
155146
}
156147

157148
pub fn url(&self) -> Url {
158-
self.url.clone()
149+
match &self.agent {
150+
Agent::OpenAI { url, .. } => url.clone(),
151+
Agent::Default { url, .. } => url.clone(),
152+
}
159153
}
160154
}
161155

@@ -181,13 +175,3 @@ impl From<EmbeddingResponseBody> for wasi_llm::EmbeddingsResult {
181175
}
182176
}
183177
}
184-
185-
impl RemoteHttpLlmEngine {
186-
pub fn new(url: Url, auth_token: String) -> Self {
187-
RemoteHttpLlmEngine {
188-
url,
189-
auth_token,
190-
client: None,
191-
}
192-
}
193-
}

0 commit comments

Comments
 (0)