@@ -2,6 +2,7 @@ package gollum_test
22
33import (
44 "context"
5+ "os"
56 "testing"
67 "text/template"
78
@@ -21,6 +22,10 @@ type templateInput struct {
2122 Topic string
2223}
2324
25+ type wordCountOutput struct {
26+ Count int `json:"count" jsonschema:"required" jsonschema_description:"The number of words in the sentence"`
27+ }
28+
2429func TestDummyDispatcher (t * testing.T ) {
2530 d := gollum .NewDummyDispatcher [testInput ]()
2631
@@ -46,7 +51,8 @@ func TestDummyDispatcher(t *testing.T) {
4651func TestOpenAIDispatcher (t * testing.T ) {
4752 ctrl := gomock .NewController (t )
4853 completer := mock_gollum .NewMockChatCompleter (ctrl )
49- d := gollum .NewOpenAIDispatcher [testInput ]("random_conversation" , "Given a topic, return random words" , completer , nil )
54+ systemPrompt := "When prompted, use the tool."
55+ d := gollum .NewOpenAIDispatcher [testInput ]("random_conversation" , "Given a topic, return random words" , systemPrompt , completer , nil )
5056
5157 ctx := context .Background ()
5258 expected := testInput {
@@ -58,15 +64,19 @@ func TestOpenAIDispatcher(t *testing.T) {
5864 fi := openai .FunctionDefinition (gollum .StructToJsonSchema ("random_conversation" , "Given a topic, return random words" , testInput {}))
5965 ti := openai.Tool {Type : "function" , Function : fi }
6066 expectedRequest := openai.ChatCompletionRequest {
61- Model : openai .GPT3Dot5Turbo0613 ,
67+ Model : openai .GPT3Dot5Turbo1106 ,
6268 Messages : []openai.ChatCompletionMessage {
6369 {
64- Role : openai .ChatMessageRoleSystem ,
70+ Role : openai .ChatMessageRoleUser ,
6571 Content : "Tell me about dinosaurs" ,
6672 },
6773 },
68- Tools : []openai.Tool {ti },
69- ToolChoice : fi .Name ,
74+ Tools : []openai.Tool {ti },
75+ ToolChoice : openai.ToolChoice {
76+ Type : "function" ,
77+ Function : openai.ToolFunction {
78+ Name : "random_conversation" ,
79+ }},
7080 MaxTokens : 512 ,
7181 Temperature : 0.0 ,
7282 }
@@ -127,3 +137,13 @@ func TestOpenAIDispatcher(t *testing.T) {
127137 assert .Equal (t , expected , output )
128138 })
129139}
140+
141+ func TestDispatchIntegration (t * testing.T ) {
142+ t .Skip ("Skipping integration test" )
143+ completer := openai .NewClient (os .Getenv ("OPENAI_API_KEY" ))
144+ systemPrompt := "When prompted, use the tool on the user's input."
145+ d := gollum .NewOpenAIDispatcher [wordCountOutput ]("wordCounter" , "count the number of words in a sentence" , systemPrompt , completer , nil )
146+ output , err := d .Prompt (context .Background (), "I like dinosaurs" )
147+ assert .NoError (t , err )
148+ assert .Equal (t , 3 , output .Count )
149+ }
0 commit comments