Skip to content

Commit 8e3a046

Browse files
rexposadasRex Posadas
andauthored
Refactor/internal testing (#194)
* added NoError check * corrected NoError * has error checks * replace more checks * Used checks test helper * Used checks test helper * remove duplicate import * fixed lint issues regarding length of messages --------- Co-authored-by: Rex Posadas <[email protected]>
1 parent 479dab3 commit 8e3a046

15 files changed

+115
-140
lines changed

api_test.go

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai_test
22

33
import (
44
. "github.com/sashabaranov/go-openai"
5+
"github.com/sashabaranov/go-openai/internal/test/checks"
56

67
"context"
78
"errors"
@@ -20,25 +21,17 @@ func TestAPI(t *testing.T) {
2021
c := NewClient(apiToken)
2122
ctx := context.Background()
2223
_, err = c.ListEngines(ctx)
23-
if err != nil {
24-
t.Fatalf("ListEngines error: %v", err)
25-
}
24+
checks.NoError(t, err, "ListEngines error")
2625

2726
_, err = c.GetEngine(ctx, "davinci")
28-
if err != nil {
29-
t.Fatalf("GetEngine error: %v", err)
30-
}
27+
checks.NoError(t, err, "GetEngine error")
3128

3229
fileRes, err := c.ListFiles(ctx)
33-
if err != nil {
34-
t.Fatalf("ListFiles error: %v", err)
35-
}
30+
checks.NoError(t, err, "ListFiles error")
3631

3732
if len(fileRes.Files) > 0 {
3833
_, err = c.GetFile(ctx, fileRes.Files[0].ID)
39-
if err != nil {
40-
t.Fatalf("GetFile error: %v", err)
41-
}
34+
checks.NoError(t, err, "GetFile error")
4235
} // else skip
4336

4437
embeddingReq := EmbeddingRequest{
@@ -49,9 +42,7 @@ func TestAPI(t *testing.T) {
4942
Model: AdaSearchQuery,
5043
}
5144
_, err = c.CreateEmbeddings(ctx, embeddingReq)
52-
if err != nil {
53-
t.Fatalf("Embedding error: %v", err)
54-
}
45+
checks.NoError(t, err, "Embedding error")
5546

5647
_, err = c.CreateChatCompletion(
5748
ctx,
@@ -66,9 +57,7 @@ func TestAPI(t *testing.T) {
6657
},
6758
)
6859

69-
if err != nil {
70-
t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
71-
}
60+
checks.NoError(t, err, "CreateChatCompletion (without name) returned error")
7261

7362
_, err = c.CreateChatCompletion(
7463
ctx,
@@ -83,20 +72,15 @@ func TestAPI(t *testing.T) {
8372
},
8473
},
8574
)
86-
87-
if err != nil {
88-
t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
89-
}
75+
checks.NoError(t, err, "CreateChatCompletion (with name) returned error")
9076

9177
stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
9278
Prompt: "Ex falso quodlibet",
9379
Model: GPT3Ada,
9480
MaxTokens: 5,
9581
Stream: true,
9682
})
97-
if err != nil {
98-
t.Errorf("CreateCompletionStream returned error: %v", err)
99-
}
83+
checks.NoError(t, err, "CreateCompletionStream returned error")
10084
defer stream.Close()
10185

10286
counter := 0
@@ -126,9 +110,7 @@ func TestAPIError(t *testing.T) {
126110
c := NewClient(apiToken + "_invalid")
127111
ctx := context.Background()
128112
_, err = c.ListEngines(ctx)
129-
if err == nil {
130-
t.Fatal("ListEngines did not fail")
131-
}
113+
checks.NoError(t, err, "ListEngines did not fail")
132114

133115
var apiErr *APIError
134116
if !errors.As(err, &apiErr) {
@@ -154,9 +136,7 @@ func TestRequestError(t *testing.T) {
154136
c := NewClientWithConfig(config)
155137
ctx := context.Background()
156138
_, err = c.ListEngines(ctx)
157-
if err == nil {
158-
t.Fatal("ListEngines request did not fail")
159-
}
139+
checks.HasError(t, err, "ListEngines did not fail")
160140

161141
var reqErr *RequestError
162142
if !errors.As(err, &reqErr) {

audio_test.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313

1414
. "github.com/sashabaranov/go-openai"
1515
"github.com/sashabaranov/go-openai/internal/test"
16+
"github.com/sashabaranov/go-openai/internal/test/checks"
1617

1718
"context"
1819
"testing"
@@ -62,9 +63,7 @@ func TestAudio(t *testing.T) {
6263
Model: "whisper-3",
6364
}
6465
_, err = tc.createFn(ctx, req)
65-
if err != nil {
66-
t.Fatalf("audio API error: %v", err)
67-
}
66+
checks.NoError(t, err, "audio API error")
6867
})
6968
}
7069
}
@@ -115,19 +114,16 @@ func TestAudioWithOptionalArgs(t *testing.T) {
115114
Language: "zh",
116115
}
117116
_, err = tc.createFn(ctx, req)
118-
if err != nil {
119-
t.Fatalf("audio API error: %v", err)
120-
}
117+
checks.NoError(t, err, "audio API error")
121118
})
122119
}
123120
}
124121

