@@ -121,6 +121,30 @@ func TestEdits(t *testing.T) {
121
121
}
122
122
}
123
123
124
+ // TestModeration Tests the moderations endpoint of the API using the mocked server.
125
+ func TestModerations (t * testing.T ) {
126
+ // create the test server
127
+ var err error
128
+ ts := OpenAITestServer ()
129
+ ts .Start ()
130
+ defer ts .Close ()
131
+
132
+ client := NewClient (testAPIToken )
133
+ ctx := context .Background ()
134
+ client .BaseURL = ts .URL + "/v1"
135
+
136
+ // create an edit request
137
+ model := "text-moderation-stable"
138
+ moderationReq := ModerationRequest {
139
+ Model : & model ,
140
+ Input : "I want to kill them." ,
141
+ }
142
+ _ , err = client .Moderations (ctx , moderationReq )
143
+ if err != nil {
144
+ t .Fatalf ("Moderation error: %v" , err )
145
+ }
146
+ }
147
+
124
148
func TestEmbedding (t * testing.T ) {
125
149
embeddedModels := []EmbeddingModel {
126
150
AdaSimilarity ,
@@ -160,6 +184,25 @@ func TestEmbedding(t *testing.T) {
160
184
}
161
185
}
162
186
187
+ func TestImages (t * testing.T ) {
188
+ // create the test server
189
+ var err error
190
+ ts := OpenAITestServer ()
191
+ ts .Start ()
192
+ defer ts .Close ()
193
+
194
+ client := NewClient (testAPIToken )
195
+ ctx := context .Background ()
196
+ client .BaseURL = ts .URL + "/v1"
197
+
198
+ req := ImageRequest {}
199
+ req .Prompt = "Lorem ipsum"
200
+ _ , err = client .CreateImage (ctx , req )
201
+ if err != nil {
202
+ t .Fatalf ("CreateImage error: %v" , err )
203
+ }
204
+ }
205
+
163
206
// getEditBody Returns the body of the request to create an edit.
164
207
func getEditBody (r * http.Request ) (EditsRequest , error ) {
165
208
edit := EditsRequest {}
@@ -261,6 +304,21 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
261
304
fmt .Fprintln (w , string (resBytes ))
262
305
}
263
306
307
+ // getCompletionBody Returns the body of the request to create a completion.
308
+ func getCompletionBody (r * http.Request ) (CompletionRequest , error ) {
309
+ completion := CompletionRequest {}
310
+ // read the request body
311
+ reqBody , err := ioutil .ReadAll (r .Body )
312
+ if err != nil {
313
+ return CompletionRequest {}, err
314
+ }
315
+ err = json .Unmarshal (reqBody , & completion )
316
+ if err != nil {
317
+ return CompletionRequest {}, err
318
+ }
319
+ return completion , nil
320
+ }
321
+
264
322
// handleImageEndpoint Handles the images endpoint by the test server.
265
323
func handleImageEndpoint (w http.ResponseWriter , r * http.Request ) {
266
324
var err error
@@ -296,34 +354,78 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
296
354
fmt .Fprintln (w , string (resBytes ))
297
355
}
298
356
299
- // getCompletionBody Returns the body of the request to create a completion .
300
- func getCompletionBody (r * http.Request ) (CompletionRequest , error ) {
301
- completion := CompletionRequest {}
357
+ // getImageBody Returns the body of the request to create a image .
358
+ func getImageBody (r * http.Request ) (ImageRequest , error ) {
359
+ image := ImageRequest {}
302
360
// read the request body
303
361
reqBody , err := ioutil .ReadAll (r .Body )
304
362
if err != nil {
305
- return CompletionRequest {}, err
363
+ return ImageRequest {}, err
306
364
}
307
- err = json .Unmarshal (reqBody , & completion )
365
+ err = json .Unmarshal (reqBody , & image )
308
366
if err != nil {
309
- return CompletionRequest {}, err
367
+ return ImageRequest {}, err
310
368
}
311
- return completion , nil
369
+ return image , nil
312
370
}
313
371
314
- // getImageBody Returns the body of the request to create a image.
315
- func getImageBody (r * http.Request ) (ImageRequest , error ) {
316
- image := ImageRequest {}
372
+ // handleModerationEndpoint Handles the moderation endpoint by the test server.
373
+ func handleModerationEndpoint (w http.ResponseWriter , r * http.Request ) {
374
+ var err error
375
+ var resBytes []byte
376
+
377
+ // completions only accepts POST requests
378
+ if r .Method != "POST" {
379
+ http .Error (w , "Method not allowed" , http .StatusMethodNotAllowed )
380
+ }
381
+ var moderationReq ModerationRequest
382
+ if moderationReq , err = getModerationBody (r ); err != nil {
383
+ http .Error (w , "could not read request" , http .StatusInternalServerError )
384
+ return
385
+ }
386
+
387
+ resCat := ResultCategories {}
388
+ resCatScore := ResultCategoryScores {}
389
+ switch {
390
+ case strings .Contains (moderationReq .Input , "kill" ):
391
+ resCat = ResultCategories {Violence : true }
392
+ resCatScore = ResultCategoryScores {Violence : 1 }
393
+ case strings .Contains (moderationReq .Input , "hate" ):
394
+ resCat = ResultCategories {Hate : true }
395
+ resCatScore = ResultCategoryScores {Hate : 1 }
396
+ case strings .Contains (moderationReq .Input , "suicide" ):
397
+ resCat = ResultCategories {SelfHarm : true }
398
+ resCatScore = ResultCategoryScores {SelfHarm : 1 }
399
+ case strings .Contains (moderationReq .Input , "porn" ):
400
+ resCat = ResultCategories {Sexual : true }
401
+ resCatScore = ResultCategoryScores {Sexual : 1 }
402
+ }
403
+
404
+ result := Result {Categories : resCat , CategoryScores : resCatScore , Flagged : true }
405
+
406
+ res := ModerationResponse {
407
+ ID : strconv .Itoa (int (time .Now ().Unix ())),
408
+ Model : * moderationReq .Model ,
409
+ }
410
+ res .Results = append (res .Results , result )
411
+
412
+ resBytes , _ = json .Marshal (res )
413
+ fmt .Fprintln (w , string (resBytes ))
414
+ }
415
+
416
+ // getModerationBody Returns the body of the request to do a moderation.
417
+ func getModerationBody (r * http.Request ) (ModerationRequest , error ) {
418
+ moderation := ModerationRequest {}
317
419
// read the request body
318
420
reqBody , err := ioutil .ReadAll (r .Body )
319
421
if err != nil {
320
- return ImageRequest {}, err
422
+ return ModerationRequest {}, err
321
423
}
322
- err = json .Unmarshal (reqBody , & image )
424
+ err = json .Unmarshal (reqBody , & moderation )
323
425
if err != nil {
324
- return ImageRequest {}, err
426
+ return ModerationRequest {}, err
325
427
}
326
- return image , nil
428
+ return moderation , nil
327
429
}
328
430
329
431
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
@@ -335,25 +437,6 @@ func numTokens(s string) int {
335
437
return int (float32 (len (s )) / 4 )
336
438
}
337
439
338
- func TestImages (t * testing.T ) {
339
- // create the test server
340
- var err error
341
- ts := OpenAITestServer ()
342
- ts .Start ()
343
- defer ts .Close ()
344
-
345
- client := NewClient (testAPIToken )
346
- ctx := context .Background ()
347
- client .BaseURL = ts .URL + "/v1"
348
-
349
- req := ImageRequest {}
350
- req .Prompt = "Lorem ipsum"
351
- _ , err = client .CreateImage (ctx , req )
352
- if err != nil {
353
- t .Fatalf ("CreateImage error: %v" , err )
354
- }
355
- }
356
-
357
440
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
358
441
func OpenAITestServer () * httptest.Server {
359
442
return httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -373,6 +456,8 @@ func OpenAITestServer() *httptest.Server {
373
456
case "/v1/completions" :
374
457
handleCompletionEndpoint (w , r )
375
458
return
459
+ case "/v1/moderations" :
460
+ handleModerationEndpoint (w , r )
376
461
case "/v1/images/generations" :
377
462
handleImageEndpoint (w , r )
378
463
// TODO: implement the other endpoints
0 commit comments