Skip to content

Commit 226ff32

Browse files
authored
Add form builder (#235)
* add form builder * cover VariImage * test for closing errors * simplify tests * add audio api test coverage * don't leak authToken when printed * rename api->client * fix test
1 parent 2f3700f commit 226ff32

File tree

8 files changed

+270
-70
lines changed

8 files changed

+270
-70
lines changed

audio.go

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"bytes"
55
"context"
66
"fmt"
7-
"io"
8-
"mime/multipart"
97
"net/http"
108
"os"
119
)
@@ -55,9 +53,9 @@ func (c *Client) callAudioAPI(
5553
endpointSuffix string,
5654
) (response AudioResponse, err error) {
5755
var formBody bytes.Buffer
58-
w := multipart.NewWriter(&formBody)
56+
builder := c.createFormBuilder(&formBody)
5957

60-
if err = audioMultipartForm(request, w); err != nil {
58+
if err = audioMultipartForm(request, builder); err != nil {
6159
return
6260
}
6361

@@ -66,81 +64,55 @@ func (c *Client) callAudioAPI(
6664
if err != nil {
6765
return
6866
}
69-
req.Header.Add("Content-Type", w.FormDataContentType())
67+
req.Header.Add("Content-Type", builder.formDataContentType())
7068

7169
err = c.sendRequest(req, &response)
7270
return
7371
}
7472

7573
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
7674
// audio processing.
77-
func audioMultipartForm(request AudioRequest, w *multipart.Writer) error {
75+
func audioMultipartForm(request AudioRequest, b formBuilder) error {
7876
f, err := os.Open(request.FilePath)
7977
if err != nil {
8078
return fmt.Errorf("opening audio file: %w", err)
8179
}
8280
defer f.Close()
8381

84-
fw, err := w.CreateFormFile("file", f.Name())
82+
err = b.createFormFile("file", f)
8583
if err != nil {
8684
return fmt.Errorf("creating form file: %w", err)
8785
}
8886

89-
if _, err = io.Copy(fw, f); err != nil {
90-
return fmt.Errorf("reading from opened audio file: %w", err)
91-
}
92-
93-
fw, err = w.CreateFormField("model")
87+
err = b.writeField("model", request.Model)
9488
if err != nil {
95-
return fmt.Errorf("creating form field: %w", err)
96-
}
97-
98-
modelName := bytes.NewReader([]byte(request.Model))
99-
if _, err = io.Copy(fw, modelName); err != nil {
10089
return fmt.Errorf("writing model name: %w", err)
10190
}
10291

10392
// Create a form field for the prompt (if provided)
10493
if request.Prompt != "" {
105-
fw, err = w.CreateFormField("prompt")
94+
err = b.writeField("prompt", request.Prompt)
10695
if err != nil {
107-
return fmt.Errorf("creating form field: %w", err)
108-
}
109-
110-
prompt := bytes.NewReader([]byte(request.Prompt))
111-
if _, err = io.Copy(fw, prompt); err != nil {
11296
return fmt.Errorf("writing prompt: %w", err)
11397
}
11498
}
11599

116100
// Create a form field for the temperature (if provided)
117101
if request.Temperature != 0 {
118-
fw, err = w.CreateFormField("temperature")
102+
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
119103
if err != nil {
120-
return fmt.Errorf("creating form field: %w", err)
121-
}
122-
123-
temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature)))
124-
if _, err = io.Copy(fw, temperature); err != nil {
125104
return fmt.Errorf("writing temperature: %w", err)
126105
}
127106
}
128107

129108
// Create a form field for the language (if provided)
130109
if request.Language != "" {
131-
fw, err = w.CreateFormField("language")
110+
err = b.writeField("language", request.Language)
132111
if err != nil {
133-
return fmt.Errorf("creating form field: %w", err)
134-
}
135-
136-
language := bytes.NewReader([]byte(request.Language))
137-
if _, err = io.Copy(fw, language); err != nil {
138112
return fmt.Errorf("writing language: %w", err)
139113
}
140114
}
141115

142116
// Close the multipart writer
143-
w.Close()
144-
145-
return nil
117+
return b.close()
146118
}

