Skip to content

Commit d6ab1b3

Browse files
authored
fix: chat stream resp error (#259)
1 parent 3b10c03 commit d6ab1b3

File tree

8 files changed

+146
-33
lines changed

8 files changed

+146
-33
lines changed

api_test.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
package openai_test
22

33
import (
4-
"encoding/json"
5-
6-
. "github.com/sashabaranov/go-openai"
7-
"github.com/sashabaranov/go-openai/internal/test/checks"
8-
94
"context"
5+
"encoding/json"
106
"errors"
117
"io"
128
"os"
139
"testing"
10+
11+
. "github.com/sashabaranov/go-openai"
12+
"github.com/sashabaranov/go-openai/internal/test/checks"
1413
)
1514

1615
func TestAPI(t *testing.T) {
@@ -119,8 +118,8 @@ func TestAPIError(t *testing.T) {
119118
t.Fatalf("Error is not an APIError: %+v", err)
120119
}
121120

122-
if apiErr.StatusCode != 401 {
123-
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
121+
if apiErr.HTTPStatusCode != 401 {
122+
t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode)
124123
}
125124

126125
switch v := apiErr.Code.(type) {
@@ -239,8 +238,8 @@ func TestRequestError(t *testing.T) {
239238
t.Fatalf("Error is not a RequestError: %+v", err)
240239
}
241240

242-
if reqErr.StatusCode != 418 {
243-
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
241+
if reqErr.HTTPStatusCode != 418 {
242+
t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode)
244243
}
245244

246245
if reqErr.Unwrap() == nil {

chat_stream.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai
33
import (
44
"bufio"
55
"context"
6+
"net/http"
67
)
78

89
type ChatCompletionStreamChoiceDelta struct {
@@ -53,6 +54,9 @@ func (c *Client) CreateChatCompletionStream(
5354
if err != nil {
5455
return
5556
}
57+
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
58+
return nil, c.handleErrorResp(resp)
59+
}
5660

5761
stream = &ChatCompletionStream{
5862
streamReader: &streamReader[ChatCompletionStreamResponse]{

chat_stream_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,57 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
204204
t.Logf("%+v\n", apiErr)
205205
}
206206

207+
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
208+
server := test.NewTestServer()
209+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
210+
w.Header().Set("Content-Type", "application/json")
211+
w.WriteHeader(429)
212+
213+
// Send test responses
214+
dataBytes := []byte(`{"error":{` +
215+
`"message": "You are sending requests too quickly.",` +
216+
`"type":"rate_limit_reached",` +
217+
`"param":null,` +
218+
`"code":"rate_limit_reached"}}`)
219+
220+
_, err := w.Write(dataBytes)
221+
checks.NoError(t, err, "Write error")
222+
})
223+
ts := server.OpenAITestServer()
224+
ts.Start()
225+
defer ts.Close()
226+
227+
// Client portion of the test
228+
config := DefaultConfig(test.GetTestToken())
229+
config.BaseURL = ts.URL + "/v1"
230+
config.HTTPClient.Transport = &tokenRoundTripper{
231+
test.GetTestToken(),
232+
http.DefaultTransport,
233+
}
234+
235+
client := NewClientWithConfig(config)
236+
ctx := context.Background()
237+
238+
request := ChatCompletionRequest{
239+
MaxTokens: 5,
240+
Model: GPT3Dot5Turbo,
241+
Messages: []ChatCompletionMessage{
242+
{
243+
Role: ChatMessageRoleUser,
244+
Content: "Hello!",
245+
},
246+
},
247+
Stream: true,
248+
}
249+
250+
var apiErr *APIError
251+
_, err := client.CreateChatCompletionStream(ctx, request)
252+
if !errors.As(err, &apiErr) {
253+
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
254+
}
255+
t.Logf("%+v\n", apiErr)
256+
}
257+
207258
// Helper funcs.
208259
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
209260
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

client.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
7272
defer res.Body.Close()
7373

7474
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
75-
var errRes ErrorResponse
76-
err = json.NewDecoder(res.Body).Decode(&errRes)
77-
if err != nil || errRes.Error == nil {
78-
reqErr := RequestError{
79-
StatusCode: res.StatusCode,
80-
Err: err,
81-
}
82-
return fmt.Errorf("error, %w", &reqErr)
83-
}
84-
errRes.Error.StatusCode = res.StatusCode
85-
return fmt.Errorf("error, status code: %d, message: %w", res.StatusCode, errRes.Error)
75+
return c.handleErrorResp(res)
8676
}
8777

8878
if v != nil {
@@ -132,3 +122,17 @@ func (c *Client) newStreamRequest(
132122
}
133123
return req, nil
134124
}
125+
126+
func (c *Client) handleErrorResp(resp *http.Response) error {
127+
var errRes ErrorResponse
128+
err := json.NewDecoder(resp.Body).Decode(&errRes)
129+
if err != nil || errRes.Error == nil {
130+
reqErr := RequestError{
131+
HTTPStatusCode: resp.StatusCode,
132+
Err: err,
133+
}
134+
return fmt.Errorf("error, %w", &reqErr)
135+
}
136+
errRes.Error.HTTPStatusCode = resp.StatusCode
137+
return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
138+
}

error.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ import (
77

88
// APIError provides error information returned by the OpenAI API.
99
type APIError struct {
10-
Code any `json:"code,omitempty"`
11-
Message string `json:"message"`
12-
Param *string `json:"param,omitempty"`
13-
Type string `json:"type"`
14-
StatusCode int `json:"-"`
10+
Code any `json:"code,omitempty"`
11+
Message string `json:"message"`
12+
Param *string `json:"param,omitempty"`
13+
Type string `json:"type"`
14+
HTTPStatusCode int `json:"-"`
1515
}
1616

1717
// RequestError provides informations about generic request errors.
1818
type RequestError struct {
19-
StatusCode int
20-
Err error
19+
HTTPStatusCode int
20+
Err error
2121
}
2222

2323
type ErrorResponse struct {
@@ -73,7 +73,7 @@ func (e *RequestError) Error() string {
7373
if e.Err != nil {
7474
return e.Err.Error()
7575
}
76-
return fmt.Sprintf("status code %d", e.StatusCode)
76+
return fmt.Sprintf("status code %d", e.HTTPStatusCode)
7777
}
7878

7979
func (e *RequestError) Unwrap() error {

error_accumulator_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"errors"
7+
"net/http"
78
"testing"
89

910
"github.com/sashabaranov/go-openai/internal/test"
@@ -71,7 +72,11 @@ func TestErrorByteWriteErrors(t *testing.T) {
7172

7273
func TestErrorAccumulatorWriteErrors(t *testing.T) {
7374
var err error
74-
ts := test.NewTestServer().OpenAITestServer()
75+
server := test.NewTestServer()
76+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
77+
http.Error(w, "error", 200)
78+
})
79+
ts := server.OpenAITestServer()
7580
ts.Start()
7681
defer ts.Close()
7782

stream.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bufio"
55
"context"
66
"errors"
7+
"net/http"
78
)
89

910
var (
@@ -43,6 +44,9 @@ func (c *Client) CreateCompletionStream(
4344
if err != nil {
4445
return
4546
}
47+
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
48+
return nil, c.handleErrorResp(resp)
49+
}
4650

4751
stream = &CompletionStream{
4852
streamReader: &streamReader[CompletionResponse]{

stream_test.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
package openai_test
22

33
import (
4-
. "github.com/sashabaranov/go-openai"
5-
"github.com/sashabaranov/go-openai/internal/test"
6-
"github.com/sashabaranov/go-openai/internal/test/checks"
7-
84
"context"
95
"errors"
106
"io"
117
"net/http"
128
"net/http/httptest"
139
"testing"
10+
11+
. "github.com/sashabaranov/go-openai"
12+
"github.com/sashabaranov/go-openai/internal/test"
13+
"github.com/sashabaranov/go-openai/internal/test/checks"
1414
)
1515

1616
func TestCompletionsStreamWrongModel(t *testing.T) {
@@ -171,6 +171,52 @@ func TestCreateCompletionStreamError(t *testing.T) {
171171
t.Logf("%+v\n", apiErr)
172172
}
173173

174+
func TestCreateCompletionStreamRateLimitError(t *testing.T) {
175+
server := test.NewTestServer()
176+
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
177+
w.Header().Set("Content-Type", "application/json")
178+
w.WriteHeader(429)
179+
180+
// Send test responses
181+
dataBytes := []byte(`{"error":{` +
182+
`"message": "You are sending requests too quickly.",` +
183+
`"type":"rate_limit_reached",` +
184+
`"param":null,` +
185+
`"code":"rate_limit_reached"}}`)
186+
187+
_, err := w.Write(dataBytes)
188+
checks.NoError(t, err, "Write error")
189+
})
190+
ts := server.OpenAITestServer()
191+
ts.Start()
192+
defer ts.Close()
193+
194+
// Client portion of the test
195+
config := DefaultConfig(test.GetTestToken())
196+
config.BaseURL = ts.URL + "/v1"
197+
config.HTTPClient.Transport = &tokenRoundTripper{
198+
test.GetTestToken(),
199+
http.DefaultTransport,
200+
}
201+
202+
client := NewClientWithConfig(config)
203+
ctx := context.Background()
204+
205+
request := CompletionRequest{
206+
MaxTokens: 5,
207+
Model: GPT3Ada,
208+
Prompt: "Hello!",
209+
Stream: true,
210+
}
211+
212+
var apiErr *APIError
213+
_, err := client.CreateCompletionStream(ctx, request)
214+
if !errors.As(err, &apiErr) {
215+
t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError")
216+
}
217+
t.Logf("%+v\n", apiErr)
218+
}
219+
174220
// A "tokenRoundTripper" is a struct that implements the RoundTripper
175221
// interface, specifically to handle the authentication token by adding a token
176222
// to the request header. We need this because the API requires that each

0 commit comments

Comments
 (0)