Skip to content

Commit f30a6cb

Browse files
committed
added example + improved deserialisation + reability
Signed-off-by: Aminu Oluwaseun Joshua <[email protected]>
1 parent b552796 commit f30a6cb

File tree

9 files changed

+1002
-86
lines changed

9 files changed

+1002
-86
lines changed

crates/llm-remote-http/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,17 @@ struct InferResponseBody {
7676

7777
#[derive(Deserialize)]
7878
struct CreateChatCompletionResponse {
79+
#[serde(rename = "id")]
7980
_id: String,
81+
#[serde(rename = "object")]
8082
_object: String,
83+
#[serde(rename = "created")]
8184
_created: u64,
85+
#[serde(rename = "model")]
8286
_model: String,
87+
#[serde(rename = "choices")]
8388
choices: Vec<ChatCompletionChoice>,
89+
#[serde(rename = "usage")]
8490
usage: CompletionUsage,
8591
}
8692

crates/llm-remote-http/src/open_ai.rs

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use reqwest::{
22
header::{HeaderMap, HeaderValue},
3-
Body, Client, Url,
3+
Client, Url,
44
};
55
use serde::Serialize;
66
use spin_world::v2::llm::{self as wasi_llm};
77

88
use crate::{
9-
schema::{EmbeddingModels, EncodingFormat, Model, Prompt, Role},
10-
CreateChatCompletionResponse, EmbeddingResponseBody,
9+
schema::{EmbeddingModels, EncodingFormat, Model, Prompt, ResponseError, Role},
10+
CreateChatCompletionResponse, CreateEmbeddingResponse,
1111
};
1212

1313
pub(crate) struct OpenAIAgentEngine;
@@ -33,7 +33,7 @@ impl OpenAIAgentEngine {
3333
spin_telemetry::inject_trace_context(&mut headers);
3434

3535
let chat_url = url
36-
.join("/chat/completions")
36+
.join("/v1/chat/completions")
3737
.map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
3838

3939
tracing::info!("Sending remote inference request to {chat_url}");
@@ -51,15 +51,16 @@ impl OpenAIAgentEngine {
5151
let resp = client
5252
.request(reqwest::Method::POST, chat_url)
5353
.headers(headers)
54-
.body(body)
54+
.json(&body)
5555
.send()
5656
.await
5757
.map_err(|err| {
5858
wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
5959
})?;
6060

61-
match resp.json::<CreateChatCompletionResponse>().await {
62-
Ok(val) => Ok(val.into()),
61+
match resp.json::<CreateChatCompletionResponses>().await {
62+
Ok(CreateChatCompletionResponses::Success(val)) => Ok(val.into()),
63+
Ok(CreateChatCompletionResponses::Error { error }) => Err(error.into()),
6364
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
6465
"Failed to deserialize response for \"POST /index\": {err}"
6566
))),
@@ -95,20 +96,21 @@ impl OpenAIAgentEngine {
9596
let resp = client
9697
.request(
9798
reqwest::Method::POST,
98-
url.join("/embeddings").map_err(|_| {
99+
url.join("/v1/embeddings").map_err(|_| {
99100
wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
100101
})?,
101102
)
102103
.headers(headers)
103-
.body(body)
104+
.json(&body)
104105
.send()
105106
.await
106107
.map_err(|err| {
107108
wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
108109
})?;
109110

110-
match resp.json::<EmbeddingResponseBody>().await {
111-
Ok(val) => Ok(val.into()),
111+
match resp.json::<CreateEmbeddingResponses>().await {
112+
Ok(CreateEmbeddingResponses::Success(val)) => Ok(val.into()),
113+
Ok(CreateEmbeddingResponses::Error { error }) => Err(error.into()),
112114
Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
113115
"Failed to deserialize response for \"POST /embed\": {err}"
114116
))),
@@ -130,12 +132,6 @@ struct CreateChatCompletionRequest {
130132
verbosity: Option<String>,
131133
}
132134

133-
impl From<CreateChatCompletionRequest> for Body {
134-
fn from(val: CreateChatCompletionRequest) -> Self {
135-
Body::from(serde_json::to_string(&val).unwrap())
136-
}
137-
}
138-
139135
#[derive(Serialize, Debug)]
140136
pub struct CreateEmbeddingRequest {
141137
input: Vec<String>,
@@ -148,8 +144,16 @@ pub struct CreateEmbeddingRequest {
148144
user: Option<String>,
149145
}
150146

151-
impl From<CreateEmbeddingRequest> for Body {
152-
fn from(val: CreateEmbeddingRequest) -> Self {
153-
Body::from(serde_json::to_string(&val).unwrap())
154-
}
147+
#[derive(serde::Deserialize)]
148+
#[serde(untagged)]
149+
enum CreateChatCompletionResponses {
150+
Success(CreateChatCompletionResponse),
151+
Error { error: ResponseError },
152+
}
153+
154+
#[derive(serde::Deserialize)]
155+
#[serde(untagged)]
156+
enum CreateEmbeddingResponses {
157+
Success(CreateEmbeddingResponse),
158+
Error { error: ResponseError },
155159
}

crates/llm-remote-http/src/schema.rs

Lines changed: 53 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,36 @@
1-
use std::fmt::Display;
2-
31
use serde::{Deserialize, Serialize};
42
use spin_world::v2::llm as wasi_llm;
53

64
/// LLM model
75
#[derive(Serialize, Debug)]
86
pub enum Model {
7+
#[serde(rename = "gpt-5")]
98
GPT5,
9+
#[serde(rename = "gpt-5-mini")]
1010
GPT5Mini,
11+
#[serde(rename = "gpt-5-nano")]
1112
GPT5Nano,
13+
#[serde(rename = "gpt-5-chat")]
1214
GPT5Chat,
15+
#[serde(rename = "gpt-4.5")]
1316
GPT45,
17+
#[serde(rename = "gpt-4.1")]
1418
GPT41,
19+
#[serde(rename = "gpt-4.1-mini")]
1520
GPT41Mini,
21+
#[serde(rename = "gpt-4.1-nano")]
1622
GPT41Nano,
23+
#[serde(rename = "gpt-4")]
1724
GPT4,
25+
#[serde(rename = "gpt-4o")]
1826
GPT4o,
27+
#[serde(rename = "gpt-4o-mini")]
1928
GPT4oMini,
29+
#[serde(rename = "o4-mini")]
2030
O4Mini,
31+
#[serde(rename = "o3")]
2132
O3,
33+
#[serde(rename = "o1")]
2234
O1,
2335
}
2436

@@ -48,27 +60,6 @@ impl TryFrom<&str> for Model {
4860
}
4961
}
5062

51-
impl Display for Model {
52-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53-
match self {
54-
Model::GPT5 => write!(f, "gpt-5"),
55-
Model::GPT5Mini => write!(f, "gpt-5-mini"),
56-
Model::GPT5Nano => write!(f, "gpt-5-nano"),
57-
Model::GPT5Chat => write!(f, "gpt-5-chat"),
58-
Model::GPT45 => write!(f, "gpt-4.5"),
59-
Model::GPT41 => write!(f, "gpt-4.1"),
60-
Model::GPT41Mini => write!(f, "gpt-4.1-mini"),
61-
Model::GPT41Nano => write!(f, "gpt-4.1-nano"),
62-
Model::GPT4 => write!(f, "gpt-4"),
63-
Model::GPT4o => write!(f, "gpt-4o"),
64-
Model::GPT4oMini => write!(f, "gpt-4o-mini"),
65-
Model::O4Mini => write!(f, "o4-mini"),
66-
Model::O3 => write!(f, "o3"),
67-
Model::O1 => write!(f, "o1"),
68-
}
69-
}
70-
}
71-
7263
#[derive(Serialize, Debug)]
7364
pub struct Prompt {
7465
role: Role,
@@ -83,23 +74,16 @@ impl Prompt {
8374

8475
#[derive(Serialize, Debug)]
8576
pub enum Role {
77+
#[serde(rename = "system")]
8678
System,
79+
#[serde(rename = "user")]
8780
User,
81+
#[serde(rename = "assistant")]
8882
Assistant,
83+
#[serde(rename = "tool")]
8984
Tool,
9085
}
9186

92-
impl Display for Role {
93-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94-
match self {
95-
Role::System => write!(f, "system"),
96-
Role::User => write!(f, "user"),
97-
Role::Assistant => write!(f, "assistant"),
98-
Role::Tool => write!(f, "tool"),
99-
}
100-
}
101-
}
102-
10387
impl TryFrom<&str> for Role {
10488
type Error = wasi_llm::Error;
10589

@@ -118,19 +102,12 @@ impl TryFrom<&str> for Role {
118102

119103
#[derive(Serialize, Debug)]
120104
pub enum EncodingFormat {
105+
#[serde(rename = "float")]
121106
Float,
107+
#[serde(rename = "base64")]
122108
Base64,
123109
}
124110

125-
impl Display for EncodingFormat {
126-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127-
match self {
128-
EncodingFormat::Float => write!(f, "float"),
129-
EncodingFormat::Base64 => write!(f, "base64"),
130-
}
131-
}
132-
}
133-
134111
impl TryFrom<&str> for EncodingFormat {
135112
type Error = wasi_llm::Error;
136113

@@ -147,23 +124,15 @@ impl TryFrom<&str> for EncodingFormat {
147124

148125
#[derive(Serialize, Debug)]
149126
pub enum EmbeddingModels {
127+
#[serde(rename = "text-embedding-ada-002")]
150128
TextEmbeddingAda002,
129+
#[serde(rename = "text-embedding-3-small")]
151130
TextEmbedding3Small,
131+
#[serde(rename = "text-embedding-3-large")]
152132
TextEmbedding3Large,
153133
Custom(String),
154134
}
155135

156-
impl Display for EmbeddingModels {
157-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158-
match self {
159-
EmbeddingModels::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
160-
EmbeddingModels::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
161-
EmbeddingModels::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
162-
EmbeddingModels::Custom(model) => write!(f, "{model}"),
163-
}
164-
}
165-
}
166-
167136
impl TryFrom<&str> for EmbeddingModels {
168137
type Error = wasi_llm::Error;
169138

@@ -179,23 +148,16 @@ impl TryFrom<&str> for EmbeddingModels {
179148

180149
#[derive(Serialize, Debug)]
181150
enum ReasoningEffort {
151+
#[serde(rename = "minimal")]
182152
Minimal,
153+
#[serde(rename = "low")]
183154
Low,
155+
#[serde(rename = "medium")]
184156
Medium,
157+
#[serde(rename = "high")]
185158
High,
186159
}
187160

188-
impl Display for ReasoningEffort {
189-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190-
match self {
191-
ReasoningEffort::Minimal => write!(f, "minimal"),
192-
ReasoningEffort::Low => write!(f, "low"),
193-
ReasoningEffort::Medium => write!(f, "medium"),
194-
ReasoningEffort::High => write!(f, "high"),
195-
}
196-
}
197-
}
198-
199161
impl TryFrom<&str> for ReasoningEffort {
200162
type Error = wasi_llm::Error;
201163

@@ -236,42 +198,68 @@ impl TryFrom<&str> for Verbosity {
236198

237199
#[derive(Deserialize)]
238200
pub struct ChatCompletionChoice {
201+
#[serde(rename = "index")]
239202
/// The index of the choice in the list of choices
240203
_index: u32,
241204
pub message: ChatCompletionResponseMessage,
242205
/// The reason the model stopped generating tokens. This will be `stop` if the model hit a
243206
/// natural stop point or a provided stop sequence,
207+
#[serde(rename = "finish_reason")]
244208
_finish_reason: String,
245209
/// Log probability information for the choice.
210+
#[serde(rename = "logprobs")]
246211
_logprobs: Option<Logprobs>,
247212
}
248213

249214
#[derive(Deserialize)]
250215
/// A chat completion message generated by the model.
251216
pub struct ChatCompletionResponseMessage {
252217
/// The role of the author of this message
218+
#[serde(rename = "role")]
253219
_role: String,
254220
/// The contents of the message
255221
pub content: String,
256222
/// The refusal message generated by the model
223+
#[serde(rename = "refusal")]
257224
_refusal: Option<String>,
258225
}
259226

260227
#[derive(Deserialize)]
261228
pub struct Logprobs {
262229
/// A list of message content tokens with log probability information.
230+
#[serde(rename = "content")]
263231
_content: Option<Vec<String>>,
264232
/// A list of message refusal tokens with log probability information.
233+
#[serde(rename = "refusal")]
265234
_refusal: Option<Vec<String>>,
266235
}
267236

268237
#[derive(Deserialize)]
269238
pub struct Embedding {
270239
/// The index of the embedding in the list of embeddings..
240+
#[serde(rename = "index")]
271241
_index: u32,
272242
/// The embedding vector, which is a list of floats. The length of vector depends on the model as
273243
/// listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
274244
pub embedding: Vec<f32>,
275245
/// The object type, which is always "embedding"
246+
#[serde(rename = "object")]
276247
_object: String,
277248
}
249+
250+
#[derive(Deserialize, Default)]
251+
pub struct ResponseError {
252+
pub message: String,
253+
#[serde(rename = "type")]
254+
_t: String,
255+
#[serde(rename = "param")]
256+
_param: Option<String>,
257+
#[serde(rename = "code")]
258+
_code: String,
259+
}
260+
261+
impl From<ResponseError> for wasi_llm::Error {
262+
fn from(value: ResponseError) -> Self {
263+
wasi_llm::Error::RuntimeError(value.message)
264+
}
265+
}

examples/open-ai-rust/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
target/
2+
.spin/

0 commit comments

Comments
 (0)