Skip to content

Commit f1b6696

Browse files
authored
refactor: refactoring http request creation and sending (#395)
* refactoring http request creation and sending * fix lint error * increase the test coverage of client.go * refactor: Change the style of HTTPRequestBuilder.Build func to one-argument-per-line.
1 parent 157de06 commit f1b6696

20 files changed

+209
-126
lines changed

api_internal_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) {
9494
az.OrgID = c.OrgID
9595

9696
cli := NewClientWithConfig(az)
97-
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "")
97+
req, err := cli.newRequest(context.Background(), "POST", "/chat/completions")
9898
if err != nil {
9999
t.Errorf("Failed to create request: %v", err)
100100
}

audio.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ func (c *Client) callAudioAPI(
9595
}
9696

9797
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
98-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody)
98+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
99+
withBody(&formBody), withContentType(builder.FormDataContentType()))
99100
if err != nil {
100101
return AudioResponse{}, err
101102
}
102-
req.Header.Add("Content-Type", builder.FormDataContentType())
103103

104104
if request.HasJSONResponse() {
105105
err = c.sendRequest(req, &response)

chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ func (c *Client) CreateChatCompletion(
152152
return
153153
}
154154

155-
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
155+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
156156
if err != nil {
157157
return
158158
}

chat_stream.go

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package openai
22

33
import (
4-
"bufio"
54
"context"
6-
7-
utils "github.com/sashabaranov/go-openai/internal"
5+
"net/http"
86
)
97

108
type ChatCompletionStreamChoiceDelta struct {
@@ -48,27 +46,17 @@ func (c *Client) CreateChatCompletionStream(
4846
}
4947

5048
request.Stream = true
51-
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
49+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
5250
if err != nil {
53-
return
51+
return nil, err
5452
}
5553

56-
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
54+
resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req)
5755
if err != nil {
5856
return
5957
}
60-
if isFailureStatusCode(resp) {
61-
return nil, c.handleErrorResp(resp)
62-
}
63-
6458
stream = &ChatCompletionStream{
65-
streamReader: &streamReader[ChatCompletionStreamResponse]{
66-
emptyMessagesLimit: c.config.EmptyMessagesLimit,
67-
reader: bufio.NewReader(resp.Body),
68-
response: resp,
69-
errAccumulator: utils.NewErrorAccumulator(),
70-
unmarshaler: &utils.JSONUnmarshaler{},
71-
},
59+
streamReader: resp,
7260
}
7361
return
7462
}

client.go

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package openai
22

33
import (
4+
"bufio"
45
"context"
56
"encoding/json"
67
"fmt"
@@ -45,6 +46,42 @@ func NewOrgClient(authToken, org string) *Client {
4546
return NewClientWithConfig(config)
4647
}
4748

49+
type requestOptions struct {
50+
body any
51+
header http.Header
52+
}
53+
54+
type requestOption func(*requestOptions)
55+
56+
func withBody(body any) requestOption {
57+
return func(args *requestOptions) {
58+
args.body = body
59+
}
60+
}
61+
62+
func withContentType(contentType string) requestOption {
63+
return func(args *requestOptions) {
64+
args.header.Set("Content-Type", contentType)
65+
}
66+
}
67+
68+
func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) {
69+
// Default Options
70+
args := &requestOptions{
71+
body: nil,
72+
header: make(http.Header),
73+
}
74+
for _, setter := range setters {
75+
setter(args)
76+
}
77+
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
78+
if err != nil {
79+
return nil, err
80+
}
81+
c.setCommonHeaders(req)
82+
return req, nil
83+
}
84+
4885
func (c *Client) sendRequest(req *http.Request, v any) error {
4986
req.Header.Set("Accept", "application/json; charset=utf-8")
5087

@@ -55,8 +92,6 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
5592
req.Header.Set("Content-Type", "application/json; charset=utf-8")
5693
}
5794

58-
c.setCommonHeaders(req)
59-
6095
res, err := c.config.HTTPClient.Do(req)
6196
if err != nil {
6297
return err
@@ -71,6 +106,41 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
71106
return decodeResponse(res.Body, v)
72107
}
73108

