Skip to content

Commit f22da8a

Browse files
stillmaticJames
andauthored
feat: allow more input types to functions, fix tests (#377)
* feat: use json.rawMessage, test functions * chore: lint * fix: tests the ChatCompletion mock server doesn't actually run otherwise. N=0 is the default request but the server will treat it as n=1 * fix: tests should default to n=1 completions * chore: add back removed interfaces, custom marshal * chore: lint * chore: lint * chore: add some tests * chore: appease lint * clean up JSON schema + tests * chore: lint * feat: remove backwards compatible functions for illustrative purposes * fix: revert params change * chore: use interface{} * chore: add test * chore: add back FunctionDefine * chore: /s/interface{}/any * chore: add back jsonschemadefinition * chore: testcov * chore: lint * chore: remove pointers * chore: update comment * chore: address CR added test for compatibility as well --------- Co-authored-by: James <[email protected]>
1 parent e948150 commit f22da8a

File tree

3 files changed

+180
-21
lines changed

3 files changed

+180
-21
lines changed

chat.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,23 @@ type ChatCompletionRequest struct {
5454
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
5555
LogitBias map[string]int `json:"logit_bias,omitempty"`
5656
User string `json:"user,omitempty"`
57-
Functions []*FunctionDefine `json:"functions,omitempty"`
58-
FunctionCall string `json:"function_call,omitempty"`
57+
Functions []FunctionDefinition `json:"functions,omitempty"`
58+
FunctionCall any `json:"function_call,omitempty"`
5959
}
6060

61-
type FunctionDefine struct {
61+
type FunctionDefinition struct {
6262
Name string `json:"name"`
6363
Description string `json:"description,omitempty"`
64-
// it's required in function call
65-
Parameters *FunctionParams `json:"parameters"`
64+
// Parameters is an object describing the function.
65+
// You can pass a raw byte array describing the schema,
66+
// or you can pass in a struct which serializes to the proper JSONSchema.
67+
// The JSONSchemaDefinition struct is provided for convenience, but you should
68+
// consider another specialized library for more complex schemas.
69+
Parameters any `json:"parameters"`
6670
}
6771

68-
type FunctionParams struct {
69-
// the Type must be JSONSchemaTypeObject
70-
Type JSONSchemaType `json:"type"`
71-
Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"`
72-
Required []string `json:"required,omitempty"`
73-
}
72+
// Deprecated: use FunctionDefinition instead.
73+
type FunctionDefine = FunctionDefinition
7474

7575
type JSONSchemaType string
7676

@@ -83,22 +83,26 @@ const (
8383
JSONSchemaTypeBoolean JSONSchemaType = "boolean"
8484
)
8585

86-
// JSONSchemaDefine is a struct for JSON Schema.
87-
type JSONSchemaDefine struct {
86+
// JSONSchemaDefinition is a struct for JSON Schema.
87+
// It is fairly limited and you may have better luck using a third-party library.
88+
type JSONSchemaDefinition struct {
8889
// Type is a type of JSON Schema.
8990
Type JSONSchemaType `json:"type,omitempty"`
9091
// Description is a description of JSON Schema.
9192
Description string `json:"description,omitempty"`
9293
// Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString.
9394
Enum []string `json:"enum,omitempty"`
9495
// Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject.
95-
Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"`
96+
Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"`
9697
// Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject.
9798
Required []string `json:"required,omitempty"`
9899
// Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray.
99-
Items *JSONSchemaDefine `json:"items,omitempty"`
100+
Items *JSONSchemaDefinition `json:"items,omitempty"`
100101
}
101102

103+
// Deprecated: use JSONSchemaDefinition instead.
104+
type JSONSchemaDefine = JSONSchemaDefinition
105+
102106
type FinishReason string
103107

104108
const (

chat_test.go

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,130 @@ func TestChatCompletions(t *testing.T) {
6767
checks.NoError(t, err, "CreateChatCompletion error")
6868
}
6969

70+
// TestChatCompletionsFunctions tests including a function call.
71+
func TestChatCompletionsFunctions(t *testing.T) {
72+
client, server, teardown := setupOpenAITestServer()
73+
defer teardown()
74+
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
75+
t.Run("bytes", func(t *testing.T) {
76+
//nolint:lll
77+
msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`)
78+
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
79+
MaxTokens: 5,
80+
Model: GPT3Dot5Turbo0613,
81+
Messages: []ChatCompletionMessage{
82+
{
83+
Role: ChatMessageRoleUser,
84+
Content: "Hello!",
85+
},
86+
},
87+
Functions: []FunctionDefine{{
88+
Name: "test",
89+
Parameters: &msg,
90+
}},
91+
})
92+
checks.NoError(t, err, "CreateChatCompletion with functions error")
93+
})
94+
t.Run("struct", func(t *testing.T) {
95+
type testMessage struct {
96+
Count int `json:"count"`
97+
Words []string `json:"words"`
98+
}
99+
msg := testMessage{
100+
Count: 2,
101+
Words: []string{"hello", "world"},
102+
}
103+
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
104+
MaxTokens: 5,
105+
Model: GPT3Dot5Turbo0613,
106+
Messages: []ChatCompletionMessage{
107+
{
108+
Role: ChatMessageRoleUser,
109+
Content: "Hello!",
110+
},
111+
},
112+
Functions: []FunctionDefinition{{
113+
Name: "test",
114+
Parameters: &msg,
115+
}},
116+
})
117+
checks.NoError(t, err, "CreateChatCompletion with functions error")
118+
})
119+
t.Run("JSONSchemaDefine", func(t *testing.T) {
120+
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
121+
MaxTokens: 5,
122+
Model: GPT3Dot5Turbo0613,
123+
Messages: []ChatCompletionMessage{
124+
{
125+
Role: ChatMessageRoleUser,
126+
Content: "Hello!",
127+
},
128+
},
129+
Functions: []FunctionDefinition{{
130+
Name: "test",
131+
Parameters: &JSONSchemaDefinition{
132+
Type: JSONSchemaTypeObject,
133+
Properties: map[string]JSONSchemaDefinition{
134+
"count": {
135+
Type: JSONSchemaTypeNumber,
136+
Description: "total number of words in sentence",
137+
},
138+
"words": {
139+
Type: JSONSchemaTypeArray,
140+
Description: "list of words in sentence",
141+
Items: &JSONSchemaDefinition{
142+
Type: JSONSchemaTypeString,
143+
},
144+
},
145+
"enumTest": {
146+
Type: JSONSchemaTypeString,
147+
Enum: []string{"hello", "world"},
148+
},
149+
},
150+
},
151+
}},
152+
})
153+
checks.NoError(t, err, "CreateChatCompletion with functions error")
154+
})
155+
t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) {
156+
// this is a compatibility check
157+
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
158+
MaxTokens: 5,
159+
Model: GPT3Dot5Turbo0613,
160+
Messages: []ChatCompletionMessage{
161+
{
162+
Role: ChatMessageRoleUser,
163+
Content: "Hello!",
164+
},
165+
},
166+
Functions: []FunctionDefine{{
167+
Name: "test",
168+
Parameters: &JSONSchemaDefine{
169+
Type: JSONSchemaTypeObject,
170+
Properties: map[string]JSONSchemaDefine{
171+
"count": {
172+
Type: JSONSchemaTypeNumber,
173+
Description: "total number of words in sentence",
174+
},
175+
"words": {
176+
Type: JSONSchemaTypeArray,
177+
Description: "list of words in sentence",
178+
Items: &JSONSchemaDefine{
179+
Type: JSONSchemaTypeString,
180+
},
181+
},
182+
"enumTest": {
183+
Type: JSONSchemaTypeString,
184+
Enum: []string{"hello", "world"},
185+
},
186+
},
187+
},
188+
}},
189+
})
190+
checks.NoError(t, err, "CreateChatCompletion with functions error")
191+
})
192+
}
193+
70194
func TestAzureChatCompletions(t *testing.T) {
71195
client, server, teardown := setupAzureTestServer()
72196
defer teardown()
@@ -109,7 +233,34 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
109233
Model: completionReq.Model,
110234
}
111235
// create completions
112-
for i := 0; i < completionReq.N; i++ {
236+
n := completionReq.N
237+
if n == 0 {
238+
n = 1
239+
}
240+
for i := 0; i < n; i++ {
241+
// if there are functions, include them
242+
if len(completionReq.Functions) > 0 {
243+
var fcb []byte
244+
b := completionReq.Functions[0].Parameters
245+
fcb, err = json.Marshal(b)
246+
if err != nil {
247+
http.Error(w, "could not marshal function parameters", http.StatusInternalServerError)
248+
return
249+
}
250+
251+
res.Choices = append(res.Choices, ChatCompletionChoice{
252+
Message: ChatCompletionMessage{
253+
Role: ChatMessageRoleFunction,
254+
// this is valid json so it should be fine
255+
FunctionCall: &FunctionCall{
256+
Name: completionReq.Functions[0].Name,
257+
Arguments: string(fcb),
258+
},
259+
},
260+
Index: i,
261+
})
262+
continue
263+
}
113264
// generate a random string of length completionReq.Length
114265
completionStr := strings.Repeat("a", completionReq.MaxTokens)
115266

@@ -121,8 +272,8 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
121272
Index: i,
122273
})
123274
}
124-
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N
125-
completionTokens := completionReq.MaxTokens * completionReq.N
275+
inputTokens := numTokens(completionReq.Messages[0].Content) * n
276+
completionTokens := completionReq.MaxTokens * n
126277
res.Usage = Usage{
127278
PromptTokens: inputTokens,
128279
CompletionTokens: completionTokens,

completion_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
8383
Model: completionReq.Model,
8484
}
8585
// create completions
86-
for i := 0; i < completionReq.N; i++ {
86+
n := completionReq.N
87+
if n == 0 {
88+
n = 1
89+
}
90+
for i := 0; i < n; i++ {
8791
// generate a random string of length completionReq.Length
8892
completionStr := strings.Repeat("a", completionReq.MaxTokens)
8993
if completionReq.Echo {
@@ -94,8 +98,8 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
9498
Index: i,
9599
})
96100
}
97-
inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N
98-
completionTokens := completionReq.MaxTokens * completionReq.N
101+
inputTokens := numTokens(completionReq.Prompt.(string)) * n
102+
completionTokens := completionReq.MaxTokens * n
99103
res.Usage = Usage{
100104
PromptTokens: inputTokens,
101105
CompletionTokens: completionTokens,

0 commit comments

Comments
 (0)