Skip to content

Commit 0d3e887

Browse files
authored
Add whisper 1 support (#117)
* Add whisper 1 support * Resolve linting issues for audio source files
1 parent d668221 commit 0d3e887

File tree

2 files changed

+243
-0
lines changed

2 files changed

+243
-0
lines changed

audio.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package gogpt
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"fmt"
7+
"io"
8+
"mime/multipart"
9+
"net/http"
10+
"os"
11+
)
12+
13+
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
14+
const (
15+
Whisper1 = "whisper-1"
16+
)
17+
18+
// AudioRequest represents a request structure for audio API.
19+
type AudioRequest struct {
20+
Model string
21+
FilePath string
22+
}
23+
24+
// AudioResponse represents a response structure for audio API.
25+
type AudioResponse struct {
26+
Text string `json:"text"`
27+
}
28+
29+
// CreateTranscription — API call to create a transcription. Returns transcribed text.
30+
func (c *Client) CreateTranscription(
31+
ctx context.Context,
32+
request AudioRequest,
33+
) (response AudioResponse, err error) {
34+
response, err = c.callAudioAPI(ctx, request, "transcriptions")
35+
return
36+
}
37+
38+
// CreateTranscription — API call to create a transcription. Returns transcribed text.
39+
func (c *Client) CreateTranslation(
40+
ctx context.Context,
41+
request AudioRequest,
42+
) (response AudioResponse, err error) {
43+
response, err = c.callAudioAPI(ctx, request, "translations")
44+
return
45+
}
46+
47+
// callAudioAPI — API call to an audio endpoint.
48+
func (c *Client) callAudioAPI(
49+
ctx context.Context,
50+
request AudioRequest,
51+
endpointSuffix string,
52+
) (response AudioResponse, err error) {
53+
var formBody bytes.Buffer
54+
w := multipart.NewWriter(&formBody)
55+
56+
if err = audioMultipartForm(request, w); err != nil {
57+
return
58+
}
59+
60+
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
61+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody)
62+
if err != nil {
63+
return
64+
}
65+
req.Header.Add("Content-Type", w.FormDataContentType())
66+
67+
err = c.sendRequest(req, &response)
68+
return
69+
}
70+
71+
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
72+
// audio processing.
73+
func audioMultipartForm(request AudioRequest, w *multipart.Writer) error {
74+
f, err := os.Open(request.FilePath)
75+
if err != nil {
76+
return fmt.Errorf("opening audio file: %w", err)
77+
}
78+
79+
fw, err := w.CreateFormFile("file", f.Name())
80+
if err != nil {
81+
return fmt.Errorf("creating form file: %w", err)
82+
}
83+
84+
if _, err = io.Copy(fw, f); err != nil {
85+
return fmt.Errorf("reading from opened audio file: %w", err)
86+
}
87+
88+
fw, err = w.CreateFormField("model")
89+
if err != nil {
90+
return fmt.Errorf("creating form field: %w", err)
91+
}
92+
93+
modelName := bytes.NewReader([]byte(request.Model))
94+
if _, err = io.Copy(fw, modelName); err != nil {
95+
return fmt.Errorf("writing model name: %w", err)
96+
}
97+
w.Close()
98+
99+
return nil
100+
}

audio_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package gogpt_test
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"io"
7+
"mime"
8+
"mime/multipart"
9+
"net/http"
10+
"os"
11+
"path/filepath"
12+
"strings"
13+
14+
. "github.com/sashabaranov/go-gpt3"
15+
"github.com/sashabaranov/go-gpt3/internal/test"
16+
17+
"context"
18+
"testing"
19+
)
20+
21+
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
22+
func TestAudio(t *testing.T) {
23+
server := test.NewTestServer()
24+
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
25+
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
26+
// create the test server
27+
var err error
28+
ts := server.OpenAITestServer()
29+
ts.Start()
30+
defer ts.Close()
31+
32+
config := DefaultConfig(test.GetTestToken())
33+
config.BaseURL = ts.URL + "/v1"
34+
client := NewClientWithConfig(config)
35+
36+
testcases := []struct {
37+
name string
38+
createFn func(context.Context, AudioRequest) (AudioResponse, error)
39+
}{
40+
{
41+
"transcribe",
42+
client.CreateTranscription,
43+
},
44+
{
45+
"translate",
46+
client.CreateTranslation,
47+
},
48+
}
49+
50+
ctx := context.Background()
51+
52+
dir, cleanup := createTestDirectory(t)
53+
defer cleanup()
54+
55+
for _, tc := range testcases {
56+
t.Run(tc.name, func(t *testing.T) {
57+
path := filepath.Join(dir, "fake.mp3")
58+
createTestFile(t, path)
59+
60+
req := AudioRequest{
61+
FilePath: path,
62+
Model: "whisper-3",
63+
}
64+
_, err = tc.createFn(ctx, req)
65+
if err != nil {
66+
t.Fatalf("audio API error: %v", err)
67+
}
68+
})
69+
}
70+
}
71+
72+
// createTestFile creates a fake file with "hello" as the content.
73+
func createTestFile(t *testing.T, path string) {
74+
file, err := os.Create(path)
75+
if err != nil {
76+
t.Fatalf("failed to create file %v", err)
77+
}
78+
if _, err = file.WriteString("hello"); err != nil {
79+
t.Fatalf("failed to write to file %v", err)
80+
}
81+
file.Close()
82+
}
83+
84+
// createTestDirectory creates a temporary folder which will be deleted when cleanup is called.
85+
func createTestDirectory(t *testing.T) (path string, cleanup func()) {
86+
t.Helper()
87+
88+
path, err := os.MkdirTemp(os.TempDir(), "")
89+
if err != nil {
90+
t.Fatal(err)
91+
}
92+
93+
return path, func() { os.RemoveAll(path) }
94+
}
95+
96+
// handleAudioEndpoint Handles the completion endpoint by the test server.
97+
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
98+
var err error
99+
100+
// audio endpoints only accept POST requests
101+
if r.Method != "POST" {
102+
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
103+
}
104+
105+
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
106+
if err != nil {
107+
http.Error(w, "failed to parse media type", http.StatusBadRequest)
108+
return
109+
}
110+
111+
if !strings.HasPrefix(mediaType, "multipart") {
112+
http.Error(w, "request is not multipart", http.StatusBadRequest)
113+
}
114+
115+
boundary, ok := params["boundary"]
116+
if !ok {
117+
http.Error(w, "no boundary in params", http.StatusBadRequest)
118+
return
119+
}
120+
121+
fileData := &bytes.Buffer{}
122+
mr := multipart.NewReader(r.Body, boundary)
123+
part, err := mr.NextPart()
124+
if err != nil && errors.Is(err, io.EOF) {
125+
http.Error(w, "error accessing file", http.StatusBadRequest)
126+
return
127+
}
128+
if _, err = io.Copy(fileData, part); err != nil {
129+
http.Error(w, "failed to copy file", http.StatusInternalServerError)
130+
return
131+
}
132+
133+
if len(fileData.Bytes()) == 0 {
134+
w.WriteHeader(http.StatusInternalServerError)
135+
http.Error(w, "received empty file data", http.StatusBadRequest)
136+
return
137+
}
138+
139+
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
140+
http.Error(w, "failed to write body", http.StatusInternalServerError)
141+
return
142+
}
143+
}

0 commit comments

Comments
 (0)