audio_test.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
package openai_test
1+
package openai //nolint:testpackage // testing private field
22

33
import (
44
"bytes"
55
"errors"
6+
"fmt"
67
"io"
78
"mime"
89
"mime/multipart"
@@ -11,7 +12,6 @@ import (
1112
"path/filepath"
1213
"strings"
1314

14-
. "github.com/sashabaranov/go-openai"
1515
"github.com/sashabaranov/go-openai/internal/test"
1616
"github.com/sashabaranov/go-openai/internal/test/checks"
1717

@@ -188,3 +188,47 @@ func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
188188
return
189189
}
190190
}
191+
192+
func TestAudioWithFailingFormBuilder(t *testing.T) {
193+
dir, cleanup := createTestDirectory(t)
194+
defer cleanup()
195+
path := filepath.Join(dir, "fake.mp3")
196+
createTestFile(t, path)
197+
198+
req := AudioRequest{
199+
FilePath: path,
200+
Prompt: "test",
201+
Temperature: 0.5,
202+
Language: "en",
203+
}
204+
205+
mockFailedErr := fmt.Errorf("mock form builder fail")
206+
mockBuilder := &mockFormBuilder{}
207+
208+
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
209+
return mockFailedErr
210+
}
211+
err := audioMultipartForm(req, mockBuilder)
212+
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
213+
214+
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
215+
return nil
216+
}
217+
218+
var failForField string
219+
mockBuilder.mockWriteField = func(fieldname, value string) error {
220+
if fieldname == failForField {
221+
return mockFailedErr
222+
}
223+
return nil
224+
}
225+
226+
failOn := []string{"model", "prompt", "temperature", "language"}
227+
for _, failingField := range failOn {
228+
failForField = failingField
229+
mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField)
230+
231+
err = audioMultipartForm(req, mockBuilder)
232+
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
233+
}
234+
}

api.go renamed to client.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"io"
78
"net/http"
89
"strings"
910
)
@@ -12,7 +13,8 @@ import (
1213
type Client struct {
1314
config ClientConfig
1415

15-
requestBuilder requestBuilder
16+
requestBuilder requestBuilder
17+
createFormBuilder func(io.Writer) formBuilder
1618
}
1719

1820
// NewClient creates new OpenAI API client.
@@ -26,6 +28,9 @@ func NewClientWithConfig(config ClientConfig) *Client {
2628
return &Client{
2729
config: config,
2830
requestBuilder: newRequestBuilder(),
31+
createFormBuilder: func(body io.Writer) formBuilder {
32+
return newFormBuilder(body)
33+
},
2934
}
3035
}
3136

config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,7 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
6464
EmptyMessagesLimit: defaultEmptyMessagesLimit,
6565
}
6666
}
67+
68+
func (ClientConfig) String() string {
69+
return "<OpenAI API ClientConfig>"
70+
}

