1
- package openai_test
1
+ package openai //nolint:testpackage // testing private field
2
2
3
3
import (
4
- . "github.com/sashabaranov/go-openai"
4
+ utils "github.com/sashabaranov/go-openai/internal "
5
5
"github.com/sashabaranov/go-openai/internal/test"
6
6
"github.com/sashabaranov/go-openai/internal/test/checks"
7
7
@@ -63,9 +63,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
63
63
// Client portion of the test
64
64
config := DefaultConfig (test .GetTestToken ())
65
65
config .BaseURL = server .URL + "/v1"
66
- config .HTTPClient .Transport = & tokenRoundTripper {
67
- test .GetTestToken (),
68
- http .DefaultTransport ,
66
+ config .HTTPClient .Transport = & test. TokenRoundTripper {
67
+ Token : test .GetTestToken (),
68
+ Fallback : http .DefaultTransport ,
69
69
}
70
70
71
71
client := NewClientWithConfig (config )
@@ -170,9 +170,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
170
170
// Client portion of the test
171
171
config := DefaultConfig (test .GetTestToken ())
172
172
config .BaseURL = server .URL + "/v1"
173
- config .HTTPClient .Transport = & tokenRoundTripper {
174
- test .GetTestToken (),
175
- http .DefaultTransport ,
173
+ config .HTTPClient .Transport = & test. TokenRoundTripper {
174
+ Token : test .GetTestToken (),
175
+ Fallback : http .DefaultTransport ,
176
176
}
177
177
178
178
client := NewClientWithConfig (config )
@@ -227,9 +227,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
227
227
// Client portion of the test
228
228
config := DefaultConfig (test .GetTestToken ())
229
229
config .BaseURL = ts .URL + "/v1"
230
- config .HTTPClient .Transport = & tokenRoundTripper {
231
- test .GetTestToken (),
232
- http .DefaultTransport ,
230
+ config .HTTPClient .Transport = & test. TokenRoundTripper {
231
+ Token : test .GetTestToken (),
232
+ Fallback : http .DefaultTransport ,
233
233
}
234
234
235
235
client := NewClientWithConfig (config )
@@ -255,6 +255,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
255
255
t .Logf ("%+v\n " , apiErr )
256
256
}
257
257
258
+ func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors (t * testing.T ) {
259
+ var err error
260
+ server := test .NewTestServer ()
261
+ server .RegisterHandler ("/v1/chat/completions" , func (w http.ResponseWriter , r * http.Request ) {
262
+ http .Error (w , "error" , 200 )
263
+ })
264
+ ts := server .OpenAITestServer ()
265
+ ts .Start ()
266
+ defer ts .Close ()
267
+
268
+ config := DefaultConfig (test .GetTestToken ())
269
+ config .BaseURL = ts .URL + "/v1"
270
+ client := NewClientWithConfig (config )
271
+
272
+ ctx := context .Background ()
273
+
274
+ stream , err := client .CreateChatCompletionStream (ctx , ChatCompletionRequest {})
275
+ checks .NoError (t , err )
276
+
277
+ stream .errAccumulator = & utils.DefaultErrorAccumulator {
278
+ Buffer : & test.FailingErrorBuffer {},
279
+ }
280
+
281
+ _ , err = stream .Recv ()
282
+ checks .ErrorIs (t , err , test .ErrTestErrorAccumulatorWriteFailed , "Did not return error when Write failed" , err .Error ())
283
+ }
284
+
258
285
// Helper funcs.
259
286
func compareChatResponses (r1 , r2 ChatCompletionStreamResponse ) bool {
260
287
if r1 .ID != r2 .ID || r1 .Object != r2 .Object || r1 .Created != r2 .Created || r1 .Model != r2 .Model {
0 commit comments