Skip to content

Commit 2ebb265

Browse files
authored
refactor: Refactor endpoint and model compatibility check (#180)
* Add model check for chat stream * Sync model checks * Fix typo * Fix functino * refactor: Refactor endpoint and model compatibility check * apply review suggestions * minor fix * invert return boolean flag * fix test
1 parent 4288394 commit 2ebb265

File tree

8 files changed

+94
-13
lines changed

8 files changed

+94
-13
lines changed

chat.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ const (
1414
)
1515

1616
var (
17-
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") //nolint:lll
18-
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
17+
ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll
18+
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
1919
)
2020

2121
type ChatCompletionMessage struct {
@@ -71,14 +71,12 @@ func (c *Client) CreateChatCompletion(
7171
return
7272
}
7373

74-
switch request.Model {
75-
case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K:
76-
default:
74+
urlSuffix := "/chat/completions"
75+
if !checkEndpointSupportsModel(urlSuffix, request.Model) {
7776
err = ErrChatCompletionInvalidModel
7877
return
7978
}
8079

81-
urlSuffix := "/chat/completions"
8280
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
8381
if err != nil {
8482
return

chat_stream.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@ func (c *Client) CreateChatCompletionStream(
3737
ctx context.Context,
3838
request ChatCompletionRequest,
3939
) (stream *ChatCompletionStream, err error) {
40+
urlSuffix := "/chat/completions"
41+
if !checkEndpointSupportsModel(urlSuffix, request.Model) {
42+
err = ErrChatCompletionInvalidModel
43+
return
44+
}
45+
4046
request.Stream = true
41-
req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request)
47+
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
4248
if err != nil {
4349
return
4450
}

chat_stream_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@ import (
1313
"testing"
1414
)
1515

16+
func TestChatCompletionsStreamWrongModel(t *testing.T) {
17+
config := DefaultConfig("whatever")
18+
config.BaseURL = "http://localhost/v1"
19+
client := NewClientWithConfig(config)
20+
ctx := context.Background()
21+
22+
req := ChatCompletionRequest{
23+
MaxTokens: 5,
24+
Model: "ada",
25+
Messages: []ChatCompletionMessage{
26+
{
27+
Role: ChatMessageRoleUser,
28+
Content: "Hello!",
29+
},
30+
},
31+
}
32+
_, err := client.CreateChatCompletionStream(ctx, req)
33+
if !errors.Is(err, ErrChatCompletionInvalidModel) {
34+
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
35+
}
36+
}
37+
1638
func TestCreateChatCompletionStream(t *testing.T) {
1739
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1840
w.Header().Set("Content-Type", "text/event-stream")

chat_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestChatCompletionsWrongModel(t *testing.T) {
3434
}
3535
_, err := client.CreateChatCompletion(ctx, req)
3636
if !errors.Is(err, ErrChatCompletionInvalidModel) {
37-
t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err)
37+
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
3838
}
3939
}
4040

completion.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,38 @@ const (
4545
CodexCodeDavinci001 = "code-davinci-001"
4646
)
4747

48+
var disabledModelsForEndpoints = map[string]map[string]bool{
49+
"/completions": {
50+
GPT3Dot5Turbo: true,
51+
GPT3Dot5Turbo0301: true,
52+
GPT4: true,
53+
GPT40314: true,
54+
GPT432K: true,
55+
GPT432K0314: true,
56+
},
57+
"/chat/completions": {
58+
CodexCodeDavinci002: true,
59+
CodexCodeCushman001: true,
60+
CodexCodeDavinci001: true,
61+
GPT3TextDavinci003: true,
62+
GPT3TextDavinci002: true,
63+
GPT3TextCurie001: true,
64+
GPT3TextBabbage001: true,
65+
GPT3TextAda001: true,
66+
GPT3TextDavinci001: true,
67+
GPT3DavinciInstructBeta: true,
68+
GPT3Davinci: true,
69+
GPT3CurieInstructBeta: true,
70+
GPT3Curie: true,
71+
GPT3Ada: true,
72+
GPT3Babbage: true,
73+
},
74+
}
75+
76+
func checkEndpointSupportsModel(endpoint, model string) bool {
77+
return !disabledModelsForEndpoints[endpoint][model]
78+
}
79+
4880
// CompletionRequest represents a request structure for completion API.
4981
type CompletionRequest struct {
5082
Model string `json:"model"`
@@ -105,12 +137,12 @@ func (c *Client) CreateCompletion(
105137
return
106138
}
107139

108-
if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo {
140+
urlSuffix := "/completions"
141+
if !checkEndpointSupportsModel(urlSuffix, request.Model) {
109142
err = ErrCompletionUnsupportedModel
110143
return
111144
}
112145

113-
urlSuffix := "/completions"
114146
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
115147
if err != nil {
116148
return

request_builder_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
6161
t.Fatalf("Did not return error when request builder failed: %v", err)
6262
}
6363

64-
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
64+
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
6565
if !errors.Is(err, errTestRequestBuilderFailed) {
6666
t.Fatalf("Did not return error when request builder failed: %v", err)
6767
}

stream.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@ func (c *Client) CreateCompletionStream(
2222
ctx context.Context,
2323
request CompletionRequest,
2424
) (stream *CompletionStream, err error) {
25+
urlSuffix := "/completions"
26+
if !checkEndpointSupportsModel(urlSuffix, request.Model) {
27+
err = ErrCompletionUnsupportedModel
28+
return
29+
}
30+
2531
request.Stream = true
26-
req, err := c.newStreamRequest(ctx, "POST", "/completions", request)
32+
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
2733
if err != nil {
2834
return
2935
}

stream_test.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,23 @@ import (
1212
"testing"
1313
)
1414

15+
func TestCompletionsStreamWrongModel(t *testing.T) {
16+
config := DefaultConfig("whatever")
17+
config.BaseURL = "http://localhost/v1"
18+
client := NewClientWithConfig(config)
19+
20+
_, err := client.CreateCompletionStream(
21+
context.Background(),
22+
CompletionRequest{
23+
MaxTokens: 5,
24+
Model: GPT3Dot5Turbo,
25+
},
26+
)
27+
if !errors.Is(err, ErrCompletionUnsupportedModel) {
28+
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
29+
}
30+
}
31+
1532
func TestCreateCompletionStream(t *testing.T) {
1633
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1734
w.Header().Set("Content-Type", "text/event-stream")
@@ -140,7 +157,7 @@ func TestCreateCompletionStreamError(t *testing.T) {
140157

141158
request := CompletionRequest{
142159
MaxTokens: 5,
143-
Model: GPT3Dot5Turbo,
160+
Model: GPT3TextDavinci003,
144161
Prompt: "Hello!",
145162
Stream: true,
146163
}

0 commit comments

Comments
 (0)