Skip to content

Commit fa694c6

Browse files
authored
Implement optional io.Reader in AudioRequest (#303) (#265) (#331)
* Implement optional io.Reader in AudioRequest (#303) (#265) * Fix err shadowing * Add test to cover AudioRequest io.Reader usage * Add additional test cases to cover AudioRequest io.Reader usage * Add test to cover opening the file specified in an AudioRequest
1 parent 61ba5f3 commit fa694c6

File tree

4 files changed

+124
-18
lines changed

4 files changed

+124
-18
lines changed

audio.go

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"fmt"
7+
"io"
78
"net/http"
89
"os"
910

@@ -27,8 +28,14 @@ const (
2728
// AudioRequest represents a request structure for audio API.
2829
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
2930
type AudioRequest struct {
30-
Model string
31-
FilePath string
31+
Model string
32+
33+
// FilePath is either an existing file in your filesystem or a filename representing the contents of Reader.
34+
FilePath string
35+
36+
// Reader is an optional io.Reader when you do not want to use an existing file.
37+
Reader io.Reader
38+
3239
Prompt string // For translation, it should be in English
3340
Temperature float32
3441
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
@@ -95,15 +102,9 @@ func (r AudioRequest) HasJSONResponse() bool {
95102
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
96103
// audio processing.
97104
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
98-
f, err := os.Open(request.FilePath)
99-
if err != nil {
100-
return fmt.Errorf("opening audio file: %w", err)
101-
}
102-
defer f.Close()
103-
104-
err = b.CreateFormFile("file", f)
105+
err := createFileField(request, b)
105106
if err != nil {
106-
return fmt.Errorf("creating form file: %w", err)
107+
return err
107108
}
108109

109110
err = b.WriteField("model", request.Model)
@@ -146,3 +147,27 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
146147
// Close the multipart writer
147148
return b.Close()
148149
}
150+
151+
// createFileField creates the "file" form field from either an existing file or by using the reader.
152+
func createFileField(request AudioRequest, b utils.FormBuilder) error {
153+
if request.Reader != nil {
154+
err := b.CreateFormFileReader("file", request.Reader, request.FilePath)
155+
if err != nil {
156+
return fmt.Errorf("creating form using reader: %w", err)
157+
}
158+
return nil
159+
}
160+
161+
f, err := os.Open(request.FilePath)
162+
if err != nil {
163+
return fmt.Errorf("opening audio file: %w", err)
164+
}
165+
defer f.Close()
166+
167+
err = b.CreateFormFile("file", f)
168+
if err != nil {
169+
return fmt.Errorf("creating form file: %w", err)
170+
}
171+
172+
return nil
173+
}

audio_test.go

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field
22

33
import (
44
"bytes"
5+
"context"
56
"errors"
67
"fmt"
78
"io"
@@ -11,12 +12,10 @@ import (
1112
"os"
1213
"path/filepath"
1314
"strings"
15+
"testing"
1416

1517
"github.com/sashabaranov/go-openai/internal/test"
1618
"github.com/sashabaranov/go-openai/internal/test/checks"
17-
18-
"context"
19-
"testing"
2019
)
2120

2221
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
@@ -65,6 +64,16 @@ func TestAudio(t *testing.T) {
6564
_, err = tc.createFn(ctx, req)
6665
checks.NoError(t, err, "audio API error")
6766
})
67+
68+
t.Run(tc.name+" (with reader)", func(t *testing.T) {
69+
req := AudioRequest{
70+
FilePath: "fake.webm",
71+
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
72+
Model: "whisper-3",
73+
}
74+
_, err = tc.createFn(ctx, req)
75+
checks.NoError(t, err, "audio API error")
76+
})
6877
}
6978
}
7079

@@ -213,3 +222,54 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
213222
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
214223
}
215224
}
225+
226+
func TestCreateFileField(t *testing.T) {
227+
t.Run("createFileField failing file", func(t *testing.T) {
228+
dir, cleanup := test.CreateTestDirectory(t)
229+
defer cleanup()
230+
path := filepath.Join(dir, "fake.mp3")
231+
test.CreateTestFile(t, path)
232+
233+
req := AudioRequest{
234+
FilePath: path,
235+
}
236+
237+
mockFailedErr := fmt.Errorf("mock form builder fail")
238+
mockBuilder := &mockFormBuilder{
239+
mockCreateFormFile: func(string, *os.File) error {
240+
return mockFailedErr
241+
},
242+
}
243+
244+
err := createFileField(req, mockBuilder)
245+
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails")
246+
})
247+
248+
t.Run("createFileField failing reader", func(t *testing.T) {
249+
req := AudioRequest{
250+
FilePath: "test.wav",
251+
Reader: bytes.NewBuffer([]byte(`wav test contents`)),
252+
}
253+
254+
mockFailedErr := fmt.Errorf("mock form builder fail")
255+
mockBuilder := &mockFormBuilder{
256+
mockCreateFormFileReader: func(string, io.Reader, string) error {
257+
return mockFailedErr
258+
},
259+
}
260+
261+
err := createFileField(req, mockBuilder)
262+
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails")
263+
})
264+
265+
t.Run("createFileField failing open", func(t *testing.T) {
266+
req := AudioRequest{
267+
FilePath: "non_existing_file.wav",
268+
}
269+
270+
mockBuilder := &mockFormBuilder{}
271+
272+
err := createFileField(req, mockBuilder)
273+
checks.HasError(t, err, "createFileField using file should return error when open file fails")
274+
})
275+
}

image_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,20 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
264264
}
265265

266266
type mockFormBuilder struct {
267-
mockCreateFormFile func(string, *os.File) error
268-
mockWriteField func(string, string) error
269-
mockClose func() error
267+
mockCreateFormFile func(string, *os.File) error
268+
mockCreateFormFileReader func(string, io.Reader, string) error
269+
mockWriteField func(string, string) error
270+
mockClose func() error
270271
}
271272

272273
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
273274
return fb.mockCreateFormFile(fieldname, file)
274275
}
275276

277+
func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
278+
return fb.mockCreateFormFileReader(fieldname, r, filename)
279+
}
280+
276281
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
277282
return fb.mockWriteField(fieldname, value)
278283
}

internal/form_builder.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package openai
22

33
import (
4+
"fmt"
45
"io"
56
"mime/multipart"
67
"os"
8+
"path"
79
)
810

911
type FormBuilder interface {
1012
CreateFormFile(fieldname string, file *os.File) error
13+
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
1114
WriteField(fieldname, value string) error
1215
Close() error
1316
FormDataContentType() string
@@ -24,15 +27,28 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
2427
}
2528

2629
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
27-
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
30+
return fb.createFormFile(fieldname, file, file.Name())
31+
}
32+
33+
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
34+
return fb.createFormFile(fieldname, r, path.Base(filename))
35+
}
36+
37+
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
38+
if filename == "" {
39+
return fmt.Errorf("filename cannot be empty")
40+
}
41+
42+
fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
2843
if err != nil {
2944
return err
3045
}
3146

32-
_, err = io.Copy(fieldWriter, file)
47+
_, err = io.Copy(fieldWriter, r)
3348
if err != nil {
3449
return err
3550
}
51+
3652
return nil
3753
}
3854

0 commit comments

Comments
 (0)