Skip to content

Commit 21eef5b

Browse files
Move form_builder into internal pkg. (#311)
* Move form_uilder into internal pkg. * Fix import of audio.go * Reorganize. * Fix import. * Fix --------- Co-authored-by: JoyShi <[email protected]>
1 parent 83d03fc commit 21eef5b

File tree

9 files changed

+96
-90
lines changed

9 files changed

+96
-90
lines changed

audio.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"fmt"
77
"net/http"
88
"os"
9+
10+
utils "github.com/sashabaranov/go-openai/internal"
911
)
1012

1113
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
@@ -72,7 +74,7 @@ func (c *Client) callAudioAPI(
7274
if err != nil {
7375
return AudioResponse{}, err
7476
}
75-
req.Header.Add("Content-Type", builder.formDataContentType())
77+
req.Header.Add("Content-Type", builder.FormDataContentType())
7678

7779
if request.HasJSONResponse() {
7880
err = c.sendRequest(req, &response)
@@ -92,55 +94,55 @@ func (r AudioRequest) HasJSONResponse() bool {
9294

9395
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
9496
// audio processing.
95-
func audioMultipartForm(request AudioRequest, b formBuilder) error {
97+
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
9698
f, err := os.Open(request.FilePath)
9799
if err != nil {
98100
return fmt.Errorf("opening audio file: %w", err)
99101
}
100102
defer f.Close()
101103

102-
err = b.createFormFile("file", f)
104+
err = b.CreateFormFile("file", f)
103105
if err != nil {
104106
return fmt.Errorf("creating form file: %w", err)
105107
}
106108

107-
err = b.writeField("model", request.Model)
109+
err = b.WriteField("model", request.Model)
108110
if err != nil {
109111
return fmt.Errorf("writing model name: %w", err)
110112
}
111113

112114
// Create a form field for the prompt (if provided)
113115
if request.Prompt != "" {
114-
err = b.writeField("prompt", request.Prompt)
116+
err = b.WriteField("prompt", request.Prompt)
115117
if err != nil {
116118
return fmt.Errorf("writing prompt: %w", err)
117119
}
118120
}
119121

120122
// Create a form field for the format (if provided)
121123
if request.Format != "" {
122-
err = b.writeField("response_format", string(request.Format))
124+
err = b.WriteField("response_format", string(request.Format))
123125
if err != nil {
124126
return fmt.Errorf("writing format: %w", err)
125127
}
126128
}
127129

128130
// Create a form field for the temperature (if provided)
129131
if request.Temperature != 0 {
130-
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
132+
err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature))
131133
if err != nil {
132134
return fmt.Errorf("writing temperature: %w", err)
133135
}
134136
}
135137

136138
// Create a form field for the language (if provided)
137139
if request.Language != "" {
138-
err = b.writeField("language", request.Language)
140+
err = b.WriteField("language", request.Language)
139141
if err != nil {
140142
return fmt.Errorf("writing language: %w", err)
141143
}
142144
}
143145

144146
// Close the multipart writer
145-
return b.close()
147+
return b.Close()
146148
}

client.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@ import (
77
"io"
88
"net/http"
99
"strings"
10+
11+
utils "github.com/sashabaranov/go-openai/internal"
1012
)
1113

1214
// Client is OpenAI GPT-3 API client.
1315
type Client struct {
1416
config ClientConfig
1517

1618
requestBuilder requestBuilder
17-
createFormBuilder func(io.Writer) formBuilder
19+
createFormBuilder func(io.Writer) utils.FormBuilder
1820
}
1921

2022
// NewClient creates new OpenAI API client.
@@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client {
2830
return &Client{
2931
config: config,
3032
requestBuilder: newRequestBuilder(),
31-
createFormBuilder: func(body io.Writer) formBuilder {
32-
return newFormBuilder(body)
33+
createFormBuilder: func(body io.Writer) utils.FormBuilder {
34+
return utils.NewFormBuilder(body)
3335
},
3436
}
3537
}

files.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
3636
var b bytes.Buffer
3737
builder := c.createFormBuilder(&b)
3838

39-
err = builder.writeField("purpose", request.Purpose)
39+
err = builder.WriteField("purpose", request.Purpose)
4040
if err != nil {
4141
return
4242
}
@@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
4646
return
4747
}
4848

49-
err = builder.createFormFile("file", fileData)
49+
err = builder.CreateFormFile("file", fileData)
5050
if err != nil {
5151
return
5252
}
5353

54-
err = builder.close()
54+
err = builder.Close()
5555
if err != nil {
5656
return
5757
}
@@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
6161
return
6262
}
6363

64-
req.Header.Set("Content-Type", builder.formDataContentType())
64+
req.Header.Set("Content-Type", builder.FormDataContentType())
6565

6666
err = c.sendRequest(req, &file)
6767

files_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package openai //nolint:testpackage // testing private field
22

