@@ -129,6 +129,7 @@ func TestEdits(t *testing.T) {
129
129
t .Fatalf ("edits does not properly return the correct number of choices" )
130
130
}
131
131
}
132
+
132
133
func TestEmbedding (t * testing.T ) {
133
134
embeddedModels := []EmbeddingModel {
134
135
AdaSimilarity ,
@@ -269,6 +270,41 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
269
270
fmt .Fprintln (w , string (resBytes ))
270
271
}
271
272
273
+ // handleImageEndpoint Handles the images endpoint by the test server.
274
+ func handleImageEndpoint (w http.ResponseWriter , r * http.Request ) {
275
+ var err error
276
+ var resBytes []byte
277
+
278
+ // imagess only accepts POST requests
279
+ if r .Method != "POST" {
280
+ http .Error (w , "Method not allowed" , http .StatusMethodNotAllowed )
281
+ }
282
+ var imageReq ImageRequest
283
+ if imageReq , err = getImageBody (r ); err != nil {
284
+ http .Error (w , "could not read request" , http .StatusInternalServerError )
285
+ return
286
+ }
287
+ res := ImageResponse {
288
+ Created : uint64 (time .Now ().Unix ()),
289
+ }
290
+ for i := 0 ; i < imageReq .N ; i ++ {
291
+ imageData := ImageResponseDataInner {}
292
+ switch imageReq .ResponseFormat {
293
+ case CreateImageResponseFormatURL , "" :
294
+ imageData .URL = "https://example.com/image.png"
295
+ case CreateImageResponseFormatB64JSON :
296
+ // This decodes to "{}" in base64.
297
+ imageData .B64JSON = "e30K"
298
+ default :
299
+ http .Error (w , "invalid response format" , http .StatusBadRequest )
300
+ return
301
+ }
302
+ res .Data = append (res .Data , imageData )
303
+ }
304
+ resBytes , _ = json .Marshal (res )
305
+ fmt .Fprintln (w , string (resBytes ))
306
+ }
307
+
272
308
// getCompletionBody Returns the body of the request to create a completion.
273
309
func getCompletionBody (r * http.Request ) (CompletionRequest , error ) {
274
310
completion := CompletionRequest {}
@@ -284,6 +320,21 @@ func getCompletionBody(r *http.Request) (CompletionRequest, error) {
284
320
return completion , nil
285
321
}
286
322
323
+ // getImageBody Returns the body of the request to create a image.
324
+ func getImageBody (r * http.Request ) (ImageRequest , error ) {
325
+ image := ImageRequest {}
326
+ // read the request body
327
+ reqBody , err := ioutil .ReadAll (r .Body )
328
+ if err != nil {
329
+ return ImageRequest {}, err
330
+ }
331
+ err = json .Unmarshal (reqBody , & image )
332
+ if err != nil {
333
+ return ImageRequest {}, err
334
+ }
335
+ return image , nil
336
+ }
337
+
287
338
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
288
339
// This function approximates based on the rule of thumb stated by OpenAI:
289
340
// https://beta.openai.com/tokenizer
@@ -293,6 +344,25 @@ func numTokens(s string) int {
293
344
return int (float32 (len (s )) / 4 )
294
345
}
295
346
347
+ func TestImages (t * testing.T ) {
348
+ // create the test server
349
+ var err error
350
+ ts := OpenAITestServer ()
351
+ ts .Start ()
352
+ defer ts .Close ()
353
+
354
+ client := NewClient (testAPIToken )
355
+ ctx := context .Background ()
356
+ client .BaseURL = ts .URL + "/v1"
357
+
358
+ req := ImageRequest {}
359
+ req .Prompt = "Lorem ipsum"
360
+ _ , err = client .CreateImage (ctx , req )
361
+ if err != nil {
362
+ t .Fatalf ("CreateImage error: %v" , err )
363
+ }
364
+ }
365
+
296
366
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
297
367
func OpenAITestServer () * httptest.Server {
298
368
return httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -312,6 +382,8 @@ func OpenAITestServer() *httptest.Server {
312
382
case "/v1/completions" :
313
383
handleCompletionEndpoint (w , r )
314
384
return
385
+ case "/v1/images/generations" :
386
+ handleImageEndpoint (w , r )
315
387
// TODO: implement the other endpoints
316
388
default :
317
389
// the endpoint doesn't exist
0 commit comments