|
1 | 1 | use anyhow::Result; |
2 | | -use reqwest::{Client, Url}; |
| 2 | +use reqwest::Url; |
3 | 3 | use serde::{Deserialize, Serialize}; |
4 | | -use spin_world::v2::llm::{self as wasi_llm}; |
5 | | - |
6 | | -use crate::{ |
7 | | - default::DefaultAgentEngine, |
8 | | - open_ai::OpenAIAgentEngine, |
9 | | - schema::{ChatCompletionChoice, Embedding}, |
| 4 | +use spin_world::{ |
| 5 | + async_trait, |
| 6 | + v2::llm::{self as wasi_llm}, |
10 | 7 | }; |
11 | 8 |
|
| 9 | +use crate::schema::{ChatCompletionChoice, Embedding}; |
| 10 | + |
12 | 11 | mod default; |
13 | 12 | mod open_ai; |
14 | 13 | mod schema; |
15 | 14 |
|
16 | | -#[derive(Clone)] |
17 | | -pub enum Agent { |
18 | | - //TODO: Joshua: Naming??! |
19 | | - Default { |
20 | | - auth_token: String, |
21 | | - url: Url, |
22 | | - client: Option<Client>, |
23 | | - }, |
24 | | - OpenAI { |
25 | | - auth_token: String, |
26 | | - url: Url, |
27 | | - client: Option<Client>, |
28 | | - }, |
| 15 | +pub struct RemoteHttpLlmEngine { |
| 16 | + worker: Box<dyn LlmWorker>, |
29 | 17 | } |
30 | 18 |
|
31 | | -impl Agent { |
32 | | - pub fn from(url: Url, auth_token: String, agent: Option<CustomLlm>) -> Self { |
33 | | - match agent { |
34 | | - Some(CustomLlm::OpenAi) => Agent::OpenAI { |
35 | | - auth_token, |
36 | | - url, |
37 | | - client: None, |
38 | | - }, |
39 | | - _ => Agent::Default { |
40 | | - auth_token, |
41 | | - url, |
42 | | - client: None, |
43 | | - }, |
44 | | - } |
| 19 | +impl RemoteHttpLlmEngine { |
| 20 | + pub fn new(url: Url, auth_token: String, custom_llm: CustomLlm) -> Self { |
| 21 | + let worker: Box<dyn LlmWorker> = match custom_llm { |
| 22 | + CustomLlm::OpenAi => Box::new(open_ai::OpenAIAgentEngine::new(auth_token, url, None)), |
| 23 | + CustomLlm::Default => Box::new(default::DefaultAgentEngine::new(auth_token, url, None)), |
| 24 | + }; |
| 25 | + Self { worker } |
45 | 26 | } |
46 | 27 | } |
47 | 28 |
|
48 | | -#[derive(Clone)] |
49 | | -pub struct RemoteHttpLlmEngine { |
50 | | - agent: Agent, |
| 29 | +#[async_trait] |
| 30 | +pub trait LlmWorker: Send + Sync { |
| 31 | + async fn infer( |
| 32 | + &mut self, |
| 33 | + model: wasi_llm::InferencingModel, |
| 34 | + prompt: String, |
| 35 | + params: wasi_llm::InferencingParams, |
| 36 | + ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error>; |
| 37 | + |
| 38 | + async fn generate_embeddings( |
| 39 | + &mut self, |
| 40 | + model: wasi_llm::EmbeddingModel, |
| 41 | + data: Vec<String>, |
| 42 | + ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>; |
| 43 | + |
| 44 | + fn url(&self) -> Url; |
51 | 45 | } |
52 | 46 |
|
53 | 47 | #[derive(Serialize)] |
@@ -150,74 +144,25 @@ struct OpenAIEmbeddingUsage { |
150 | 144 | } |
151 | 145 |
|
152 | 146 | impl RemoteHttpLlmEngine { |
153 | | - pub fn new(url: Url, auth_token: String, agent: Option<CustomLlm>) -> Self { |
154 | | - RemoteHttpLlmEngine { |
155 | | - agent: Agent::from(url, auth_token, agent), |
156 | | - } |
157 | | - } |
158 | | - |
159 | 147 | pub async fn infer( |
160 | 148 | &mut self, |
161 | 149 | model: wasi_llm::InferencingModel, |
162 | 150 | prompt: String, |
163 | 151 | params: wasi_llm::InferencingParams, |
164 | 152 | ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> { |
165 | | - match &self.agent { |
166 | | - Agent::OpenAI { |
167 | | - auth_token, |
168 | | - url, |
169 | | - client, |
170 | | - } => { |
171 | | - OpenAIAgentEngine::infer(auth_token, url, client.clone(), model, prompt, params) |
172 | | - .await |
173 | | - } |
174 | | - Agent::Default { |
175 | | - auth_token, |
176 | | - url, |
177 | | - client, |
178 | | - } => { |
179 | | - DefaultAgentEngine::infer(auth_token, url, client.clone(), model, prompt, params) |
180 | | - .await |
181 | | - } |
182 | | - } |
| 153 | + self.worker.infer(model, prompt, params).await |
183 | 154 | } |
184 | 155 |
|
185 | 156 | pub async fn generate_embeddings( |
186 | 157 | &mut self, |
187 | 158 | model: wasi_llm::EmbeddingModel, |
188 | 159 | data: Vec<String>, |
189 | 160 | ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> { |
190 | | - match &self.agent { |
191 | | - Agent::OpenAI { |
192 | | - auth_token, |
193 | | - url, |
194 | | - client, |
195 | | - } => { |
196 | | - OpenAIAgentEngine::generate_embeddings(auth_token, url, client.clone(), model, data) |
197 | | - .await |
198 | | - } |
199 | | - Agent::Default { |
200 | | - auth_token, |
201 | | - url, |
202 | | - client, |
203 | | - } => { |
204 | | - DefaultAgentEngine::generate_embeddings( |
205 | | - auth_token, |
206 | | - url, |
207 | | - client.clone(), |
208 | | - model, |
209 | | - data, |
210 | | - ) |
211 | | - .await |
212 | | - } |
213 | | - } |
| 161 | + self.worker.generate_embeddings(model, data).await |
214 | 162 | } |
215 | 163 |
|
216 | 164 | pub fn url(&self) -> Url { |
217 | | - match &self.agent { |
218 | | - Agent::OpenAI { url, .. } => url.clone(), |
219 | | - Agent::Default { url, .. } => url.clone(), |
220 | | - } |
| 165 | + self.worker.url() |
221 | 166 | } |
222 | 167 | } |
223 | 168 |
|
@@ -267,19 +212,11 @@ impl From<CreateEmbeddingResponse> for wasi_llm::EmbeddingsResult { |
267 | 212 | } |
268 | 213 | } |
269 | 214 |
|
270 | | -#[derive(Debug, serde::Deserialize, PartialEq)] |
| 215 | +#[derive(Debug, Default, serde::Deserialize, PartialEq)] |
| 216 | +#[serde(rename_all = "snake_case")] |
271 | 217 | pub enum CustomLlm { |
272 | 218 | /// Compatible with OpenAI's API alongside some other LLMs |
273 | 219 | OpenAi, |
274 | | -} |
275 | | - |
276 | | -impl TryFrom<&str> for CustomLlm { |
277 | | - type Error = anyhow::Error; |
278 | | - |
279 | | - fn try_from(value: &str) -> Result<Self, Self::Error> { |
280 | | - match value.to_lowercase().as_str() { |
281 | | - "open_ai" | "openai" => Ok(CustomLlm::OpenAi), |
282 | | - _ => Err(anyhow::anyhow!("Invalid custom LLM: {}", value)), |
283 | | - } |
284 | | - } |
| 220 | + #[default] |
| 221 | + Default, |
285 | 222 | } |
0 commit comments