Skip to content

Commit 1394329

Browse files
authored
move error_accumulator into internal pkg (#304) (#335)
* move error_accumulator into internal pkg (#304) * move error_accumulator into internal pkg (#304) * add a test for ErrTooManyEmptyStreamMessages in stream_reader (#304)
1 parent fa694c6 commit 1394329

12 files changed

+249
-201
lines changed

chat_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (c *Client) CreateChatCompletionStream(
6666
emptyMessagesLimit: c.config.EmptyMessagesLimit,
6767
reader: bufio.NewReader(resp.Body),
6868
response: resp,
69-
errAccumulator: newErrorAccumulator(),
69+
errAccumulator: utils.NewErrorAccumulator(),
7070
unmarshaler: &utils.JSONUnmarshaler{},
7171
},
7272
}

chat_stream_test.go

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
package openai_test
1+
package openai //nolint:testpackage // testing private field
22

33
import (
4-
. "github.com/sashabaranov/go-openai"
4+
utils "github.com/sashabaranov/go-openai/internal"
55
"github.com/sashabaranov/go-openai/internal/test"
66
"github.com/sashabaranov/go-openai/internal/test/checks"
77

@@ -63,9 +63,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
6363
// Client portion of the test
6464
config := DefaultConfig(test.GetTestToken())
6565
config.BaseURL = server.URL + "/v1"
66-
config.HTTPClient.Transport = &tokenRoundTripper{
67-
test.GetTestToken(),
68-
http.DefaultTransport,
66+
config.HTTPClient.Transport = &test.TokenRoundTripper{
67+
Token: test.GetTestToken(),
68+
Fallback: http.DefaultTransport,
6969
}
7070

7171
client := NewClientWithConfig(config)
@@ -170,9 +170,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
170170
// Client portion of the test
171171
config := DefaultConfig(test.GetTestToken())
172172
config.BaseURL = server.URL + "/v1"
173-
config.HTTPClient.Transport = &tokenRoundTripper{
174-
test.GetTestToken(),
175-
http.DefaultTransport,
173+
config.HTTPClient.Transport = &test.TokenRoundTripper{
174+
Token: test.GetTestToken(),
175+
Fallback: http.DefaultTransport,
176176
}
177177

178178
client := NewClientWithConfig(config)
@@ -227,9 +227,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
227227
// Client portion of the test
228228
config := DefaultConfig(test.GetTestToken())
229229
config.BaseURL = ts.URL + "/v1"
230-
config.HTTPClient.Transport = &tokenRoundTripper{
231-
test.GetTestToken(),
232-
http.DefaultTransport,
230+
config.HTTPClient.Transport = &test.TokenRoundTripper{
231+
Token: test.GetTestToken(),
232+
Fallback: http.DefaultTransport,
233233
}
234234

235235
client := NewClientWithConfig(config)
@@ -255,6 +255,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
255255
t.Logf("%+v\n", apiErr)
256256
}
257257

258+
func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
259+
var err error
260+
server := test.NewTestServer()
261+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
262+
http.Error(w, "error", 200)
263+
})
264+
ts := server.OpenAITestServer()
265+
ts.Start()
266+
defer ts.Close()
267+
268+
config := DefaultConfig(test.GetTestToken())
269+
config.BaseURL = ts.URL + "/v1"
270+
client := NewClientWithConfig(config)
271+
272+
ctx := context.Background()
273+
274+
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
275+
checks.NoError(t, err)
276+
277+
stream.errAccumulator = &utils.DefaultErrorAccumulator{
278+
Buffer: &test.FailingErrorBuffer{},
279+
}
280+
281+
_, err = stream.Recv()
282+
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error())
283+
}
284+
258285
// Helper funcs.
259286
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
260287
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

error_accumulator.go

Lines changed: 0 additions & 53 deletions
This file was deleted.

error_accumulator_test.go

Lines changed: 0 additions & 100 deletions
This file was deleted.

