Skip to content

Commit 2c55a49

Browse files
authored
Add Image generation API (#48)
1 parent 1c20931 commit 2c55a49

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

api_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ func TestEdits(t *testing.T) {
129129
t.Fatalf("edits does not properly return the correct number of choices")
130130
}
131131
}
132+
132133
func TestEmbedding(t *testing.T) {
133134
embeddedModels := []EmbeddingModel{
134135
AdaSimilarity,
@@ -269,6 +270,41 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
269270
fmt.Fprintln(w, string(resBytes))
270271
}
271272

273+
// handleImageEndpoint Handles the images endpoint by the test server.
274+
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
275+
var err error
276+
var resBytes []byte
277+
278+
// imagess only accepts POST requests
279+
if r.Method != "POST" {
280+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
281+
}
282+
var imageReq ImageRequest
283+
if imageReq, err = getImageBody(r); err != nil {
284+
http.Error(w, "could not read request", http.StatusInternalServerError)
285+
return
286+
}
287+
res := ImageResponse{
288+
Created: uint64(time.Now().Unix()),
289+
}
290+
for i := 0; i < imageReq.N; i++ {
291+
imageData := ImageResponseDataInner{}
292+
switch imageReq.ResponseFormat {
293+
case CreateImageResponseFormatURL, "":
294+
imageData.URL = "https://example.com/image.png"
295+
case CreateImageResponseFormatB64JSON:
296+
// This decodes to "{}" in base64.
297+
imageData.B64JSON = "e30K"
298+
default:
299+
http.Error(w, "invalid response format", http.StatusBadRequest)
300+
return
301+
}
302+
res.Data = append(res.Data, imageData)
303+
}
304+
resBytes, _ = json.Marshal(res)
305+
fmt.Fprintln(w, string(resBytes))
306+
}
307+
272308
// getCompletionBody Returns the body of the request to create a completion.
273309
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
274310
completion := CompletionRequest{}
@@ -284,6 +320,21 @@ func getCompletionBody(r *http.Request) (CompletionRequest, error) {
284320
return completion, nil
285321
}
286322

323+
// getImageBody Returns the body of the request to create a image.
324+
func getImageBody(r *http.Request) (ImageRequest, error) {
325+
image := ImageRequest{}
326+
// read the request body
327+
reqBody, err := ioutil.ReadAll(r.Body)
328+
if err != nil {
329+
return ImageRequest{}, err
330+
}
331+
err = json.Unmarshal(reqBody, &image)
332+
if err != nil {
333+
return ImageRequest{}, err
334+
}
335+
return image, nil
336+
}
337+
287338
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
288339
// This function approximates based on the rule of thumb stated by OpenAI:
289340
// https://beta.openai.com/tokenizer
@@ -293,6 +344,25 @@ func numTokens(s string) int {
293344
return int(float32(len(s)) / 4)
294345
}
295346

347+
func TestImages(t *testing.T) {
348+
// create the test server
349+
var err error
350+
ts := OpenAITestServer()
351+
ts.Start()
352+
defer ts.Close()
353+
354+
client := NewClient(testAPIToken)
355+
ctx := context.Background()
356+
client.BaseURL = ts.URL + "/v1"
357+
358+
req := ImageRequest{}
359+
req.Prompt = "Lorem ipsum"
360+
_, err = client.CreateImage(ctx, req)
361+
if err != nil {
362+
t.Fatalf("CreateImage error: %v", err)
363+
}
364+
}
365+
296366
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
297367
func OpenAITestServer() *httptest.Server {
298368
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -312,6 +382,8 @@ func OpenAITestServer() *httptest.Server {
312382
case "/v1/completions":
313383
handleCompletionEndpoint(w, r)
314384
return
385+
case "/v1/images/generations":
386+
handleImageEndpoint(w, r)
315387
// TODO: implement the other endpoints
316388
default:
317389
// the endpoint doesn't exist

image.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package gogpt
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"net/http"
8+
)
9+
10+
// Image sizes defined by the OpenAI API.
11+
const (
12+
CreateImageSize256x256 = "256x256"
13+
CreateImageSize512x512 = "512x512"
14+
CreateImageSize1024x1024 = "1024x1024"
15+
)
16+
17+
const (
18+
CreateImageResponseFormatURL = "url"
19+
CreateImageResponseFormatB64JSON = "b64_json"
20+
)
21+
22+
// ImageRequest represents the request structure for the image API.
23+
type ImageRequest struct {
24+
Prompt string `json:"prompt,omitempty"`
25+
N int `json:"n,omitempty"`
26+
Size string `json:"size,omitempty"`
27+
ResponseFormat string `json:"response_format,omitempty"`
28+
User string `json:"user,omitempty"`
29+
}
30+
31+
// ImageResponse represents a response structure for image API.
32+
type ImageResponse struct {
33+
Created uint64 `json:"created,omitempty"`
34+
Data []ImageResponseDataInner `json:"data,omitempty"`
35+
}
36+
37+
// ImageResponseData represents a response data structure for image API.
38+
type ImageResponseDataInner struct {
39+
URL string `json:"url,omitempty"`
40+
B64JSON string `json:"b64_json,omitempty"`
41+
}
42+
43+
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
44+
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
45+
var reqBytes []byte
46+
reqBytes, err = json.Marshal(request)
47+
if err != nil {
48+
return
49+
}
50+
51+
urlSuffix := "/images/generations"
52+
req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
53+
if err != nil {
54+
return
55+
}
56+
57+
req = req.WithContext(ctx)
58+
err = c.sendRequest(req, &response)
59+
return
60+
}

0 commit comments

Comments
 (0)