109+
func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
110+
resp, err := c.config.HTTPClient.Do(req)
111+
if err != nil {
112+
return
113+
}
114+
115+
if isFailureStatusCode(resp) {
116+
err = c.handleErrorResp(resp)
117+
return
118+
}
119+
return resp.Body, nil
120+
}
121+
122+
func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) {
123+
req.Header.Set("Content-Type", "application/json")
124+
req.Header.Set("Accept", "text/event-stream")
125+
req.Header.Set("Cache-Control", "no-cache")
126+
req.Header.Set("Connection", "keep-alive")
127+
128+
resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
129+
if err != nil {
130+
return new(streamReader[T]), err
131+
}
132+
if isFailureStatusCode(resp) {
133+
return new(streamReader[T]), client.handleErrorResp(resp)
134+
}
135+
return &streamReader[T]{
136+
emptyMessagesLimit: client.config.EmptyMessagesLimit,
137+
reader: bufio.NewReader(resp.Body),
138+
response: resp,
139+
errAccumulator: utils.NewErrorAccumulator(),
140+
unmarshaler: &utils.JSONUnmarshaler{},
141+
}, nil
142+
}
143+
74144
func (c *Client) setCommonHeaders(req *http.Request) {
75145
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
76146
// Azure API Key authentication
@@ -138,26 +208,6 @@ func (c *Client) fullURL(suffix string, args ...any) string {
138208
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
139209
}
140210

141-
func (c *Client) newStreamRequest(
142-
ctx context.Context,
143-
method string,
144-
urlSuffix string,
145-
body any,
146-
model string) (*http.Request, error) {
147-
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body)
148-
if err != nil {
149-
return nil, err
150-
}
151-
152-
req.Header.Set("Content-Type", "application/json")
153-
req.Header.Set("Accept", "text/event-stream")
154-
req.Header.Set("Cache-Control", "no-cache")
155-
req.Header.Set("Connection", "keep-alive")
156-
157-
c.setCommonHeaders(req)
158-
return req, nil
159-
}
160-
161211
func (c *Client) handleErrorResp(resp *http.Response) error {
162212
var errRes ErrorResponse
163213
err := json.NewDecoder(resp.Body).Decode(&errRes)

client_test.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed")
1616

1717
type failingRequestBuilder struct{}
1818

19-
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
19+
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) {
2020
return nil, errTestRequestBuilderFailed
2121
}
2222

@@ -41,9 +41,10 @@ func TestDecodeResponse(t *testing.T) {
4141
stringInput := ""
4242

4343
testCases := []struct {
44-
name string
45-
value interface{}
46-
body io.Reader
44+
name string
45+
value interface{}
46+
body io.Reader
47+
hasError bool
4748
}{
4849
{
4950
name: "nil input",
@@ -60,18 +61,32 @@ func TestDecodeResponse(t *testing.T) {
6061
value: &map[string]interface{}{},
6162
body: bytes.NewReader([]byte(`{"test": "test"}`)),
6263
},
64+
{
65+
name: "reader return error",
66+
value: &stringInput,
67+
body: &errorReader{err: errors.New("dummy")},
68+
hasError: true,
69+
},
6370
}
6471

6572
for _, tc := range testCases {
6673
t.Run(tc.name, func(t *testing.T) {
6774
err := decodeResponse(tc.body, tc.value)
68-
if err != nil {
75+
if (err != nil) != tc.hasError {
6976
t.Errorf("Unexpected error: %v", err)
7077
}
7178
})
7279
}
7380
}
7481

82+
type errorReader struct {
83+
err error
84+
}
85+
86+
func (e *errorReader) Read(_ []byte) (n int, err error) {
87+
return 0, e.err
88+
}
89+
7590
func TestHandleErrorResp(t *testing.T) {
7691
// var errRes *ErrorResponse
7792
var errRes ErrorResponse

completion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (c *Client) CreateCompletion(
165165
return
166166
}
167167

168-
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
168+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
169169
if err != nil {
170170
return
171171
}

edits.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type EditsResponse struct {
3232

3333
// Perform an API call to the Edits endpoint.
3434
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
35-
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
35+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request))
3636
if err != nil {
3737
return
3838
}

embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ type EmbeddingRequest struct {
132132
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
133133
// https://beta.openai.com/docs/api-reference/embeddings/create
134134
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
135-
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
135+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request))
136136
if err != nil {
137137
return
138138
}

engines.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type EnginesList struct {
2222
// ListEngines Lists the currently available engines, and provides basic
2323
// information about each option such as the owner and availability.
2424
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
25-
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
25+
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines"))
2626
if err != nil {
2727
return
2828
}
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
3838
engineID string,
3939
) (engine Engine, err error) {
4040
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
41-
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
41+
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
4242
if err != nil {
4343
return
4444
}

0 commit comments

Comments
 (0)