Skip to content

Commit ecdea45

Browse files
authored
Adds support for audio captioning with Whisper (#267)
* Add speech to text example in docs * Add caption formats for audio transcription * Add caption example to README * Address sanity check errors * Add tests for decodeResponse * Use typechecker for audio response format * Decoding response refactors
1 parent d6ab1b3 commit ecdea45

File tree

5 files changed

+129
-8
lines changed

5 files changed

+129
-8
lines changed

README.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,47 @@ func main() {
223223
```
224224
</details>
225225

226+
<details>
227+
<summary>Audio Captions</summary>
228+
229+
```go
230+
package main
231+
232+
import (
233+
"context"
234+
"fmt"
235+
"os"
236+
237+
openai "github.com/sashabaranov/go-openai"
238+
)
239+
240+
func main() {
241+
c := openai.NewClient(os.Getenv("OPENAI_KEY"))
242+
243+
req := openai.AudioRequest{
244+
Model: openai.Whisper1,
245+
FilePath: os.Args[1],
246+
Format: openai.AudioResponseFormatSRT,
247+
}
248+
resp, err := c.CreateTranscription(context.Background(), req)
249+
if err != nil {
250+
fmt.Printf("Transcription error: %v\n", err)
251+
return
252+
}
253+
f, err := os.Create(os.Args[1] + ".srt")
254+
if err != nil {
255+
fmt.Printf("Could not open file: %v\n", err)
256+
return
257+
}
258+
defer f.Close()
259+
if _, err := f.WriteString(resp.Text); err != nil {
260+
fmt.Printf("Error writing to file: %v\n", err)
261+
return
262+
}
263+
}
264+
```
265+
</details>
266+
226267
<details>
227268
<summary>DALL-E 2 image generation</summary>
228269

@@ -420,4 +461,4 @@ func main() {
420461
fmt.Println(resp.Choices[0].Message.Content)
421462
}
422463
```
423-
</details>
464+
</details>

audio.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ const (
1313
Whisper1 = "whisper-1"
1414
)
1515

16+
// Response formats; Whisper uses AudioResponseFormatJSON by default.
17+
type AudioResponseFormat string
18+
19+
const (
20+
AudioResponseFormatJSON AudioResponseFormat = "json"
21+
AudioResponseFormatSRT AudioResponseFormat = "srt"
22+
AudioResponseFormatVTT AudioResponseFormat = "vtt"
23+
)
24+
1625
// AudioRequest represents a request structure for audio API.
1726
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
1827
type AudioRequest struct {
@@ -21,6 +30,7 @@ type AudioRequest struct {
2130
Prompt string // For translation, it should be in English
2231
Temperature float32
2332
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
33+
Format AudioResponseFormat
2434
}
2535

2636
// AudioResponse represents a response structure for audio API.
@@ -66,10 +76,19 @@ func (c *Client) callAudioAPI(
6676
}
6777
req.Header.Add("Content-Type", builder.formDataContentType())
6878

69-
err = c.sendRequest(req, &response)
79+
if request.HasJSONResponse() {
80+
err = c.sendRequest(req, &response)
81+
} else {
82+
err = c.sendRequest(req, &response.Text)
83+
}
7084
return
7185
}
7286

87+
// HasJSONResponse returns true if the response format is JSON.
88+
func (r AudioRequest) HasJSONResponse() bool {
89+
return r.Format == "" || r.Format == AudioResponseFormatJSON
90+
}
91+
7392
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
7493
// audio processing.
7594
func audioMultipartForm(request AudioRequest, b formBuilder) error {
@@ -97,6 +116,14 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
97116
}
98117
}
99118

119+
// Create a form field for the format (if provided)
120+
if request.Format != "" {
121+
err = b.writeField("response_format", string(request.Format))
122+
if err != nil {
123+
return fmt.Errorf("writing format: %w", err)
124+
}
125+
}
126+
100127
// Create a form field for the temperature (if provided)
101128
if request.Temperature != 0 {
102129
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))

audio_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
112112
Prompt: "用简体中文",
113113
Temperature: 0.5,
114114
Language: "zh",
115+
Format: AudioResponseFormatSRT,
115116
}
116117
_, err = tc.createFn(ctx, req)
117118
checks.NoError(t, err, "audio API error")
@@ -179,6 +180,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
179180
Prompt: "test",
180181
Temperature: 0.5,
181182
Language: "en",
183+
Format: AudioResponseFormatSRT,
182184
}
183185

184186
mockFailedErr := fmt.Errorf("mock form builder fail")
@@ -202,7 +204,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
202204
return nil
203205
}
204206

205-
failOn := []string{"model", "prompt", "temperature", "language"}
207+
failOn := []string{"model", "prompt", "temperature", "language", "response_format"}
206208
for _, failingField := range failOn {
207209
failForField = failingField
208210
mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField)

client.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func NewOrgClient(authToken, org string) *Client {
4343
return NewClientWithConfig(config)
4444
}
4545

46-
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
46+
func (c *Client) sendRequest(req *http.Request, v any) error {
4747
req.Header.Set("Accept", "application/json; charset=utf-8")
4848
// Azure API Key authentication
4949
if c.config.APIType == APITypeAzure {
@@ -75,12 +75,26 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
7575
return c.handleErrorResp(res)
7676
}
7777

78-
if v != nil {
79-
if err = json.NewDecoder(res.Body).Decode(v); err != nil {
80-
return err
81-
}
78+
return decodeResponse(res.Body, v)
79+
}
80+
81+
func decodeResponse(body io.Reader, v any) error {
82+
if v == nil {
83+
return nil
8284
}
8385

86+
if result, ok := v.(*string); ok {
87+
return decodeString(body, result)
88+
}
89+
return json.NewDecoder(body).Decode(v)
90+
}
91+
92+
func decodeString(body io.Reader, output *string) error {
93+
b, err := io.ReadAll(body)
94+
if err != nil {
95+
return err
96+
}
97+
*output = string(b)
8498
return nil
8599
}
86100

client_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package openai //nolint:testpackage // testing private field
22

33
import (
4+
"bytes"
5+
"io"
46
"testing"
57
)
68

@@ -20,3 +22,38 @@ func TestClient(t *testing.T) {
2022
t.Errorf("Client does not contain proper orgID")
2123
}
2224
}
25+
26+
func TestDecodeResponse(t *testing.T) {
27+
stringInput := ""
28+
29+
testCases := []struct {
30+
name string
31+
value interface{}
32+
body io.Reader
33+
}{
34+
{
35+
name: "nil input",
36+
value: nil,
37+
body: bytes.NewReader([]byte("")),
38+
},
39+
{
40+
name: "string input",
41+
value: &stringInput,
42+
body: bytes.NewReader([]byte("test")),
43+
},
44+
{
45+
name: "map input",
46+
value: &map[string]interface{}{},
47+
body: bytes.NewReader([]byte(`{"test": "test"}`)),
48+
},
49+
}
50+
51+
for _, tc := range testCases {
52+
t.Run(tc.name, func(t *testing.T) {
53+
err := decodeResponse(tc.body, tc.value)
54+
if err != nil {
55+
t.Errorf("Unexpected error: %v", err)
56+
}
57+
})
58+
}
59+
}

0 commit comments

Comments
 (0)