@@ -7,7 +7,7 @@ use reqwest::{
77use serde:: Serialize ;
88use spin_world:: v2:: llm:: { self as wasi_llm} ;
99
10- use crate :: InferResponseBody ;
10+ use crate :: { EmbeddingResponseBody , InferResponseBody } ;
1111
1212pub ( crate ) struct OpenAIAgentEngine ;
1313
@@ -43,11 +43,10 @@ impl OpenAIAgentEngine {
4343 content: prompt,
4444 } ] ,
4545 model : model. into ( ) ,
46- max_completion_tokens : todo ! ( ) ,
47- frequency_penalty : Some ( params. repeat_penalty ) , // TODO: change to frequency_penalty
48- reasoning_effort : todo ! ( ) ,
49- audio : todo ! ( ) ,
50- verbosity : todo ! ( ) ,
46+ max_completion_tokens : Some ( params. max_tokens ) ,
47+ frequency_penalty : Some ( params. repeat_penalty ) , // TODO: Joshua: change to frequency_penalty
48+ reasoning_effort : Some ( ReasoningEffort :: Medium ) ,
49+ verbosity : Some ( Verbosity :: Low ) ,
5150 } ;
5251
5352 let resp = client
@@ -69,13 +68,52 @@ impl OpenAIAgentEngine {
6968 }
7069
7170 pub async fn generate_embeddings (
72- _auth_token : & str ,
73- _url : & Url ,
74- mut _client : Option < Client > ,
75- _model : wasi_llm:: EmbeddingModel ,
76- _data : Vec < String > ,
71+ auth_token : & str ,
72+ url : & Url ,
73+ mut client : Option < Client > ,
74+ model : wasi_llm:: EmbeddingModel ,
75+ data : Vec < String > ,
7776 ) -> Result < wasi_llm:: EmbeddingsResult , wasi_llm:: Error > {
78- todo ! ( "What's an embedding?" )
77+ let client = client. get_or_insert_with ( Default :: default) ;
78+
79+ let mut headers = HeaderMap :: new ( ) ;
80+ headers. insert (
81+ "authorization" ,
82+ HeaderValue :: from_str ( & format ! ( "bearer {}" , auth_token) ) . map_err ( |_| {
83+ wasi_llm:: Error :: RuntimeError ( "Failed to create authorization header" . to_string ( ) )
84+ } ) ?,
85+ ) ;
86+ spin_telemetry:: inject_trace_context ( & mut headers) ;
87+
88+ let body = CreateEmbeddingRequest {
89+ input : data,
90+ model : EmbeddingModel :: Custom ( model) ,
91+ encoding_format : None ,
92+ dimensions : None ,
93+ user : None ,
94+ } ;
95+
96+ let resp = client
97+ . request (
98+ reqwest:: Method :: POST ,
99+ url. join ( "/embeddings" ) . map_err ( |_| {
100+ wasi_llm:: Error :: RuntimeError ( "Failed to create URL" . to_string ( ) )
101+ } ) ?,
102+ )
103+ . headers ( headers)
104+ . body ( body)
105+ . send ( )
106+ . await
107+ . map_err ( |err| {
108+ wasi_llm:: Error :: RuntimeError ( format ! ( "POST /embed request error: {err}" ) )
109+ } ) ?;
110+
111+ match resp. json :: < EmbeddingResponseBody > ( ) . await {
112+ Ok ( val) => Ok ( val. into ( ) ) ,
113+ Err ( err) => Err ( wasi_llm:: Error :: RuntimeError ( format ! (
114+ "Failed to deserialize response for \" POST /embed\" : {err}"
115+ ) ) ) ,
116+ }
79117 }
80118}
81119
@@ -84,15 +122,13 @@ struct CreateChatCompletionRequest {
84122 messages : Vec < Message > ,
85123 model : Model ,
86124 #[ serde( skip_serializing_if = "Option::is_none" ) ]
87- pub max_completion_tokens : Option < u32 > ,
125+ max_completion_tokens : Option < u32 > ,
88126 #[ serde( skip_serializing_if = "Option::is_none" ) ]
89- pub frequency_penalty : Option < f32 > ,
127+ frequency_penalty : Option < f32 > ,
90128 #[ serde( skip_serializing_if = "Option::is_none" ) ]
91- pub reasoning_effort : Option < ReasoningEffort > ,
129+ reasoning_effort : Option < ReasoningEffort > ,
92130 #[ serde( skip_serializing_if = "Option::is_none" ) ]
93- pub audio : Option < AudioOptions > ,
94- #[ serde( skip_serializing_if = "Option::is_none" ) ]
95- pub verbosity : Option < Verbosity > ,
131+ verbosity : Option < Verbosity > ,
96132}
97133
98134impl From < CreateChatCompletionRequest > for Body {
@@ -101,53 +137,28 @@ impl From<CreateChatCompletionRequest> for Body {
101137 }
102138}
103139
104- #[ derive( Serialize , Debug ) ]
105- pub struct AudioOptions {
106- pub voice : String ,
107- pub format : String ,
108- }
109-
110140#[ derive( Serialize , Debug ) ]
111141enum Verbosity {
112142 Low ,
113- Medium ,
114- High ,
143+ _Medium ,
144+ _High ,
115145}
116146
117147#[ derive( Serialize , Debug ) ]
118148enum ReasoningEffort {
119- Minimal ,
120- Low ,
149+ _Minimal ,
150+ _Low ,
121151 Medium ,
122- High ,
152+ _High ,
123153}
124154
125155impl Display for ReasoningEffort {
126156 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
127157 match self {
128- ReasoningEffort :: Minimal => write ! ( f, "minimal" ) ,
129- ReasoningEffort :: Low => write ! ( f, "low" ) ,
158+ ReasoningEffort :: _Minimal => write ! ( f, "minimal" ) ,
159+ ReasoningEffort :: _Low => write ! ( f, "low" ) ,
130160 ReasoningEffort :: Medium => write ! ( f, "medium" ) ,
131- ReasoningEffort :: High => write ! ( f, "high" ) ,
132- }
133- }
134- }
135-
136- #[ derive( Serialize , Debug ) ]
137- enum InputType {
138- Text ,
139- Audio ,
140- Image ,
141- Video ,
142- }
143-
144- impl Display for InputType {
145- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
146- match self {
147- InputType :: Text => write ! ( f, "text" ) ,
148- InputType :: Audio => write ! ( f, "audio" ) ,
149- InputType :: Image => write ! ( f, "image" ) ,
150- InputType :: Video => write ! ( f, "video" ) ,
161+ ReasoningEffort :: _High => write ! ( f, "high" ) ,
151162 }
152163 }
153164}
@@ -221,19 +232,71 @@ struct Message {
221232
222233#[ derive( Serialize , Debug ) ]
223234enum Role {
224- System ,
235+ _System ,
225236 User ,
226- Assistant ,
227- Tool ,
237+ _Assistant ,
238+ _Tool ,
228239}
229240
230241impl Display for Role {
231242 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
232243 match self {
233- Role :: System => write ! ( f, "system" ) ,
244+ Role :: _System => write ! ( f, "system" ) ,
234245 Role :: User => write ! ( f, "user" ) ,
235- Role :: Assistant => write ! ( f, "assistant" ) ,
236- Role :: Tool => write ! ( f, "tool" ) ,
246+ Role :: _Assistant => write ! ( f, "assistant" ) ,
247+ Role :: _Tool => write ! ( f, "tool" ) ,
248+ }
249+ }
250+ }
251+
252+ #[ derive( Serialize , Debug ) ]
253+ pub struct CreateEmbeddingRequest {
254+ input : Vec < String > ,
255+ model : EmbeddingModel ,
256+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
257+ encoding_format : Option < EncodingFormat > ,
258+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
259+ dimensions : Option < u32 > ,
260+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
261+ user : Option < String > ,
262+ }
263+
264+ impl From < CreateEmbeddingRequest > for Body {
265+ fn from ( val : CreateEmbeddingRequest ) -> Self {
266+ Body :: from ( serde_json:: to_string ( & val) . unwrap ( ) )
267+ }
268+ }
269+
270+ #[ derive( Serialize , Debug ) ]
271+ enum EncodingFormat {
272+ _Float,
273+ _Base64,
274+ }
275+
276+ impl Display for EncodingFormat {
277+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
278+ match self {
279+ EncodingFormat :: _Float => write ! ( f, "float" ) ,
280+ EncodingFormat :: _Base64 => write ! ( f, "base64" ) ,
281+ }
282+ }
283+ }
284+
285+ #[ derive( Serialize , Debug ) ]
286+ enum EmbeddingModel {
287+ _TextEmbeddingAda002,
288+ _TextEmbedding3Small,
289+ _TextEmbedding3Large,
290+ Custom ( String ) ,
291+ }
292+
293+ impl Display for EmbeddingModel {
294+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
295+ match self {
296+ EmbeddingModel :: _TextEmbeddingAda002 => write ! ( f, "text-embedding-ada-002" ) ,
297+ EmbeddingModel :: _TextEmbedding3Small => write ! ( f, "text-embedding-3-small" ) ,
298+ EmbeddingModel :: _TextEmbedding3Large => write ! ( f, "text-embedding-3-large" ) ,
299+ EmbeddingModel :: Custom ( model) => write ! ( f, "{model}" ) ,
237300 }
238301 }
239302}
0 commit comments