files_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func TestFileUpload(t *testing.T) {
3030

3131
req := FileRequest{
3232
FileName: "test.go",
33-
FilePath: "api.go",
33+
FilePath: "client.go",
3434
Purpose: "fine-tune",
3535
}
3636
_, err = client.CreateFile(ctx, req)

form_builder.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package openai
2+
3+
import (
4+
"io"
5+
"mime/multipart"
6+
"os"
7+
)
8+
9+
type formBuilder interface {
10+
createFormFile(fieldname string, file *os.File) error
11+
writeField(fieldname, value string) error
12+
close() error
13+
formDataContentType() string
14+
}
15+
16+
type defaultFormBuilder struct {
17+
writer *multipart.Writer
18+
}
19+
20+
func newFormBuilder(body io.Writer) *defaultFormBuilder {
21+
return &defaultFormBuilder{
22+
writer: multipart.NewWriter(body),
23+
}
24+
}
25+
26+
func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error {
27+
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
28+
if err != nil {
29+
return err
30+
}
31+
32+
_, err = io.Copy(fieldWriter, file)
33+
if err != nil {
34+
return err
35+
}
36+
return nil
37+
}
38+
39+
func (fb *defaultFormBuilder) writeField(fieldname, value string) error {
40+
return fb.writer.WriteField(fieldname, value)
41+
}
42+
43+
func (fb *defaultFormBuilder) close() error {
44+
return fb.writer.Close()
45+
}
46+
47+
func (fb *defaultFormBuilder) formDataContentType() string {
48+
return fb.writer.FormDataContentType()
49+
}

image.go

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"io"
7-
"mime/multipart"
86
"net/http"
97
"os"
108
"strconv"
@@ -67,50 +65,46 @@ type ImageEditRequest struct {
6765
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
6866
func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) (response ImageResponse, err error) {
6967
body := &bytes.Buffer{}
70-
writer := multipart.NewWriter(body)
68+
builder := c.createFormBuilder(body)
7169

7270
// image
73-
image, err := writer.CreateFormFile("image", request.Image.Name())
74-
if err != nil {
75-
return
76-
}
77-
_, err = io.Copy(image, request.Image)
71+
err = builder.createFormFile("image", request.Image)
7872
if err != nil {
7973
return
8074
}
8175

8276
// mask, it is optional
8377
if request.Mask != nil {
84-
mask, err2 := writer.CreateFormFile("mask", request.Mask.Name())
85-
if err2 != nil {
86-
return
87-
}
88-
_, err = io.Copy(mask, request.Mask)
78+
err = builder.createFormFile("mask", request.Mask)
8979
if err != nil {
9080
return
9181
}
9282
}
9383

94-
err = writer.WriteField("prompt", request.Prompt)
84+
err = builder.writeField("prompt", request.Prompt)
85+
if err != nil {
86+
return
87+
}
88+
err = builder.writeField("n", strconv.Itoa(request.N))
9589
if err != nil {
9690
return
9791
}
98-
err = writer.WriteField("n", strconv.Itoa(request.N))
92+
err = builder.writeField("size", request.Size)
9993
if err != nil {
10094
return
10195
}
102-
err = writer.WriteField("size", request.Size)
96+
err = builder.close()
10397
if err != nil {
10498
return
10599
}
106-
writer.Close()
100+
107101
urlSuffix := "/images/edits"
108102
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
109103
if err != nil {
110104
return
111105
}
112106

113-
req.Header.Set("Content-Type", writer.FormDataContentType())
107+
req.Header.Set("Content-Type", builder.formDataContentType())
114108
err = c.sendRequest(req, &response)
115109
return
116110
}
@@ -126,35 +120,35 @@ type ImageVariRequest struct {
126120
// Use abbreviations(vari for variation) because ci-lint has a single-line length limit ...
127121
func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) {
128122
body := &bytes.Buffer{}
129-
writer := multipart.NewWriter(body)
123+
builder := c.createFormBuilder(body)
130124

131125
// image
132-
image, err := writer.CreateFormFile("image", request.Image.Name())
126+
err = builder.createFormFile("image", request.Image)
133127
if err != nil {
134128
return
135129
}
136-
_, err = io.Copy(image, request.Image)
130+
131+
err = builder.writeField("n", strconv.Itoa(request.N))
137132
if err != nil {
138133
return
139134
}
140-
141-
err = writer.WriteField("n", strconv.Itoa(request.N))
135+
err = builder.writeField("size", request.Size)
142136
if err != nil {
143137
return
144138
}
145-
err = writer.WriteField("size", request.Size)
139+
err = builder.close()
146140
if err != nil {
147141
return
148142
}
149-
writer.Close()
143+
150144
//https://platform.openai.com/docs/api-reference/images/create-variation
151145
urlSuffix := "/images/variations"
152146
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
153147
if err != nil {
154148
return
155149
}
156150

157-
req.Header.Set("Content-Type", writer.FormDataContentType())
151+
req.Header.Set("Content-Type", builder.formDataContentType())
158152
err = c.sendRequest(req, &response)
159153
return
160154
}

0 commit comments

Comments
 (0)