11use anyhow:: Result ;
2- use reqwest:: {
3- header:: { HeaderMap , HeaderValue } ,
4- Client , Url ,
5- } ;
2+ use reqwest:: { Client , Url } ;
63use serde:: { Deserialize , Serialize } ;
7- use serde_json:: json;
84use spin_world:: v2:: llm:: { self as wasi_llm} ;
95
6+ use crate :: { default:: DefaultAgentEngine , open_ai:: OpenAIAgentEngine } ;
7+
8+ mod default;
109mod open_ai;
1110
11+ #[ derive( Clone ) ]
12+ pub enum Agent {
13+ //TODO: Joshua: Naming??!
14+ Default {
15+ auth_token : String ,
16+ url : Url ,
17+ client : Option < Client > ,
18+ } ,
19+ OpenAI {
20+ auth_token : String ,
21+ url : Url ,
22+ client : Option < Client > ,
23+ } ,
24+ }
25+
26+ impl Agent {
27+ pub fn from ( url : Url , auth_token : String , agent : Option < String > ) -> Self {
28+ match agent {
29+ Some ( agent_name) if agent_name == * "open_ai" => Agent :: OpenAI {
30+ auth_token,
31+ url,
32+ client : None ,
33+ } ,
34+ _ => Agent :: Default {
35+ auth_token,
36+ url,
37+ client : None ,
38+ } ,
39+ }
40+ }
41+ }
42+
1243#[ derive( Clone ) ]
1344pub struct RemoteHttpLlmEngine {
14- auth_token : String ,
15- url : Url ,
16- client : Option < Client > ,
45+ agent : Agent ,
1746}
1847
1948#[ derive( Serialize ) ]
@@ -53,59 +82,35 @@ struct EmbeddingResponseBody {
5382}
5483
5584impl RemoteHttpLlmEngine {
85+ pub fn new ( url : Url , auth_token : String , agent : Option < String > ) -> Self {
86+ RemoteHttpLlmEngine {
87+ agent : Agent :: from ( url, auth_token, agent) ,
88+ }
89+ }
90+
5691 pub async fn infer (
5792 & mut self ,
5893 model : wasi_llm:: InferencingModel ,
5994 prompt : String ,
6095 params : wasi_llm:: InferencingParams ,
6196 ) -> Result < wasi_llm:: InferencingResult , wasi_llm:: Error > {
62- let client = self . client . get_or_insert_with ( Default :: default) ;
63-
64- let mut headers = HeaderMap :: new ( ) ;
65- headers. insert (
66- "authorization" ,
67- HeaderValue :: from_str ( & format ! ( "bearer {}" , self . auth_token) ) . map_err ( |_| {
68- wasi_llm:: Error :: RuntimeError ( "Failed to create authorization header" . to_string ( ) )
69- } ) ?,
70- ) ;
71- spin_telemetry:: inject_trace_context ( & mut headers) ;
72-
73- let inference_options = InferRequestBodyParams {
74- max_tokens : params. max_tokens ,
75- repeat_penalty : params. repeat_penalty ,
76- repeat_penalty_last_n_token_count : params. repeat_penalty_last_n_token_count ,
77- temperature : params. temperature ,
78- top_k : params. top_k ,
79- top_p : params. top_p ,
80- } ;
81- let body = serde_json:: to_string ( & json ! ( {
82- "model" : model,
83- "prompt" : prompt,
84- "options" : inference_options
85- } ) )
86- . map_err ( |_| wasi_llm:: Error :: RuntimeError ( "Failed to serialize JSON" . to_string ( ) ) ) ?;
87-
88- let infer_url = self
89- . url
90- . join ( "/infer" )
91- . map_err ( |_| wasi_llm:: Error :: RuntimeError ( "Failed to create URL" . to_string ( ) ) ) ?;
92- tracing:: info!( "Sending remote inference request to {infer_url}" ) ;
93-
94- let resp = client
95- . request ( reqwest:: Method :: POST , infer_url)
96- . headers ( headers)
97- . body ( body)
98- . send ( )
99- . await
100- . map_err ( |err| {
101- wasi_llm:: Error :: RuntimeError ( format ! ( "POST /infer request error: {err}" ) )
102- } ) ?;
103-
104- match resp. json :: < InferResponseBody > ( ) . await {
105- Ok ( val) => Ok ( val. into ( ) ) ,
106- Err ( err) => Err ( wasi_llm:: Error :: RuntimeError ( format ! (
107- "Failed to deserialize response for \" POST /index\" : {err}"
108- ) ) ) ,
97+ match & self . agent {
98+ Agent :: OpenAI {
99+ auth_token,
100+ url,
101+ client,
102+ } => {
103+ OpenAIAgentEngine :: infer ( auth_token, url, client. clone ( ) , model, prompt, params)
104+ . await
105+ }
106+ Agent :: Default {
107+ auth_token,
108+ url,
109+ client,
110+ } => {
111+ DefaultAgentEngine :: infer ( auth_token, url, client. clone ( ) , model, prompt, params)
112+ . await
113+ }
109114 }
110115 }
111116
@@ -114,48 +119,37 @@ impl RemoteHttpLlmEngine {
114119 model : wasi_llm:: EmbeddingModel ,
115120 data : Vec < String > ,
116121 ) -> Result < wasi_llm:: EmbeddingsResult , wasi_llm:: Error > {
117- let client = self . client . get_or_insert_with ( Default :: default) ;
118-
119- let mut headers = HeaderMap :: new ( ) ;
120- headers. insert (
121- "authorization" ,
122- HeaderValue :: from_str ( & format ! ( "bearer {}" , self . auth_token) ) . map_err ( |_| {
123- wasi_llm:: Error :: RuntimeError ( "Failed to create authorization header" . to_string ( ) )
124- } ) ?,
125- ) ;
126- spin_telemetry:: inject_trace_context ( & mut headers) ;
127-
128- let body = serde_json:: to_string ( & json ! ( {
129- "model" : model,
130- "input" : data
131- } ) )
132- . map_err ( |_| wasi_llm:: Error :: RuntimeError ( "Failed to serialize JSON" . to_string ( ) ) ) ?;
133-
134- let resp = client
135- . request (
136- reqwest:: Method :: POST ,
137- self . url . join ( "/embed" ) . map_err ( |_| {
138- wasi_llm:: Error :: RuntimeError ( "Failed to create URL" . to_string ( ) )
139- } ) ?,
140- )
141- . headers ( headers)
142- . body ( body)
143- . send ( )
144- . await
145- . map_err ( |err| {
146- wasi_llm:: Error :: RuntimeError ( format ! ( "POST /embed request error: {err}" ) )
147- } ) ?;
148-
149- match resp. json :: < EmbeddingResponseBody > ( ) . await {
150- Ok ( val) => Ok ( val. into ( ) ) ,
151- Err ( err) => Err ( wasi_llm:: Error :: RuntimeError ( format ! (
152- "Failed to deserialize response for \" POST /embed\" : {err}"
153- ) ) ) ,
122+ match & self . agent {
123+ Agent :: OpenAI {
124+ auth_token,
125+ url,
126+ client,
127+ } => {
128+ OpenAIAgentEngine :: generate_embeddings ( auth_token, url, client. clone ( ) , model, data)
129+ . await
130+ }
131+ Agent :: Default {
132+ auth_token,
133+ url,
134+ client,
135+ } => {
136+ DefaultAgentEngine :: generate_embeddings (
137+ auth_token,
138+ url,
139+ client. clone ( ) ,
140+ model,
141+ data,
142+ )
143+ . await
144+ }
154145 }
155146 }
156147
157148 pub fn url ( & self ) -> Url {
158- self . url . clone ( )
149+ match & self . agent {
150+ Agent :: OpenAI { url, .. } => url. clone ( ) ,
151+ Agent :: Default { url, .. } => url. clone ( ) ,
152+ }
159153 }
160154}
161155
@@ -181,13 +175,3 @@ impl From<EmbeddingResponseBody> for wasi_llm::EmbeddingsResult {
181175 }
182176 }
183177}
184-
185- impl RemoteHttpLlmEngine {
186- pub fn new ( url : Url , auth_token : String ) -> Self {
187- RemoteHttpLlmEngine {
188- url,
189- auth_token,
190- client : None ,
191- }
192- }
193- }
0 commit comments