33
import (
4+
. "github.com/sashabaranov/go-openai/internal"
45
"github.com/sashabaranov/go-openai/internal/test"
56
"github.com/sashabaranov/go-openai/internal/test/checks"
67

@@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
8586
config.BaseURL = ""
8687
client := NewClientWithConfig(config)
8788
mockBuilder := &mockFormBuilder{}
88-
client.createFormBuilder = func(io.Writer) formBuilder {
89+
client.createFormBuilder = func(io.Writer) FormBuilder {
8990
return mockBuilder
9091
}
9192

form_builder.go

Lines changed: 0 additions & 49 deletions
This file was deleted.

image.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
6969
builder := c.createFormBuilder(body)
7070

7171
// image
72-
err = builder.createFormFile("image", request.Image)
72+
err = builder.CreateFormFile("image", request.Image)
7373
if err != nil {
7474
return
7575
}
7676

7777
// mask, it is optional
7878
if request.Mask != nil {
79-
err = builder.createFormFile("mask", request.Mask)
79+
err = builder.CreateFormFile("mask", request.Mask)
8080
if err != nil {
8181
return
8282
}
8383
}
8484

85-
err = builder.writeField("prompt", request.Prompt)
85+
err = builder.WriteField("prompt", request.Prompt)
8686
if err != nil {
8787
return
8888
}
8989

90-
err = builder.writeField("n", strconv.Itoa(request.N))
90+
err = builder.WriteField("n", strconv.Itoa(request.N))
9191
if err != nil {
9292
return
9393
}
9494

95-
err = builder.writeField("size", request.Size)
95+
err = builder.WriteField("size", request.Size)
9696
if err != nil {
9797
return
9898
}
9999

100-
err = builder.writeField("response_format", request.ResponseFormat)
100+
err = builder.WriteField("response_format", request.ResponseFormat)
101101
if err != nil {
102102
return
103103
}
104104

105-
err = builder.close()
105+
err = builder.Close()
106106
if err != nil {
107107
return
108108
}
@@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
113113
return
114114
}
115115

116-
req.Header.Set("Content-Type", builder.formDataContentType())
116+
req.Header.Set("Content-Type", builder.FormDataContentType())
117117
err = c.sendRequest(req, &response)
118118
return
119119
}
@@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
133133
builder := c.createFormBuilder(body)
134134

135135
// image
136-
err = builder.createFormFile("image", request.Image)
136+
err = builder.CreateFormFile("image", request.Image)
137137
if err != nil {
138138
return
139139
}
140140

141-
err = builder.writeField("n", strconv.Itoa(request.N))
141+
err = builder.WriteField("n", strconv.Itoa(request.N))
142142
if err != nil {
143143
return
144144
}
145145

146-
err = builder.writeField("size", request.Size)
146+
err = builder.WriteField("size", request.Size)
147147
if err != nil {
148148
return
149149
}
150150

151-
err = builder.writeField("response_format", request.ResponseFormat)
151+
err = builder.WriteField("response_format", request.ResponseFormat)
152152
if err != nil {
153153
return
154154
}
155155

156-
err = builder.close()
156+
err = builder.Close()
157157
if err != nil {
158158
return
159159
}
@@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
165165
return
166166
}
167167

168-
req.Header.Set("Content-Type", builder.formDataContentType())
168+
req.Header.Set("Content-Type", builder.FormDataContentType())
169169
err = c.sendRequest(req, &response)
170170
return
171171
}

image_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package openai //nolint:testpackage // testing private field
22

33
import (
4+
utils "github.com/sashabaranov/go-openai/internal"
45
"github.com/sashabaranov/go-openai/internal/test"
56
"github.com/sashabaranov/go-openai/internal/test/checks"
67

@@ -268,19 +269,19 @@ type mockFormBuilder struct {
268269
mockClose func() error
269270
}
270271

271-
func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error {
272+
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
272273
return fb.mockCreateFormFile(fieldname, file)
273274
}
274275

275-
func (fb *mockFormBuilder) writeField(fieldname, value string) error {
276+
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
276277
return fb.mockWriteField(fieldname, value)
277278
}
278279

279-
func (fb *mockFormBuilder) close() error {
280+
func (fb *mockFormBuilder) Close() error {
280281
return fb.mockClose()
281282
}
282283

283-
func (fb *mockFormBuilder) formDataContentType() string {
284+
func (fb *mockFormBuilder) FormDataContentType() string {
284285
return ""
285286
}
286287

@@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) {
290291
client := NewClientWithConfig(config)
291292

292293
mockBuilder := &mockFormBuilder{}
293-
client.createFormBuilder = func(io.Writer) formBuilder {
294+
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
294295
return mockBuilder
295296
}
296297
ctx := context.Background()
@@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
357358
client := NewClientWithConfig(config)
358359

359360
mockBuilder := &mockFormBuilder{}
360-
client.createFormBuilder = func(io.Writer) formBuilder {
361+
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
361362
return mockBuilder
362363
}
363364
ctx := context.Background()

0 commit comments

Comments
 (0)