Skip to content

Commit 38c2c5d

Browse files
committed
WIP: added embeddings
Signed-off-by: Aminu Oluwaseun Joshua <[email protected]>
1 parent b770937 commit 38c2c5d

File tree

1 file changed

+120
-57
lines changed

1 file changed

+120
-57
lines changed

crates/llm-remote-http/src/open_ai.rs

Lines changed: 120 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use reqwest::{
77
use serde::Serialize;
88
use spin_world::v2::llm::{self as wasi_llm};
99

10-
use crate::InferResponseBody;
10+
use crate::{EmbeddingResponseBody, InferResponseBody};
1111

1212
pub(crate) struct OpenAIAgentEngine;
1313

@@ -43,11 +43,10 @@ impl OpenAIAgentEngine {
4343
content: prompt,
4444
}],
4545
model: model.into(),
46-
max_completion_tokens: todo!(),
47-
frequency_penalty: Some(params.repeat_penalty), // TODO: change to frequency_penalty
48-
reasoning_effort: todo!(),
49-
audio: todo!(),
50-
verbosity: todo!(),
46+
max_completion_tokens: Some(params.max_tokens),
47+
frequency_penalty: Some(params.repeat_penalty), // TODO: Joshua: change to frequency_penalty
48+
reasoning_effort: Some(ReasoningEffort::Medium),
49+
verbosity: Some(Verbosity::Low),
5150
};
5251

5352
let resp = client
@@ -69,13 +68,52 @@ impl OpenAIAgentEngine {
6968
}
7069

7170
pub async fn generate_embeddings(
72-
_auth_token: &str,
73-
_url: &Url,
74-
mut _client: Option<Client>,
75-
_model: wasi_llm::EmbeddingModel,
76-
_data: Vec<String>,
71+
auth_token: &str,
72+
url: &Url,
73+
mut client: Option<Client>,
74+
model: wasi_llm::EmbeddingModel,
75+
data: Vec<String>,
7776
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
78-
todo!("What's an embedding?")
77+
let client = client.get_or_insert_with(Default::default);
78+
79+
let mut headers = HeaderMap::new();
80+
headers.insert(
81+
"authorization",
82+
HeaderValue::from_str(&format!("bearer {}", auth_token)).map_err(|_| {
83+
wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
84+
})?,
85+
);
86+
spin_telemetry::inject_trace_context(&mut headers);
87+
88+
let body = CreateEmbeddingRequest {
89+
input: data,
90+
model: EmbeddingModel::Custom(model),
91+
encoding_format: None,
92+
dimensions: None,
93+
user: None,
94+
};
95+
96+
let resp = client
97+
.request(
98+
reqwest::Method::POST,
99+
url.join("/embeddings").map_err(|_| {
100+
wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
101+
})?,
102+
)
103+
.headers(headers)
104+
.body(body)
105+
.send()
106+
.await
107+
.map_err(|err| {
108+
wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
109+
})?;
110+
111+
match resp.json::<EmbeddingResponseBody>().await {
112+
Ok(val) => Ok(val.into()),
113+
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
114+
"Failed to deserialize response for \"POST /embed\": {err}"
115+
))),
116+
}
79117
}
80118
}
81119

@@ -84,15 +122,13 @@ struct CreateChatCompletionRequest {
84122
messages: Vec<Message>,
85123
model: Model,
86124
#[serde(skip_serializing_if = "Option::is_none")]
87-
pub max_completion_tokens: Option<u32>,
125+
max_completion_tokens: Option<u32>,
88126
#[serde(skip_serializing_if = "Option::is_none")]
89-
pub frequency_penalty: Option<f32>,
127+
frequency_penalty: Option<f32>,
90128
#[serde(skip_serializing_if = "Option::is_none")]
91-
pub reasoning_effort: Option<ReasoningEffort>,
129+
reasoning_effort: Option<ReasoningEffort>,
92130
#[serde(skip_serializing_if = "Option::is_none")]
93-
pub audio: Option<AudioOptions>,
94-
#[serde(skip_serializing_if = "Option::is_none")]
95-
pub verbosity: Option<Verbosity>,
131+
verbosity: Option<Verbosity>,
96132
}
97133

98134
impl From<CreateChatCompletionRequest> for Body {
@@ -101,53 +137,28 @@ impl From<CreateChatCompletionRequest> for Body {
101137
}
102138
}
103139

