Skip to content

Commit e49d771

Browse files
authored
support for parsing error response message fields even if they are arrays (#381) (#384)
1 parent f0770cf commit e49d771

File tree

2 files changed

+111
-8
lines changed

2 files changed

+111
-8
lines changed

api_test.go

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,108 @@ func TestAPIError(t *testing.T) {
137137
}
138138
}
139139

140+
func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) {
141+
type testCase struct {
142+
name string
143+
response string
144+
hasError bool
145+
checkFn func(t *testing.T, apiErr APIError)
146+
}
147+
testCases := []testCase{
148+
{
149+
name: "parse succeeds when the message is string",
150+
response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`,
151+
hasError: false,
152+
checkFn: func(t *testing.T, apiErr APIError) {
153+
expected := "foo"
154+
if apiErr.Message != expected {
155+
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
156+
}
157+
},
158+
},
159+
{
160+
name: "parse succeeds when the message is array with single item",
161+
response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`,
162+
hasError: false,
163+
checkFn: func(t *testing.T, apiErr APIError) {
164+
expected := "foo"
165+
if apiErr.Message != expected {
166+
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
167+
}
168+
},
169+
},
170+
{
171+
name: "parse succeeds when the message is array with multiple items",
172+
response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`,
173+
hasError: false,
174+
checkFn: func(t *testing.T, apiErr APIError) {
175+
expected := "foo, bar, baz"
176+
if apiErr.Message != expected {
177+
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
178+
}
179+
},
180+
},
181+
{
182+
name: "parse succeeds when the message is empty array",
183+
response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`,
184+
hasError: false,
185+
checkFn: func(t *testing.T, apiErr APIError) {
186+
if apiErr.Message != "" {
187+
t.Fatalf("Unexpected API message: %v; expected: empty", apiErr)
188+
}
189+
},
190+
},
191+
{
192+
name: "parse succeeds when the message is null",
193+
response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`,
194+
hasError: false,
195+
checkFn: func(t *testing.T, apiErr APIError) {
196+
if apiErr.Message != "" {
197+
t.Fatalf("Unexpected API message: %v; expected: empty", apiErr)
198+
}
199+
},
200+
},
201+
{
202+
name: "parse failed when the message is object",
203+
response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`,
204+
hasError: true,
205+
},
206+
{
207+
name: "parse failed when the message is int",
208+
response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`,
209+
hasError: true,
210+
},
211+
{
212+
name: "parse failed when the message is float",
213+
response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`,
214+
hasError: true,
215+
},
216+
{
217+
name: "parse failed when the message is bool",
218+
response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`,
219+
hasError: true,
220+
},
221+
{
222+
name: "parse failed when the message is not exists",
223+
response: `{"type":"invalid_request_error","param":null,"code":null}`,
224+
hasError: true,
225+
},
226+
}
227+
for _, tc := range testCases {
228+
t.Run(tc.name, func(t *testing.T) {
229+
var apiErr APIError
230+
err := json.Unmarshal([]byte(tc.response), &apiErr)
231+
if (err != nil) != tc.hasError {
232+
t.Errorf("Unexpected error: %v", err)
233+
return
234+
}
235+
if tc.checkFn != nil {
236+
tc.checkFn(t, apiErr)
237+
}
238+
})
239+
}
240+
}
241+
140242
func TestAPIErrorUnmarshalJSONInteger(t *testing.T) {
141243
var apiErr APIError
142244
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
@@ -217,13 +319,6 @@ func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) {
217319
checks.HasError(t, err, "Type should be a string")
218320
}
219321

220-
func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
221-
var apiErr APIError
222-
response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}`
223-
err := json.Unmarshal([]byte(response), &apiErr)
224-
checks.HasError(t, err, "Message should be a string")
225-
}
226-
227322
func TestRequestError(t *testing.T) {
228323
client, server, teardown := setupOpenAITestServer()
229324
defer teardown()

error.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai
33
import (
44
"encoding/json"
55
"fmt"
6+
"strings"
67
)
78

89
// APIError provides error information returned by the OpenAI API.
@@ -41,7 +42,14 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
4142

4243
err = json.Unmarshal(rawMap["message"], &e.Message)
4344
if err != nil {
44-
return
45+
// If the parameter field of a function call is invalid as a JSON schema
46+
// refs: https://github.com/sashabaranov/go-openai/issues/381
47+
var messages []string
48+
err = json.Unmarshal(rawMap["message"], &messages)
49+
if err != nil {
50+
return
51+
}
52+
e.Message = strings.Join(messages, ", ")
4553
}
4654

4755
// optional fields for azure openai

0 commit comments

Comments
 (0)