125122
// createTestFile creates a fake file with "hello" as the content.
126123
func createTestFile(t *testing.T, path string) {
127124
file, err := os.Create(path)
128-
if err != nil {
129-
t.Fatalf("failed to create file %v", err)
130-
}
125+
checks.NoError(t, err, "failed to create file")
126+
131127
if _, err = file.WriteString("hello"); err != nil {
132128
t.Fatalf("failed to write to file %v", err)
133129
}
@@ -139,9 +135,7 @@ func createTestDirectory(t *testing.T) (path string, cleanup func()) {
139135
t.Helper()
140136

141137
path, err := os.MkdirTemp(os.TempDir(), "")
142-
if err != nil {
143-
t.Fatal(err)
144-
}
138+
checks.NoError(t, err)
145139

146140
return path, func() { os.RemoveAll(path) }
147141
}

chat_stream_test.go

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai_test
33
import (
44
. "github.com/sashabaranov/go-openai"
55
"github.com/sashabaranov/go-openai/internal/test"
6+
"github.com/sashabaranov/go-openai/internal/test/checks"
67

78
"context"
89
"encoding/json"
@@ -55,9 +56,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
5556
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
5657

5758
_, err := w.Write(dataBytes)
58-
if err != nil {
59-
t.Errorf("Write error: %s", err)
60-
}
59+
checks.NoError(t, err, "Write error")
6160
}))
6261
defer server.Close()
6362

@@ -85,9 +84,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
8584
}
8685

8786
stream, err := client.CreateChatCompletionStream(ctx, request)
88-
if err != nil {
89-
t.Errorf("CreateCompletionStream returned error: %v", err)
90-
}
87+
checks.NoError(t, err, "CreateCompletionStream returned error")
9188
defer stream.Close()
9289

9390
expectedResponses := []ChatCompletionStreamResponse{
@@ -126,9 +123,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
126123
t.Logf("%d: %s", ix, string(b))
127124

128125
receivedResponse, streamErr := stream.Recv()
129-
if streamErr != nil {
130-
t.Errorf("stream.Recv() failed: %v", streamErr)
131-
}
126+
checks.NoError(t, streamErr, "stream.Recv() failed")
132127
if !compareChatResponses(expectedResponse, receivedResponse) {
133128
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
134129
}
@@ -140,6 +135,8 @@ func TestCreateChatCompletionStream(t *testing.T) {
140135
}
141136

142137
_, streamErr = stream.Recv()
138+
139+
checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished")
143140
if !errors.Is(streamErr, io.EOF) {
144141
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
145142
}
@@ -166,9 +163,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
166163
}
167164

168165
_, err := w.Write(dataBytes)
169-
if err != nil {
170-
t.Errorf("Write error: %s", err)
171-
}
166+
checks.NoError(t, err, "Write error")
172167
}))
173168
defer server.Close()
174169

@@ -196,15 +191,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
196191
}
197192

198193
stream, err := client.CreateChatCompletionStream(ctx, request)
199-
if err != nil {
200-
t.Errorf("CreateCompletionStream returned error: %v", err)
201-
}
194+
checks.NoError(t, err, "CreateCompletionStream returned error")
202195
defer stream.Close()
203196

204197
_, streamErr := stream.Recv()
205-
if streamErr == nil {
206-
t.Errorf("stream.Recv() did not return error")
207-
}
198+
checks.HasError(t, streamErr, "stream.Recv() did not return error")
199+
208200
var apiErr *APIError
209201
if !errors.As(streamErr, &apiErr) {
210202
t.Errorf("stream.Recv() did not return APIError")

chat_test.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ package openai_test
33
import (
44
. "github.com/sashabaranov/go-openai"
55
"github.com/sashabaranov/go-openai/internal/test"
6+
"github.com/sashabaranov/go-openai/internal/test/checks"
67

78
"context"
89
"encoding/json"
9-
"errors"
1010
"fmt"
1111
"io"
1212
"net/http"
@@ -33,9 +33,8 @@ func TestChatCompletionsWrongModel(t *testing.T) {
3333
},
3434
}
3535
_, err := client.CreateChatCompletion(ctx, req)
36-
if !errors.Is(err, ErrChatCompletionInvalidModel) {
37-
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
38-
}
36+
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
37+
checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg)
3938
}
4039