104-
#[derive(Serialize, Debug)]
105-
pub struct AudioOptions {
106-
pub voice: String,
107-
pub format: String,
108-
}
109-
110140
#[derive(Serialize, Debug)]
111141
enum Verbosity {
112142
Low,
113-
Medium,
114-
High,
143+
_Medium,
144+
_High,
115145
}
116146

117147
#[derive(Serialize, Debug)]
118148
enum ReasoningEffort {
119-
Minimal,
120-
Low,
149+
_Minimal,
150+
_Low,
121151
Medium,
122-
High,
152+
_High,
123153
}
124154

125155
impl Display for ReasoningEffort {
126156
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127157
match self {
128-
ReasoningEffort::Minimal => write!(f, "minimal"),
129-
ReasoningEffort::Low => write!(f, "low"),
158+
ReasoningEffort::_Minimal => write!(f, "minimal"),
159+
ReasoningEffort::_Low => write!(f, "low"),
130160
ReasoningEffort::Medium => write!(f, "medium"),
131-
ReasoningEffort::High => write!(f, "high"),
132-
}
133-
}
134-
}
135-
136-
#[derive(Serialize, Debug)]
137-
enum InputType {
138-
Text,
139-
Audio,
140-
Image,
141-
Video,
142-
}
143-
144-
impl Display for InputType {
145-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146-
match self {
147-
InputType::Text => write!(f, "text"),
148-
InputType::Audio => write!(f, "audio"),
149-
InputType::Image => write!(f, "image"),
150-
InputType::Video => write!(f, "video"),
161+
ReasoningEffort::_High => write!(f, "high"),
151162
}
152163
}
153164
}
@@ -221,19 +232,71 @@ struct Message {
221232

222233
#[derive(Serialize, Debug)]
223234
enum Role {
224-
System,
235+
_System,
225236
User,
226-
Assistant,
227-
Tool,
237+
_Assistant,
238+
_Tool,
228239
}
229240

230241
impl Display for Role {
231242
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232243
match self {
233-
Role::System => write!(f, "system"),
244+
Role::_System => write!(f, "system"),
234245
Role::User => write!(f, "user"),
235-
Role::Assistant => write!(f, "assistant"),
236-
Role::Tool => write!(f, "tool"),
246+
Role::_Assistant => write!(f, "assistant"),
247+
Role::_Tool => write!(f, "tool"),
248+
}
249+
}
250+
}
251+
252+
#[derive(Serialize, Debug)]
253+
pub struct CreateEmbeddingRequest {
254+
input: Vec<String>,
255+
model: EmbeddingModel,
256+
#[serde(skip_serializing_if = "Option::is_none")]
257+
encoding_format: Option<EncodingFormat>,
258+
#[serde(skip_serializing_if = "Option::is_none")]
259+
dimensions: Option<u32>,
260+
#[serde(skip_serializing_if = "Option::is_none")]
261+
user: Option<String>,
262+
}
263+
264+
impl From<CreateEmbeddingRequest> for Body {
265+
fn from(val: CreateEmbeddingRequest) -> Self {
266+
Body::from(serde_json::to_string(&val).unwrap())
267+
}
268+
}
269+
270+
#[derive(Serialize, Debug)]
271+
enum EncodingFormat {
272+
_Float,
273+
_Base64,
274+
}
275+
276+
impl Display for EncodingFormat {
277+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278+
match self {
279+
EncodingFormat::_Float => write!(f, "float"),
280+
EncodingFormat::_Base64 => write!(f, "base64"),
281+
}
282+
}
283+
}
284+
285+
#[derive(Serialize, Debug)]
286+
enum EmbeddingModel {
287+
_TextEmbeddingAda002,
288+
_TextEmbedding3Small,
289+
_TextEmbedding3Large,
290+
Custom(String),
291+
}
292+
293+
impl Display for EmbeddingModel {
294+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295+
match self {
296+
EmbeddingModel::_TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
297+
EmbeddingModel::_TextEmbedding3Small => write!(f, "text-embedding-3-small"),
298+
EmbeddingModel::_TextEmbedding3Large => write!(f, "text-embedding-3-large"),
299+
EmbeddingModel::Custom(model) => write!(f, "{model}"),
237300
}
238301
}
239302
}

0 commit comments

Comments
 (0)