Skip to content

Commit e22a29d

Browse files
authored
Check if the model param is valid for moderations endpoint (#437)
* chore: check for models before sending moderation requets to openai endpoint * chore: table driven tests to include more model cases for moderations endpoint
1 parent 39b2acb commit e22a29d

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

moderation.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai
22

33
import (
44
"context"
5+
"errors"
56
"net/http"
67
)
78

@@ -15,9 +16,19 @@ import (
1516
const (
1617
ModerationTextStable = "text-moderation-stable"
1718
ModerationTextLatest = "text-moderation-latest"
18-
ModerationText001 = "text-moderation-001"
19+
// Deprecated: use ModerationTextStable and ModerationTextLatest instead.
20+
ModerationText001 = "text-moderation-001"
1921
)
2022

23+
var (
24+
ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll
25+
)
26+
27+
var validModerationModel = map[string]struct{}{
28+
ModerationTextStable: {},
29+
ModerationTextLatest: {},
30+
}
31+
2132
// ModerationRequest represents a request structure for moderation API.
2233
type ModerationRequest struct {
2334
Input string `json:"input,omitempty"`
@@ -63,6 +74,10 @@ type ModerationResponse struct {
6374
// Moderations — perform a moderation api call over a string.
6475
// Input can be an array or slice but a string will reduce the complexity.
6576
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
77+
if _, ok := validModerationModel[request.Model]; len(request.Model) > 0 && !ok {
78+
err = ErrModerationInvalidModel
79+
return
80+
}
6681
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request))
6782
if err != nil {
6883
return

moderation_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,41 @@ func TestModerations(t *testing.T) {
2727
checks.NoError(t, err, "Moderation error")
2828
}
2929

30+
// TestModerationsWithIncorrectModel Tests passing valid and invalid models to moderations endpoint.
31+
func TestModerationsWithDifferentModelOptions(t *testing.T) {
32+
var modelOptions []struct {
33+
model string
34+
expect error
35+
}
36+
modelOptions = append(modelOptions,
37+
getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel),
38+
getModerationModelTestOption(ModerationTextStable, nil),
39+
getModerationModelTestOption(ModerationTextLatest, nil),
40+
getModerationModelTestOption("", nil),
41+
)
42+
client, server, teardown := setupOpenAITestServer()
43+
defer teardown()
44+
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
45+
for _, modelTest := range modelOptions {
46+
_, err := client.Moderations(context.Background(), ModerationRequest{
47+
Model: modelTest.model,
48+
Input: "I want to kill them.",
49+
})
50+
checks.ErrorIs(t, err, modelTest.expect,
51+
fmt.Sprintf("Moderations(..) expects err: %v, actual err:%v", modelTest.expect, err))
52+
}
53+
}
54+
55+
func getModerationModelTestOption(model string, expect error) struct {
56+
model string
57+
expect error
58+
} {
59+
return struct {
60+
model string
61+
expect error
62+
}{model: model, expect: expect}
63+
}
64+
3065
// handleModerationEndpoint Handles the moderation endpoint by the test server.
3166
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
3267
var err error

0 commit comments

Comments
 (0)