4140
func TestChatCompletionsWithStream(t *testing.T) {
@@ -48,9 +47,7 @@ func TestChatCompletionsWithStream(t *testing.T) {
4847
Stream: true,
4948
}
5049
_, err := client.CreateChatCompletion(ctx, req)
51-
if !errors.Is(err, ErrChatCompletionStreamNotSupported) {
52-
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
53-
}
50+
checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error")
5451
}
5552

5653
// TestCompletions Tests the completions endpoint of the API using the mocked server.
@@ -79,9 +76,7 @@ func TestChatCompletions(t *testing.T) {
7976
},
8077
}
8178
_, err = client.CreateChatCompletion(ctx, req)
82-
if err != nil {
83-
t.Fatalf("CreateChatCompletion error: %v", err)
84-
}
79+
checks.NoError(t, err, "CreateChatCompletion error")
8580
}
8681

8782
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.

completion_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai_test
33
import (
44
. "github.com/sashabaranov/go-openai"
55
"github.com/sashabaranov/go-openai/internal/test"
6+
"github.com/sashabaranov/go-openai/internal/test/checks"
67

78
"context"
89
"encoding/json"
@@ -66,9 +67,7 @@ func TestCompletions(t *testing.T) {
6667
}
6768
req.Prompt = "Lorem ipsum"
6869
_, err = client.CreateCompletion(ctx, req)
69-
if err != nil {
70-
t.Fatalf("CreateCompletion error: %v", err)
71-
}
70+
checks.NoError(t, err, "CreateCompletion error")
7271
}
7372

7473
// handleCompletionEndpoint Handles the completion endpoint by the test server.

edits_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai_test
33
import (
44
. "github.com/sashabaranov/go-openai"
55
"github.com/sashabaranov/go-openai/internal/test"
6+
"github.com/sashabaranov/go-openai/internal/test/checks"
67

78
"context"
89
"encoding/json"
@@ -40,9 +41,7 @@ func TestEdits(t *testing.T) {
4041
N: 3,
4142
}
4243
response, err := client.Edits(ctx, editReq)
43-
if err != nil {
44-
t.Fatalf("Edits error: %v", err)
45-
}
44+
checks.NoError(t, err, "Edits error")
4645
if len(response.Choices) != editReq.N {
4746
t.Fatalf("edits does not properly return the correct number of choices")
4847
}

embeddings_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai_test
22

33
import (
44
. "github.com/sashabaranov/go-openai"
5+
"github.com/sashabaranov/go-openai/internal/test/checks"
56

67
"bytes"
78
"encoding/json"
@@ -38,9 +39,7 @@ func TestEmbedding(t *testing.T) {
3839
// marshal embeddingReq to JSON and confirm that the model field equals
3940
// the AdaSearchQuery type
4041
marshaled, err := json.Marshal(embeddingReq)
41-
if err != nil {
42-
t.Fatalf("Could not marshal embedding request: %v", err)
43-
}
42+
checks.NoError(t, err, "Could not marshal embedding request")
4443
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
4544
t.Fatalf("Expected embedding request to contain model field")
4645
}

error_accumulator_test.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"testing"
88

99
"github.com/sashabaranov/go-openai/internal/test"
10+
"github.com/sashabaranov/go-openai/internal/test/checks"
1011
)
1112

1213
var (
@@ -81,16 +82,13 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
8182
ctx := context.Background()
8283

8384
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
84-
if err != nil {
85-
t.Fatal(err)
86-
}
85+
checks.NoError(t, err)
86+
8787
stream.errAccumulator = &defaultErrorAccumulator{
8888
buffer: &failingErrorBuffer{},
8989
unmarshaler: &jsonUnmarshaler{},
9090
}
9191

9292
_, err = stream.Recv()
93-
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
94-
t.Fatalf("Did not return error when write failed: %v", err)
95-
}
93+
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
9694
}

files_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai_test
33
import (
44
. "github.com/sashabaranov/go-openai"
55
"github.com/sashabaranov/go-openai/internal/test"
6+
"github.com/sashabaranov/go-openai/internal/test/checks"
67

78
"context"
89
"encoding/json"
@@ -33,9 +34,7 @@ func TestFileUpload(t *testing.T) {
3334
Purpose: "fine-tune",
3435
}
3536
_, err = client.CreateFile(ctx, req)
36-
if err != nil {
37-
t.Fatalf("CreateFile error: %v", err)
38-
}
37+
checks.NoError(t, err, "CreateFile erro")
3938
}
4039

4140
// handleCreateFile Handles the images endpoint by the test server.

0 commit comments

Comments
 (0)