Skip to content

Commit a2d3bc2

Browse files
committed
refactor RemoteHttpCompute config
CustomLlm is parsed directly, catches unsupported CustomLlm. Also introduces a new trait called LlmWorker which every LLM engine implements Signed-off-by: Aminu Oluwaseun Joshua <[email protected]>
1 parent 8a51cee commit a2d3bc2

File tree

4 files changed

+122
-140
lines changed

4 files changed

+122
-140
lines changed

crates/factor-llm/src/spin.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::path::PathBuf;
22
use std::sync::Arc;
33

44
use spin_factors::runtime_config::toml::GetTomlValue;
5-
use spin_llm_remote_http::RemoteHttpLlmEngine;
5+
use spin_llm_remote_http::{CustomLlm, RemoteHttpLlmEngine};
66
use spin_world::async_trait;
77
use spin_world::v1::llm::{self as v1};
88
use spin_world::v2::llm::{self as v2};
@@ -122,7 +122,7 @@ impl LlmCompute {
122122
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
123123
config.url,
124124
config.auth_token,
125-
config.custom_llm.and_then(|c| c.as_str().try_into().ok()),
125+
config.custom_llm,
126126
))),
127127
};
128128
Ok(engine)
@@ -133,7 +133,8 @@ impl LlmCompute {
133133
pub struct RemoteHttpCompute {
134134
url: Url,
135135
auth_token: String,
136-
custom_llm: Option<String>,
136+
#[serde(default)]
137+
custom_llm: CustomLlm,
137138
}
138139

139140
/// A noop engine used when the local engine feature is disabled.

crates/llm-remote-http/src/default.rs

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,43 @@ use reqwest::{
44
Client, Url,
55
};
66
use serde_json::json;
7-
use spin_world::v2::llm::{self as wasi_llm};
7+
use spin_world::{
8+
async_trait,
9+
v2::llm::{self as wasi_llm},
10+
};
811

9-
use crate::{EmbeddingResponseBody, InferRequestBodyParams, InferResponseBody};
12+
use crate::{EmbeddingResponseBody, InferRequestBodyParams, InferResponseBody, LlmWorker};
1013

11-
pub(crate) struct DefaultAgentEngine;
14+
pub(crate) struct DefaultAgentEngine {
15+
auth_token: String,
16+
url: Url,
17+
client: Option<Client>,
18+
}
1219

1320
impl DefaultAgentEngine {
14-
pub async fn infer(
15-
auth_token: &String,
16-
url: &Url,
17-
mut client: Option<Client>,
21+
pub fn new(auth_token: String, url: Url, client: Option<Client>) -> Self {
22+
Self {
23+
auth_token,
24+
url,
25+
client,
26+
}
27+
}
28+
}
29+
30+
#[async_trait]
31+
impl LlmWorker for DefaultAgentEngine {
32+
async fn infer(
33+
&mut self,
1834
model: wasi_llm::InferencingModel,
1935
prompt: String,
2036
params: wasi_llm::InferencingParams,
2137
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
22-
let client = client.get_or_insert_with(Default::default);
38+
let client = self.client.get_or_insert_with(Default::default);
2339

2440
let mut headers = HeaderMap::new();
2541
headers.insert(
2642
"authorization",
27-
HeaderValue::from_str(&format!("bearer {}", auth_token)).map_err(|_| {
43+
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
2844
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
2945
})?,
3046
);
@@ -45,7 +61,8 @@ impl DefaultAgentEngine {
4561
}))
4662
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
4763

48-
let infer_url = url
64+
let infer_url = self
65+
.url
4966
.join("/infer")
5067
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
5168
tracing::info!("Sending remote inference request to {infer_url}");
@@ -68,19 +85,17 @@ impl DefaultAgentEngine {
6885
}
6986
}
7087

71-
pub async fn generate_embeddings(
72-
auth_token: &String,
73-
url: &Url,
74-
mut client: Option<Client>,
88+
async fn generate_embeddings(
89+
&mut self,
7590
model: wasi_llm::EmbeddingModel,
7691
data: Vec<String>,
7792
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
78-
let client = client.get_or_insert_with(Default::default);
93+
let client = self.client.get_or_insert_with(Default::default);
7994

8095
let mut headers = HeaderMap::new();
8196
headers.insert(
8297
"authorization",
83-
HeaderValue::from_str(&format!("bearer {}", auth_token)).map_err(|_| {
98+
HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
8499
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
85100
})?,
86101
);
@@ -95,7 +110,7 @@ impl DefaultAgentEngine {
95110
let resp = client
96111
.request(
97112
reqwest::Method::POST,
98-
url.join("/embed").map_err(|_| {
113+
self.url.join("/embed").map_err(|_| {
99114
wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
100115
})?,
101116
)
@@ -114,4 +129,8 @@ impl DefaultAgentEngine {
114129
))),
115130
}
116131
}
132+
133+
fn url(&self) -> Url {
134+
self.url.clone()
135+
}
117136
}

crates/llm-remote-http/src/lib.rs

Lines changed: 38 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,47 @@
11
use anyhow::Result;
2-
use reqwest::{Client, Url};
2+
use reqwest::Url;
33
use serde::{Deserialize, Serialize};
4-
use spin_world::v2::llm::{self as wasi_llm};
5-
6-
use crate::{
7-
default::DefaultAgentEngine,
8-
open_ai::OpenAIAgentEngine,
9-
schema::{ChatCompletionChoice, Embedding},
4+
use spin_world::{
5+
async_trait,
6+
v2::llm::{self as wasi_llm},
107
};
118

9+
use crate::schema::{ChatCompletionChoice, Embedding};
10+
1211
mod default;
1312
mod open_ai;
1413
mod schema;
1514

16-
#[derive(Clone)]
17-
pub enum Agent {
18-
//TODO: Joshua: Naming??!
19-
Default {
20-
auth_token: String,
21-
url: Url,
22-
client: Option<Client>,
23-
},
24-
OpenAI {
25-
auth_token: String,
26-
url: Url,
27-
client: Option<Client>,
28-
},
15+
pub struct RemoteHttpLlmEngine {
16+
worker: Box<dyn LlmWorker>,
2917
}
3018

31-
impl Agent {
32-
pub fn from(url: Url, auth_token: String, agent: Option<CustomLlm>) -> Self {
33-
match agent {
34-
Some(CustomLlm::OpenAi) => Agent::OpenAI {
35-
auth_token,
36-
url,
37-
client: None,
38-
},
39-
_ => Agent::Default {
40-
auth_token,
41-
url,
42-
client: None,
43-
},
44-
}
19+
impl RemoteHttpLlmEngine {
20+
pub fn new(url: Url, auth_token: String, custom_llm: CustomLlm) -> Self {
21+
let worker: Box<dyn LlmWorker> = match custom_llm {
22+
CustomLlm::OpenAi => Box::new(open_ai::OpenAIAgentEngine::new(auth_token, url, None)),
23+
CustomLlm::Default => Box::new(default::DefaultAgentEngine::new(auth_token, url, None)),
24+
};
25+
Self { worker }
4526
}
4627
}
4728

48-
#[derive(Clone)]
49-
pub struct RemoteHttpLlmEngine {
50-
agent: Agent,
29+
#[async_trait]
30+
pub trait LlmWorker: Send + Sync {
31+
async fn infer(
32+
&mut self,
33+
model: wasi_llm::InferencingModel,
34+
prompt: String,
35+
params: wasi_llm::InferencingParams,
36+
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error>;
37+
38+
async fn generate_embeddings(
39+
&mut self,
40+
model: wasi_llm::EmbeddingModel,
41+
data: Vec<String>,
42+
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>;
43+
44+
fn url(&self) -> Url;
5145
}
5246

5347
#[derive(Serialize)]
@@ -150,74 +144,25 @@ struct OpenAIEmbeddingUsage {
150144
}
151145

152146
impl RemoteHttpLlmEngine {
153-
pub fn new(url: Url, auth_token: String, agent: Option<CustomLlm>) -> Self {
154-
RemoteHttpLlmEngine {
155-
agent: Agent::from(url, auth_token, agent),
156-
}
157-
}
158-
159147
pub async fn infer(
160148
&mut self,
161149
model: wasi_llm::InferencingModel,
162150
prompt: String,
163151
params: wasi_llm::InferencingParams,
164152
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
165-
match &self.agent {
166-
Agent::OpenAI {
167-
auth_token,
168-
url,
169-
client,
170-
} => {
171-
OpenAIAgentEngine::infer(auth_token, url, client.clone(), model, prompt, params)
172-
.await
173-
}
174-
Agent::Default {
175-
auth_token,
176-
url,
177-
client,
178-
} => {
179-
DefaultAgentEngine::infer(auth_token, url, client.clone(), model, prompt, params)
180-
.await
181-
}
182-
}
153+
self.worker.infer(model, prompt, params).await
183154
}
184155

185156
pub async fn generate_embeddings(
186157
&mut self,
187158
model: wasi_llm::EmbeddingModel,
188159
data: Vec<String>,
189160
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
190-
match &self.agent {
191-
Agent::OpenAI {
192-
auth_token,
193-
url,
194-
client,
195-
} => {
196-
OpenAIAgentEngine::generate_embeddings(auth_token, url, client.clone(), model, data)
197-
.await
198-
}
199-
Agent::Default {
200-
auth_token,
201-
url,
202-
client,
203-
} => {
204-
DefaultAgentEngine::generate_embeddings(
205-
auth_token,
206-
url,
207-
client.clone(),
208-
model,
209-
data,
210-
)
211-
.await
212-
}
213-
}
161+
self.worker.generate_embeddings(model, data).await
214162
}
215163

216164
pub fn url(&self) -> Url {
217-
match &self.agent {
218-
Agent::OpenAI { url, .. } => url.clone(),
219-
Agent::Default { url, .. } => url.clone(),
220-
}
165+
self.worker.url()
221166
}
222167
}
223168

@@ -267,19 +212,11 @@ impl From<CreateEmbeddingResponse> for wasi_llm::EmbeddingsResult {
267212
}
268213
}
269214

270-
#[derive(Debug, serde::Deserialize, PartialEq)]
215+
#[derive(Debug, Default, serde::Deserialize, PartialEq)]
216+
#[serde(rename_all = "snake_case")]
271217
pub enum CustomLlm {
272218
/// Compatible with OpenAI's API alongside some other LLMs
273219
OpenAi,
274-
}
275-
276-
impl TryFrom<&str> for CustomLlm {
277-
type Error = anyhow::Error;
278-
279-
fn try_from(value: &str) -> Result<Self, Self::Error> {
280-
match value.to_lowercase().as_str() {
281-
"open_ai" | "openai" => Ok(CustomLlm::OpenAi),
282-
_ => Err(anyhow::anyhow!("Invalid custom LLM: {}", value)),
283-
}
284-
}
220+
#[default]
221+
Default,
285222
}

0 commit comments

Comments
 (0)