Skip to content

Commit b552796

Browse files
committed
Handles response
Signed-off-by: Aminu Oluwaseun Joshua <[email protected]>
1 parent 543b0e7 commit b552796

File tree

4 files changed

+142
-16
lines changed

4 files changed

+142
-16
lines changed

crates/factor-llm/src/spin.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl LlmCompute {
122122
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
123123
config.url,
124124
config.auth_token,
125-
config.agent,
125+
config.custom_llm.and_then(|c| c.as_str().try_into().ok()),
126126
))),
127127
};
128128
Ok(engine)
@@ -133,7 +133,7 @@ impl LlmCompute {
133133
pub struct RemoteHttpCompute {
134134
url: Url,
135135
auth_token: String,
136-
agent: Option<String>,
136+
custom_llm: Option<String>,
137137
}
138138

139139
/// A noop engine used when the local engine feature is disabled.

crates/llm-remote-http/src/lib.rs

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ use reqwest::{Client, Url};
33
use serde::{Deserialize, Serialize};
44
use spin_world::v2::llm::{self as wasi_llm};
55

6-
use crate::{default::DefaultAgentEngine, open_ai::OpenAIAgentEngine};
6+
use crate::{
7+
default::DefaultAgentEngine,
8+
open_ai::OpenAIAgentEngine,
9+
schema::{ChatCompletionChoice, Embedding},
10+
};
711

812
mod default;
913
mod open_ai;
@@ -25,9 +29,9 @@ pub enum Agent {
2529
}
2630

2731
impl Agent {
28-
pub fn from(url: Url, auth_token: String, agent: Option<String>) -> Self {
32+
pub fn from(url: Url, auth_token: String, agent: Option<CustomLlm>) -> Self {
2933
match agent {
30-
Some(agent_name) if agent_name == *"open_ai" => Agent::OpenAI {
34+
Some(CustomLlm::OpenAi) => Agent::OpenAI {
3135
auth_token,
3236
url,
3337
client: None,
@@ -70,6 +74,23 @@ struct InferResponseBody {
7074
usage: InferUsage,
7175
}
7276

77+
#[derive(Deserialize)]
78+
struct CreateChatCompletionResponse {
79+
_id: String,
80+
_object: String,
81+
_created: u64,
82+
_model: String,
83+
choices: Vec<ChatCompletionChoice>,
84+
usage: CompletionUsage,
85+
}
86+
87+
#[derive(Deserialize)]
88+
struct CompletionUsage {
89+
completion_tokens: u32,
90+
prompt_tokens: u32,
91+
_total_tokens: u32,
92+
}
93+
7394
#[derive(Deserialize)]
7495
#[serde(rename_all(deserialize = "camelCase"))]
7596
struct EmbeddingUsage {
@@ -82,8 +103,31 @@ struct EmbeddingResponseBody {
82103
usage: EmbeddingUsage,
83104
}
84105

106+
#[derive(Deserialize)]
107+
struct CreateEmbeddingResponse {
108+
_object: String,
109+
_model: String,
110+
data: Vec<Embedding>,
111+
usage: OpenAIEmbeddingUsage,
112+
}
113+
114+
impl CreateEmbeddingResponse {
115+
fn embeddings(&self) -> Vec<Vec<f32>> {
116+
self.data
117+
.iter()
118+
.map(|embedding| embedding.embedding.clone())
119+
.collect()
120+
}
121+
}
122+
123+
#[derive(Deserialize)]
124+
struct OpenAIEmbeddingUsage {
125+
prompt_tokens: u32,
126+
_total_tokens: u32,
127+
}
128+
85129
impl RemoteHttpLlmEngine {
86-
pub fn new(url: Url, auth_token: String, agent: Option<String>) -> Self {
130+
pub fn new(url: Url, auth_token: String, agent: Option<CustomLlm>) -> Self {
87131
RemoteHttpLlmEngine {
88132
agent: Agent::from(url, auth_token, agent),
89133
}
@@ -166,6 +210,18 @@ impl From<InferResponseBody> for wasi_llm::InferencingResult {
166210
}
167211
}
168212

213+
impl From<CreateChatCompletionResponse> for wasi_llm::InferencingResult {
214+
fn from(value: CreateChatCompletionResponse) -> Self {
215+
Self {
216+
text: value.choices[0].message.content.clone(),
217+
usage: wasi_llm::InferencingUsage {
218+
prompt_token_count: value.usage.prompt_tokens,
219+
generated_token_count: value.usage.completion_tokens,
220+
},
221+
}
222+
}
223+
}
224+
169225
impl From<EmbeddingResponseBody> for wasi_llm::EmbeddingsResult {
170226
fn from(value: EmbeddingResponseBody) -> Self {
171227
Self {
@@ -176,3 +232,30 @@ impl From<EmbeddingResponseBody> for wasi_llm::EmbeddingsResult {
176232
}
177233
}
178234
}
235+
236+
impl From<CreateEmbeddingResponse> for wasi_llm::EmbeddingsResult {
237+
fn from(value: CreateEmbeddingResponse) -> Self {
238+
Self {
239+
embeddings: value.embeddings(),
240+
usage: wasi_llm::EmbeddingsUsage {
241+
prompt_token_count: value.usage.prompt_tokens,
242+
},
243+
}
244+
}
245+
}
246+
247+
#[derive(Debug, serde::Deserialize, PartialEq)]
248+
pub enum CustomLlm {
249+
OpenAi,
250+
}
251+
252+
impl TryFrom<&str> for CustomLlm {
253+
type Error = anyhow::Error;
254+
255+
fn try_from(value: &str) -> Result<Self, Self::Error> {
256+
match value.to_lowercase().as_str() {
257+
"open_ai" | "openai" => Ok(CustomLlm::OpenAi),
258+
_ => Err(anyhow::anyhow!("Invalid custom LLM: {}", value)),
259+
}
260+
}
261+
}

crates/llm-remote-http/src/open_ai.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use serde::Serialize;
66
use spin_world::v2::llm::{self as wasi_llm};
77

88
use crate::{
9-
schema::{EmbeddingModels, EncodingFormat, Message, Model, Role},
10-
EmbeddingResponseBody, InferResponseBody,
9+
schema::{EmbeddingModels, EncodingFormat, Model, Prompt, Role},
10+
CreateChatCompletionResponse, EmbeddingResponseBody,
1111
};
1212

1313
pub(crate) struct OpenAIAgentEngine;
@@ -39,11 +39,11 @@ impl OpenAIAgentEngine {
3939
tracing::info!("Sending remote inference request to {chat_url}");
4040

4141
let body = CreateChatCompletionRequest {
42-
// TODO: Joshua: make Role customizable
43-
messages: vec![Message::new(Role::User, prompt)],
42+
// TODO: Make Role customizable
43+
messages: vec![Prompt::new(Role::User, prompt)],
4444
model: model.as_str().try_into()?,
4545
max_completion_tokens: Some(params.max_tokens),
46-
frequency_penalty: Some(params.repeat_penalty), // TODO: Joshua: change to frequency_penalty
46+
frequency_penalty: Some(params.repeat_penalty),
4747
reasoning_effort: None,
4848
verbosity: None,
4949
};
@@ -58,7 +58,7 @@ impl OpenAIAgentEngine {
5858
wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
5959
})?;
6060

61-
match resp.json::<InferResponseBody>().await {
61+
match resp.json::<CreateChatCompletionResponse>().await {
6262
Ok(val) => Ok(val.into()),
6363
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
6464
"Failed to deserialize response for \"POST /index\": {err}"
@@ -118,7 +118,7 @@ impl OpenAIAgentEngine {
118118

119119
#[derive(Serialize, Debug)]
120120
struct CreateChatCompletionRequest {
121-
messages: Vec<Message>,
121+
messages: Vec<Prompt>,
122122
model: Model,
123123
#[serde(skip_serializing_if = "Option::is_none")]
124124
max_completion_tokens: Option<u32>,

crates/llm-remote-http/src/schema.rs

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use std::fmt::Display;
22

3-
use serde::Serialize;
3+
use serde::{Deserialize, Serialize};
44
use spin_world::v2::llm as wasi_llm;
55

6+
/// LLM model
67
#[derive(Serialize, Debug)]
78
pub enum Model {
89
GPT5,
@@ -69,12 +70,12 @@ impl Display for Model {
6970
}
7071

7172
#[derive(Serialize, Debug)]
72-
pub struct Message {
73+
pub struct Prompt {
7374
role: Role,
7475
content: String,
7576
}
7677

77-
impl Message {
78+
impl Prompt {
7879
pub fn new(role: Role, content: String) -> Self {
7980
Self { role, content }
8081
}
@@ -232,3 +233,45 @@ impl TryFrom<&str> for Verbosity {
232233
}
233234
}
234235
}
236+
237+
#[derive(Deserialize)]
238+
pub struct ChatCompletionChoice {
239+
/// The index of the choice in the list of choices
240+
_index: u32,
241+
pub message: ChatCompletionResponseMessage,
242+
/// The reason the model stopped generating tokens. This will be `stop` if the model hit a
243+
/// natural stop point or a provided stop sequence,
244+
_finish_reason: String,
245+
/// Log probability information for the choice.
246+
_logprobs: Option<Logprobs>,
247+
}
248+
249+
#[derive(Deserialize)]
250+
/// A chat completion message generated by the model.
251+
pub struct ChatCompletionResponseMessage {
252+
/// The role of the author of this message
253+
_role: String,
254+
/// The contents of the message
255+
pub content: String,
256+
/// The refusal message generated by the model
257+
_refusal: Option<String>,
258+
}
259+
260+
#[derive(Deserialize)]
261+
pub struct Logprobs {
262+
/// A list of message content tokens with log probability information.
263+
_content: Option<Vec<String>>,
264+
/// A list of message refusal tokens with log probability information.
265+
_refusal: Option<Vec<String>>,
266+
}
267+
268+
#[derive(Deserialize)]
269+
pub struct Embedding {
270+
/// The index of the embedding in the list of embeddings..
271+
_index: u32,
272+
/// The embedding vector, which is a list of floats. The length of vector depends on the model as
273+
/// listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
274+
pub embedding: Vec<f32>,
275+
/// The object type, which is always "embedding"
276+
_object: String,
277+
}

0 commit comments

Comments
 (0)