Skip to content

Commit 061c97e

Browse files
authored
Implement Unmarshaller interface. Resolves #244 (#248)
1 parent d94c5e7 commit 061c97e

File tree

2 files changed

+145
-5
lines changed

2 files changed

+145
-5
lines changed

api_test.go

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package openai_test
22

33
import (
4+
"encoding/json"
5+
46
. "github.com/sashabaranov/go-openai"
57
"github.com/sashabaranov/go-openai/internal/test/checks"
68

@@ -110,7 +112,7 @@ func TestAPIError(t *testing.T) {
110112
c := NewClient(apiToken + "_invalid")
111113
ctx := context.Background()
112114
_, err = c.ListEngines(ctx)
113-
checks.NoError(t, err, "ListEngines did not fail")
115+
checks.HasError(t, err, "ListEngines should fail with an invalid key")
114116

115117
var apiErr *APIError
116118
if !errors.As(err, &apiErr) {
@@ -120,14 +122,108 @@ func TestAPIError(t *testing.T) {
120122
if apiErr.StatusCode != 401 {
121123
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
122124
}
123-
if *apiErr.Code != "invalid_api_key" {
124-
t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
125+
126+
switch v := apiErr.Code.(type) {
127+
case string:
128+
if v != "invalid_api_key" {
129+
t.Fatalf("Unexpected API error code: %s", v)
130+
}
131+
default:
132+
t.Fatalf("Unexpected API error code type: %T", v)
125133
}
134+
126135
if apiErr.Error() == "" {
127136
t.Fatal("Empty error message occurred")
128137
}
129138
}
130139

140+
func TestAPIErrorUnmarshalJSONInteger(t *testing.T) {
141+
var apiErr APIError
142+
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
143+
err := json.Unmarshal([]byte(response), &apiErr)
144+
checks.NoError(t, err, "Unexpected Unmarshal API response error")
145+
146+
switch v := apiErr.Code.(type) {
147+
case int:
148+
if v != 418 {
149+
t.Fatalf("Unexpected API code integer: %d; expected 418", v)
150+
}
151+
default:
152+
t.Fatalf("Unexpected API error code type: %T", v)
153+
}
154+
}
155+
156+
func TestAPIErrorUnmarshalJSONString(t *testing.T) {
157+
var apiErr APIError
158+
response := `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
159+
err := json.Unmarshal([]byte(response), &apiErr)
160+
checks.NoError(t, err, "Unexpected Unmarshal API response error")
161+
162+
switch v := apiErr.Code.(type) {
163+
case string:
164+
if v != "teapot" {
165+
t.Fatalf("Unexpected API code string: %s; expected `teapot`", v)
166+
}
167+
default:
168+
t.Fatalf("Unexpected API error code type: %T", v)
169+
}
170+
}
171+
172+
func TestAPIErrorUnmarshalJSONNoCode(t *testing.T) {
173+
// test integer code
174+
response := `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
175+
var apiErr APIError
176+
err := json.Unmarshal([]byte(response), &apiErr)
177+
checks.NoError(t, err, "Unexpected Unmarshal API response error")
178+
179+
switch v := apiErr.Code.(type) {
180+
case nil:
181+
default:
182+
t.Fatalf("Unexpected API error code type: %T", v)
183+
}
184+
}
185+
186+
func TestAPIErrorUnmarshalInvalidData(t *testing.T) {
187+
apiErr := APIError{}
188+
data := []byte(`--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`)
189+
err := apiErr.UnmarshalJSON(data)
190+
checks.HasError(t, err, "Expected error when unmarshaling invalid data")
191+
192+
if apiErr.Code != nil {
193+
t.Fatalf("Expected nil code, got %q", apiErr.Code)
194+
}
195+
if apiErr.Message != "" {
196+
t.Fatalf("Expected empty message, got %q", apiErr.Message)
197+
}
198+
if apiErr.Param != nil {
199+
t.Fatalf("Expected nil param, got %q", *apiErr.Param)
200+
}
201+
if apiErr.Type != "" {
202+
t.Fatalf("Expected empty type, got %q", apiErr.Type)
203+
}
204+
}
205+
206+
func TestAPIErrorUnmarshalJSONInvalidParam(t *testing.T) {
207+
var apiErr APIError
208+
response := `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}`
209+
err := json.Unmarshal([]byte(response), &apiErr)
210+
checks.HasError(t, err, "Param should be a string")
211+
}
212+
213+
func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) {
214+
var apiErr APIError
215+
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}`
216+
err := json.Unmarshal([]byte(response), &apiErr)
217+
checks.HasError(t, err, "Type should be a string")
218+
}
219+
220+
func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
221+
var apiErr APIError
222+
response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}`
223+
err := json.Unmarshal([]byte(response), &apiErr)
224+
checks.HasError(t, err, "Message should be a string")
225+
}
226+
131227
func TestRequestError(t *testing.T) {
132228
var err error
133229

error.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package openai
22

3-
import "fmt"
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
47

58
// APIError provides error information returned by the OpenAI API.
69
type APIError struct {
7-
Code *string `json:"code,omitempty"`
10+
Code any `json:"code,omitempty"`
811
Message string `json:"message"`
912
Param *string `json:"param,omitempty"`
1013
Type string `json:"type"`
@@ -25,6 +28,47 @@ func (e *APIError) Error() string {
2528
return e.Message
2629
}
2730

31+
func (e *APIError) UnmarshalJSON(data []byte) (err error) {
32+
var rawMap map[string]json.RawMessage
33+
err = json.Unmarshal(data, &rawMap)
34+
if err != nil {
35+
return
36+
}
37+
38+
err = json.Unmarshal(rawMap["message"], &e.Message)
39+
if err != nil {
40+
return
41+
}
42+
43+
err = json.Unmarshal(rawMap["type"], &e.Type)
44+
if err != nil {
45+
return
46+
}
47+
48+
// optional fields
49+
if _, ok := rawMap["param"]; ok {
50+
err = json.Unmarshal(rawMap["param"], &e.Param)
51+
if err != nil {
52+
return
53+
}
54+
}
55+
56+
if _, ok := rawMap["code"]; !ok {
57+
return nil
58+
}
59+
60+
// if the api returned a number, we need to force an integer
61+
// since the json package defaults to float64
62+
var intCode int
63+
err = json.Unmarshal(rawMap["code"], &intCode)
64+
if err == nil {
65+
e.Code = intCode
66+
return nil
67+
}
68+
69+
return json.Unmarshal(rawMap["code"], &e.Code)
70+
}
71+
2872
func (e *RequestError) Error() string {
2973
if e.Err != nil {
3074
return e.Err.Error()

0 commit comments

Comments
 (0)