internal/error_accumulator.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io"
7+
)
8+
9+
type ErrorAccumulator interface {
10+
Write(p []byte) error
11+
Bytes() []byte
12+
}
13+
14+
type errorBuffer interface {
15+
io.Writer
16+
Len() int
17+
Bytes() []byte
18+
}
19+
20+
type DefaultErrorAccumulator struct {
21+
Buffer errorBuffer
22+
}
23+
24+
func NewErrorAccumulator() ErrorAccumulator {
25+
return &DefaultErrorAccumulator{
26+
Buffer: &bytes.Buffer{},
27+
}
28+
}
29+
30+
func (e *DefaultErrorAccumulator) Write(p []byte) error {
31+
_, err := e.Buffer.Write(p)
32+
if err != nil {
33+
return fmt.Errorf("error accumulator write error, %w", err)
34+
}
35+
return nil
36+
}
37+
38+
func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) {
39+
if e.Buffer.Len() == 0 {
40+
return
41+
}
42+
errBytes = e.Buffer.Bytes()
43+
return
44+
}

internal/error_accumulator_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package openai_test
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"testing"
7+
8+
utils "github.com/sashabaranov/go-openai/internal"
9+
"github.com/sashabaranov/go-openai/internal/test"
10+
)
11+
12+
func TestErrorAccumulatorBytes(t *testing.T) {
13+
accumulator := &utils.DefaultErrorAccumulator{
14+
Buffer: &bytes.Buffer{},
15+
}
16+
17+
errBytes := accumulator.Bytes()
18+
if len(errBytes) != 0 {
19+
t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes))
20+
}
21+
22+
err := accumulator.Write([]byte("{}"))
23+
if err != nil {
24+
t.Fatalf("%+v", err)
25+
}
26+
27+
errBytes = accumulator.Bytes()
28+
if len(errBytes) == 0 {
29+
t.Fatalf("Did not return error bytes when has error: %s", string(errBytes))
30+
}
31+
}
32+
33+
func TestErrorByteWriteErrors(t *testing.T) {
34+
accumulator := &utils.DefaultErrorAccumulator{
35+
Buffer: &test.FailingErrorBuffer{},
36+
}
37+
err := accumulator.Write([]byte("{"))
38+
if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) {
39+
t.Fatalf("Did not return error when write failed: %v", err)
40+
}
41+
}

internal/test/failer.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package test
2+
3+
import "errors"
4+
5+
var (
6+
ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
7+
)
8+
9+
type FailingErrorBuffer struct{}
10+
11+
func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) {
12+
return 0, ErrTestErrorAccumulatorWriteFailed
13+
}
14+
15+
func (b *FailingErrorBuffer) Len() int {
16+
return 0
17+
}
18+
19+
func (b *FailingErrorBuffer) Bytes() []byte {
20+
return []byte{}
21+
}

internal/test/helpers.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package test
33
import (
44
"github.com/sashabaranov/go-openai/internal/test/checks"
55

6+
"net/http"
67
"os"
78
"testing"
89
)
@@ -27,3 +28,26 @@ func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {
2728

2829
return path, func() { os.RemoveAll(path) }
2930
}
31+
32+
// TokenRoundTripper is a struct that implements the RoundTripper
33+
// interface, specifically to handle the authentication token by adding a token
34+
// to the request header. We need this because the API requires that each
35+
// request include a valid API token in the headers for authentication and
36+
// authorization.
37+
type TokenRoundTripper struct {
38+
Token string
39+
Fallback http.RoundTripper
40+
}
41+
42+
// RoundTrip takes an *http.Request as input and returns an
43+
// *http.Response and an error.
44+
//
45+
// It is expected to use the provided request to create a connection to an HTTP
46+
// server and return the response, or an error if one occurred. The returned
47+
// Response should have its Body closed. If the RoundTrip method returns an
48+
// error, the Client's Get, Head, Post, and PostForm methods return the same
49+
// error.
50+
func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
51+
req.Header.Set("Authorization", "Bearer "+t.Token)
52+
return t.Fallback.RoundTrip(req)
53+
}

stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func (c *Client) CreateCompletionStream(
5555
emptyMessagesLimit: c.config.EmptyMessagesLimit,
5656
reader: bufio.NewReader(resp.Body),
5757
response: resp,
58-
errAccumulator: newErrorAccumulator(),
58+
errAccumulator: utils.NewErrorAccumulator(),
5959
unmarshaler: &utils.JSONUnmarshaler{},
6060
},
6161
}

0 commit comments

Comments
 (0)