Skip to content

Commit 07a9aa3

Browse files
authored
fix: dispatch tools (#24)
* fix: dispatch tools * chore: skip
1 parent 7c4cc75 commit 07a9aa3

File tree

3 files changed

+49
-16
lines changed

3 files changed

+49
-16
lines changed

dispatch.go

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ type OpenAIDispatcherConfig struct {
4848
// For any type T and prompt, it will generate and parse the response into T.
4949
type OpenAIDispatcher[T any] struct {
5050
*OpenAIDispatcherConfig
51-
completer ChatCompleter
52-
ti openai.Tool
53-
parser Parser[T]
51+
completer ChatCompleter
52+
ti openai.Tool
53+
systemPrompt string
54+
parser Parser[T]
5455
}
5556

56-
func NewOpenAIDispatcher[T any](name, description string, completer ChatCompleter, cfg *OpenAIDispatcherConfig) *OpenAIDispatcher[T] {
57+
func NewOpenAIDispatcher[T any](name, description, systemPrompt string, completer ChatCompleter, cfg *OpenAIDispatcherConfig) *OpenAIDispatcher[T] {
5758
// note: name must not have spaces - valid json
5859
// we won't check here but the openai client will throw an error
5960
var t T
@@ -65,12 +66,13 @@ func NewOpenAIDispatcher[T any](name, description string, completer ChatComplete
6566
completer: completer,
6667
ti: ti,
6768
parser: parser,
69+
systemPrompt: systemPrompt,
6870
}
6971
}
7072

7173
func (d *OpenAIDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, error) {
7274
var output T
73-
model := openai.GPT3Dot5Turbo0613
75+
model := openai.GPT3Dot5Turbo1106
7476
temperature := float32(0.0)
7577
maxTokens := 512
7678
if d.OpenAIDispatcherConfig != nil {
@@ -85,24 +87,35 @@ func (d *OpenAIDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, err
8587
}
8688
}
8789

88-
resp, err := d.completer.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
90+
req := openai.ChatCompletionRequest{
8991
Model: model,
9092
Messages: []openai.ChatCompletionMessage{
93+
{
94+
Role: openai.ChatMessageRoleSystem,
95+
Content: d.systemPrompt,
96+
},
9197
{
9298
Role: openai.ChatMessageRoleUser,
9399
Content: prompt,
94100
},
95101
},
96-
Tools: []openai.Tool{d.ti},
97-
ToolChoice: d.ti.Function.Name,
102+
Tools: []openai.Tool{d.ti},
103+
ToolChoice: openai.ToolChoice{
104+
Type: "function",
105+
Function: openai.ToolFunction{
106+
Name: d.ti.Function.Name,
107+
}},
98108
Temperature: temperature,
99109
MaxTokens: maxTokens,
100-
})
110+
}
111+
112+
resp, err := d.completer.CreateChatCompletion(ctx, req)
101113
if err != nil {
102114
return output, err
103115
}
104116

105-
output, err = d.parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
117+
toolOutput := resp.Choices[0].Message.ToolCalls[0].Function.Arguments
118+
output, err = d.parser.Parse(ctx, []byte(toolOutput))
106119
if err != nil {
107120
return output, err
108121
}

dispatch_test.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package gollum_test
22

33
import (
44
"context"
5+
"os"
56
"testing"
67
"text/template"
78

@@ -21,6 +22,10 @@ type templateInput struct {
2122
Topic string
2223
}
2324

25+
type wordCountOutput struct {
26+
Count int `json:"count" jsonschema:"required" jsonschema_description:"The number of words in the sentence"`
27+
}
28+
2429
func TestDummyDispatcher(t *testing.T) {
2530
d := gollum.NewDummyDispatcher[testInput]()
2631

@@ -46,7 +51,8 @@ func TestDummyDispatcher(t *testing.T) {
4651
func TestOpenAIDispatcher(t *testing.T) {
4752
ctrl := gomock.NewController(t)
4853
completer := mock_gollum.NewMockChatCompleter(ctrl)
49-
d := gollum.NewOpenAIDispatcher[testInput]("random_conversation", "Given a topic, return random words", completer, nil)
54+
systemPrompt := "When prompted, use the tool."
55+
d := gollum.NewOpenAIDispatcher[testInput]("random_conversation", "Given a topic, return random words", systemPrompt, completer, nil)
5056

5157
ctx := context.Background()
5258
expected := testInput{
@@ -58,15 +64,19 @@ func TestOpenAIDispatcher(t *testing.T) {
5864
fi := openai.FunctionDefinition(gollum.StructToJsonSchema("random_conversation", "Given a topic, return random words", testInput{}))
5965
ti := openai.Tool{Type: "function", Function: fi}
6066
expectedRequest := openai.ChatCompletionRequest{
61-
Model: openai.GPT3Dot5Turbo0613,
67+
Model: openai.GPT3Dot5Turbo1106,
6268
Messages: []openai.ChatCompletionMessage{
6369
{
64-
Role: openai.ChatMessageRoleSystem,
70+
Role: openai.ChatMessageRoleUser,
6571
Content: "Tell me about dinosaurs",
6672
},
6773
},
68-
Tools: []openai.Tool{ti},
69-
ToolChoice: fi.Name,
74+
Tools: []openai.Tool{ti},
75+
ToolChoice: openai.ToolChoice{
76+
Type: "function",
77+
Function: openai.ToolFunction{
78+
Name: "random_conversation",
79+
}},
7080
MaxTokens: 512,
7181
Temperature: 0.0,
7282
}
@@ -127,3 +137,13 @@ func TestOpenAIDispatcher(t *testing.T) {
127137
assert.Equal(t, expected, output)
128138
})
129139
}
140+
141+
func TestDispatchIntegration(t *testing.T) {
142+
t.Skip("Skipping integration test")
143+
completer := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
144+
systemPrompt := "When prompted, use the tool on the user's input."
145+
d := gollum.NewOpenAIDispatcher[wordCountOutput]("wordCounter", "count the number of words in a sentence", systemPrompt, completer, nil)
146+
output, err := d.Prompt(context.Background(), "I like dinosaurs")
147+
assert.NoError(t, err)
148+
assert.Equal(t, 3, output.Count)
149+
}

functions_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func TestEndToEnd(t *testing.T) {
7272
fi := gollum.StructToJsonSchema("weather", "Get the current weather in a given location", getWeatherInput{})
7373

7474
chatRequest := openai.ChatCompletionRequest{
75-
Model: "gpt-3.5-turbo-0613",
75+
Model: openai.GPT3Dot5Turbo1106,
7676
Messages: []openai.ChatCompletionMessage{
7777
{
7878
Role: "user",

0 commit comments

Comments
 (0)