Skip to content

Commit 53d195c

Browse files
authored
add testable json marshaller (#161)
1 parent ba77a64 commit 53d195c

File tree

10 files changed

+102
-18
lines changed

10 files changed

+102
-18
lines changed

api.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,22 @@ import (
1111
// Client is OpenAI GPT-3 API client.
1212
type Client struct {
1313
config ClientConfig
14+
15+
marshaller marshaller
1416
}
1517

1618
// NewClient creates new OpenAI API client.
1719
func NewClient(authToken string) *Client {
1820
config := DefaultConfig(authToken)
19-
return &Client{config}
21+
return NewClientWithConfig(config)
2022
}
2123

2224
// NewClientWithConfig creates new OpenAI API client for specified config.
2325
func NewClientWithConfig(config ClientConfig) *Client {
24-
return &Client{config}
26+
return &Client{
27+
config: config,
28+
marshaller: &jsonMarshaller{},
29+
}
2530
}
2631

2732
// NewOrgClient creates new OpenAI API client for specified Organization ID.
@@ -30,7 +35,7 @@ func NewClientWithConfig(config ClientConfig) *Client {
3035
func NewOrgClient(authToken, org string) *Client {
3136
config := DefaultConfig(authToken)
3237
config.OrgID = org
33-
return &Client{config}
38+
return NewClientWithConfig(config)
3439
}
3540

3641
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
@@ -90,7 +95,7 @@ func (c *Client) newStreamRequest(
9095
var reqBody []byte
9196
if body != nil {
9297
var err error
93-
reqBody, err = json.Marshal(body)
98+
reqBody, err = c.marshaller.marshal(body)
9499
if err != nil {
95100
return nil, err
96101
}

chat.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"errors"
87
"net/http"
98
)
@@ -74,7 +73,7 @@ func (c *Client) CreateChatCompletion(
7473
}
7574

7675
var reqBytes []byte
77-
reqBytes, err = json.Marshal(request)
76+
reqBytes, err = c.marshaller.marshal(request)
7877
if err != nil {
7978
return
8079
}

completion.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"errors"
87
"net/http"
98
)
@@ -107,7 +106,7 @@ func (c *Client) CreateCompletion(
107106
}
108107

109108
var reqBytes []byte
110-
reqBytes, err = json.Marshal(request)
109+
reqBytes, err = c.marshaller.marshal(request)
111110
if err != nil {
112111
return
113112
}

edits.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"net/http"
87
)
98

@@ -34,7 +33,7 @@ type EditsResponse struct {
3433
// Perform an API call to the Edits endpoint.
3534
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
3635
var reqBytes []byte
37-
reqBytes, err = json.Marshal(request)
36+
reqBytes, err = c.marshaller.marshal(request)
3837
if err != nil {
3938
return
4039
}

embeddings.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"net/http"
87
)
98

@@ -135,7 +134,7 @@ type EmbeddingRequest struct {
135134
// https://beta.openai.com/docs/api-reference/embeddings/create
136135
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
137136
var reqBytes []byte
138-
reqBytes, err = json.Marshal(request)
137+
reqBytes, err = c.marshaller.marshal(request)
139138
if err != nil {
140139
return
141140
}

fine_tunes.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"fmt"
87
"net/http"
98
)
@@ -70,7 +69,7 @@ type FineTuneDeleteResponse struct {
7069

7170
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
7271
var reqBytes []byte
73-
reqBytes, err = json.Marshal(request)
72+
reqBytes, err = c.marshaller.marshal(request)
7473
if err != nil {
7574
return
7675
}

image.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"io"
87
"mime/multipart"
98
"net/http"
@@ -47,7 +46,7 @@ type ImageResponseDataInner struct {
4746
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
4847
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
4948
var reqBytes []byte
50-
reqBytes, err = json.Marshal(request)
49+
reqBytes, err = c.marshaller.marshal(request)
5150
if err != nil {
5251
return
5352
}

marshaller.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package openai
2+
3+
import (
4+
"encoding/json"
5+
)
6+
7+
type marshaller interface {
8+
marshal(value any) ([]byte, error)
9+
}
10+
11+
type jsonMarshaller struct{}
12+
13+
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
14+
return json.Marshal(value)
15+
}

marshaller_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package openai //nolint:testpackage // testing private field
2+
3+
import (
4+
"github.com/sashabaranov/go-openai/internal/test"
5+
6+
"context"
7+
"errors"
8+
"testing"
9+
)
10+
11+
type failingMarshaller struct{}
12+
13+
var errTestMarshallerFailed = errors.New("test marshaller failed")
14+
15+
func (jm *failingMarshaller) marshal(value any) ([]byte, error) {
16+
return []byte{}, errTestMarshallerFailed
17+
}
18+
19+
func TestClientReturnMarshallerErrors(t *testing.T) {
20+
var err error
21+
ts := test.NewTestServer().OpenAITestServer()
22+
ts.Start()
23+
defer ts.Close()
24+
25+
config := DefaultConfig(test.GetTestToken())
26+
config.BaseURL = ts.URL + "/v1"
27+
client := NewClientWithConfig(config)
28+
client.marshaller = &failingMarshaller{}
29+
30+
ctx := context.Background()
31+
32+
_, err = client.CreateCompletion(ctx, CompletionRequest{})
33+
if !errors.Is(err, errTestMarshallerFailed) {
34+
t.Fatalf("Did not return error when marshaller failed: %v", err)
35+
}
36+
37+
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
38+
if !errors.Is(err, errTestMarshallerFailed) {
39+
t.Fatalf("Did not return error when marshaller failed: %v", err)
40+
}
41+
42+
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
43+
if !errors.Is(err, errTestMarshallerFailed) {
44+
t.Fatalf("Did not return error when marshaller failed: %v", err)
45+
}
46+
47+
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
48+
if !errors.Is(err, errTestMarshallerFailed) {
49+
t.Fatalf("Did not return error when marshaller failed: %v", err)
50+
}
51+
52+
_, err = client.Moderations(ctx, ModerationRequest{})
53+
if !errors.Is(err, errTestMarshallerFailed) {
54+
t.Fatalf("Did not return error when marshaller failed: %v", err)
55+
}
56+
57+
_, err = client.Edits(ctx, EditsRequest{})
58+
if !errors.Is(err, errTestMarshallerFailed) {
59+
t.Fatalf("Did not return error when marshaller failed: %v", err)
60+
}
61+
62+
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
63+
if !errors.Is(err, errTestMarshallerFailed) {
64+
t.Fatalf("Did not return error when marshaller failed: %v", err)
65+
}
66+
67+
_, err = client.CreateImage(ctx, ImageRequest{})
68+
if !errors.Is(err, errTestMarshallerFailed) {
69+
t.Fatalf("Did not return error when marshaller failed: %v", err)
70+
}
71+
}

moderation.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"net/http"
87
)
98

@@ -53,7 +52,7 @@ type ModerationResponse struct {
5352
// Input can be an array or slice but a string will reduce the complexity.
5453
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
5554
var reqBytes []byte
56-
reqBytes, err = json.Marshal(request)
55+
reqBytes, err = c.marshaller.marshal(request)
5756
if err != nil {
5857
return
5958
}

0 commit comments

Comments
 (0)