Skip to content

Commit 794a551

Browse files
sm2642smehta12
andauthored
-Added moderation endpoint test (#56)
-Rearrange some code Co-authored-by: Shalin <[email protected]>
1 parent 81b5788 commit 794a551

File tree

1 file changed

+118
-33
lines changed

1 file changed

+118
-33
lines changed

api_test.go

Lines changed: 118 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,30 @@ func TestEdits(t *testing.T) {
121121
}
122122
}
123123

124+
// TestModeration Tests the moderations endpoint of the API using the mocked server.
125+
func TestModerations(t *testing.T) {
126+
// create the test server
127+
var err error
128+
ts := OpenAITestServer()
129+
ts.Start()
130+
defer ts.Close()
131+
132+
client := NewClient(testAPIToken)
133+
ctx := context.Background()
134+
client.BaseURL = ts.URL + "/v1"
135+
136+
// create an edit request
137+
model := "text-moderation-stable"
138+
moderationReq := ModerationRequest{
139+
Model: &model,
140+
Input: "I want to kill them.",
141+
}
142+
_, err = client.Moderations(ctx, moderationReq)
143+
if err != nil {
144+
t.Fatalf("Moderation error: %v", err)
145+
}
146+
}
147+
124148
func TestEmbedding(t *testing.T) {
125149
embeddedModels := []EmbeddingModel{
126150
AdaSimilarity,
@@ -160,6 +184,25 @@ func TestEmbedding(t *testing.T) {
160184
}
161185
}
162186

187+
func TestImages(t *testing.T) {
188+
// create the test server
189+
var err error
190+
ts := OpenAITestServer()
191+
ts.Start()
192+
defer ts.Close()
193+
194+
client := NewClient(testAPIToken)
195+
ctx := context.Background()
196+
client.BaseURL = ts.URL + "/v1"
197+
198+
req := ImageRequest{}
199+
req.Prompt = "Lorem ipsum"
200+
_, err = client.CreateImage(ctx, req)
201+
if err != nil {
202+
t.Fatalf("CreateImage error: %v", err)
203+
}
204+
}
205+
163206
// getEditBody Returns the body of the request to create an edit.
164207
func getEditBody(r *http.Request) (EditsRequest, error) {
165208
edit := EditsRequest{}
@@ -261,6 +304,21 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
261304
fmt.Fprintln(w, string(resBytes))
262305
}
263306

307+
// getCompletionBody Returns the body of the request to create a completion.
308+
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
309+
completion := CompletionRequest{}
310+
// read the request body
311+
reqBody, err := ioutil.ReadAll(r.Body)
312+
if err != nil {
313+
return CompletionRequest{}, err
314+
}
315+
err = json.Unmarshal(reqBody, &completion)
316+
if err != nil {
317+
return CompletionRequest{}, err
318+
}
319+
return completion, nil
320+
}
321+
264322
// handleImageEndpoint Handles the images endpoint by the test server.
265323
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
266324
var err error
@@ -296,34 +354,78 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
296354
fmt.Fprintln(w, string(resBytes))
297355
}
298356

299-
// getCompletionBody Returns the body of the request to create a completion.
300-
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
301-
completion := CompletionRequest{}
357+
// getImageBody Returns the body of the request to create a image.
358+
func getImageBody(r *http.Request) (ImageRequest, error) {
359+
image := ImageRequest{}
302360
// read the request body
303361
reqBody, err := ioutil.ReadAll(r.Body)
304362
if err != nil {
305-
return CompletionRequest{}, err
363+
return ImageRequest{}, err
306364
}
307-
err = json.Unmarshal(reqBody, &completion)
365+
err = json.Unmarshal(reqBody, &image)
308366
if err != nil {
309-
return CompletionRequest{}, err
367+
return ImageRequest{}, err
310368
}
311-
return completion, nil
369+
return image, nil
312370
}
313371

314-
// getImageBody Returns the body of the request to create a image.
315-
func getImageBody(r *http.Request) (ImageRequest, error) {
316-
image := ImageRequest{}
372+
// handleModerationEndpoint Handles the moderation endpoint by the test server.
373+
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
374+
var err error
375+
var resBytes []byte
376+
377+
// completions only accepts POST requests
378+
if r.Method != "POST" {
379+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
380+
}
381+
var moderationReq ModerationRequest
382+
if moderationReq, err = getModerationBody(r); err != nil {
383+
http.Error(w, "could not read request", http.StatusInternalServerError)
384+
return
385+
}
386+
387+
resCat := ResultCategories{}
388+
resCatScore := ResultCategoryScores{}
389+
switch {
390+
case strings.Contains(moderationReq.Input, "kill"):
391+
resCat = ResultCategories{Violence: true}
392+
resCatScore = ResultCategoryScores{Violence: 1}
393+
case strings.Contains(moderationReq.Input, "hate"):
394+
resCat = ResultCategories{Hate: true}
395+
resCatScore = ResultCategoryScores{Hate: 1}
396+
case strings.Contains(moderationReq.Input, "suicide"):
397+
resCat = ResultCategories{SelfHarm: true}
398+
resCatScore = ResultCategoryScores{SelfHarm: 1}
399+
case strings.Contains(moderationReq.Input, "porn"):
400+
resCat = ResultCategories{Sexual: true}
401+
resCatScore = ResultCategoryScores{Sexual: 1}
402+
}
403+
404+
result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}
405+
406+
res := ModerationResponse{
407+
ID: strconv.Itoa(int(time.Now().Unix())),
408+
Model: *moderationReq.Model,
409+
}
410+
res.Results = append(res.Results, result)
411+
412+
resBytes, _ = json.Marshal(res)
413+
fmt.Fprintln(w, string(resBytes))
414+
}
415+
416+
// getModerationBody Returns the body of the request to do a moderation.
417+
func getModerationBody(r *http.Request) (ModerationRequest, error) {
418+
moderation := ModerationRequest{}
317419
// read the request body
318420
reqBody, err := ioutil.ReadAll(r.Body)
319421
if err != nil {
320-
return ImageRequest{}, err
422+
return ModerationRequest{}, err
321423
}
322-
err = json.Unmarshal(reqBody, &image)
424+
err = json.Unmarshal(reqBody, &moderation)
323425
if err != nil {
324-
return ImageRequest{}, err
426+
return ModerationRequest{}, err
325427
}
326-
return image, nil
428+
return moderation, nil
327429
}
328430

329431
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
@@ -335,25 +437,6 @@ func numTokens(s string) int {
335437
return int(float32(len(s)) / 4)
336438
}
337439

338-
func TestImages(t *testing.T) {
339-
// create the test server
340-
var err error
341-
ts := OpenAITestServer()
342-
ts.Start()
343-
defer ts.Close()
344-
345-
client := NewClient(testAPIToken)
346-
ctx := context.Background()
347-
client.BaseURL = ts.URL + "/v1"
348-
349-
req := ImageRequest{}
350-
req.Prompt = "Lorem ipsum"
351-
_, err = client.CreateImage(ctx, req)
352-
if err != nil {
353-
t.Fatalf("CreateImage error: %v", err)
354-
}
355-
}
356-
357440
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
358441
func OpenAITestServer() *httptest.Server {
359442
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -373,6 +456,8 @@ func OpenAITestServer() *httptest.Server {
373456
case "/v1/completions":
374457
handleCompletionEndpoint(w, r)
375458
return
459+
case "/v1/moderations":
460+
handleModerationEndpoint(w, r)
376461
case "/v1/images/generations":
377462
handleImageEndpoint(w, r)
378463
// TODO: implement the other endpoints

0 commit comments

Comments
 (0)