Skip to content

Commit 294a307

Browse files
cryo-zdrootfs
andauthored
test: add test for ToolsDatabase (#284)
Signed-off-by: cryo <[email protected]> Co-authored-by: Huamin Chen <[email protected]>
1 parent e9186e7 commit 294a307

File tree

1 file changed

+295
-0
lines changed

1 file changed

+295
-0
lines changed
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
package tools_test
2+
3+
import (
4+
"encoding/json"
5+
"os"
6+
"path/filepath"
7+
"testing"
8+
9+
"github.com/openai/openai-go"
10+
"github.com/openai/openai-go/packages/param"
11+
candle_binding "github.com/vllm-project/semantic-router/candle-binding"
12+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/tools"
13+
14+
. "github.com/onsi/ginkgo/v2"
15+
. "github.com/onsi/gomega"
16+
)
17+
18+
func TestTools(t *testing.T) {
19+
RegisterFailHandler(Fail)
20+
RunSpecs(t, "Tools Suite")
21+
}
22+
23+
var _ = BeforeSuite(func() {
24+
// Initialize BERT model once for all cache tests (Linux only)
25+
err := candle_binding.InitModel("sentence-transformers/all-MiniLM-L6-v2", true)
26+
Expect(err).NotTo(HaveOccurred())
27+
})
28+
29+
var _ = Describe("ToolsDatabase", func() {
30+
Describe("NewToolsDatabase", func() {
31+
It("should create enabled and disabled databases", func() {
32+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
33+
SimilarityThreshold: 0.8,
34+
Enabled: true,
35+
})
36+
Expect(db).NotTo(BeNil())
37+
Expect(db.IsEnabled()).To(BeTrue())
38+
39+
db2 := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
40+
SimilarityThreshold: 0.8,
41+
Enabled: false,
42+
})
43+
Expect(db2).NotTo(BeNil())
44+
Expect(db2.IsEnabled()).To(BeFalse())
45+
})
46+
})
47+
48+
Describe("LoadToolsFromFile", func() {
49+
var (
50+
tempDir string
51+
toolFilePath string
52+
)
53+
54+
BeforeEach(func() {
55+
var err error
56+
tempDir, err = os.MkdirTemp("", "tools_test")
57+
Expect(err).NotTo(HaveOccurred())
58+
59+
toolFilePath = filepath.Join(tempDir, "tools.json")
60+
toolsData := []tools.ToolEntry{
61+
{
62+
Tool: openai.ChatCompletionToolParam{
63+
Type: "function",
64+
Function: openai.FunctionDefinitionParam{
65+
Name: "weather",
66+
Description: param.NewOpt("Get weather info"),
67+
},
68+
},
69+
Description: "Get weather info",
70+
Tags: []string{"weather", "info"},
71+
Category: "utility",
72+
},
73+
{
74+
Tool: openai.ChatCompletionToolParam{
75+
Type: "function",
76+
Function: openai.FunctionDefinitionParam{
77+
Name: "news",
78+
Description: param.NewOpt("Get latest news"),
79+
},
80+
},
81+
Description: "Get latest news",
82+
Tags: []string{"news"},
83+
Category: "information",
84+
},
85+
}
86+
data, err := json.Marshal(toolsData)
87+
Expect(err).NotTo(HaveOccurred())
88+
err = os.WriteFile(toolFilePath, data, 0o644)
89+
Expect(err).NotTo(HaveOccurred())
90+
})
91+
92+
AfterEach(func() {
93+
os.RemoveAll(tempDir)
94+
})
95+
96+
It("should load tools from file when enabled", func() {
97+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
98+
SimilarityThreshold: 0.7,
99+
Enabled: true,
100+
})
101+
err := db.LoadToolsFromFile(toolFilePath)
102+
Expect(err).NotTo(HaveOccurred())
103+
Expect(db.GetToolCount()).To(Equal(2))
104+
toolsList := db.GetAllTools()
105+
Expect(toolsList).To(HaveLen(2))
106+
Expect(toolsList[0].Function.Name).To(Equal("weather"))
107+
Expect(toolsList[1].Function.Name).To(Equal("news"))
108+
})
109+
110+
It("should do nothing if disabled", func() {
111+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
112+
SimilarityThreshold: 0.7,
113+
Enabled: false,
114+
})
115+
err := db.LoadToolsFromFile(toolFilePath)
116+
Expect(err).NotTo(HaveOccurred())
117+
Expect(db.GetToolCount()).To(Equal(0))
118+
})
119+
120+
It("should return error if file does not exist", func() {
121+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
122+
SimilarityThreshold: 0.7,
123+
Enabled: true,
124+
})
125+
err := db.LoadToolsFromFile("/nonexistent/tools.json")
126+
Expect(err).To(HaveOccurred())
127+
Expect(err.Error()).To(ContainSubstring("failed to read tools file"))
128+
})
129+
130+
It("should return error if file is invalid JSON", func() {
131+
badFile := filepath.Join(tempDir, "bad.json")
132+
Expect(os.WriteFile(badFile, []byte("{invalid json"), 0o644)).To(Succeed())
133+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
134+
SimilarityThreshold: 0.7,
135+
Enabled: true,
136+
})
137+
err := db.LoadToolsFromFile(badFile)
138+
Expect(err).To(HaveOccurred())
139+
Expect(err.Error()).To(ContainSubstring("failed to parse tools JSON"))
140+
})
141+
})
142+
143+
Describe("AddTool", func() {
144+
It("should add tool when enabled", func() {
145+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
146+
SimilarityThreshold: 0.8,
147+
Enabled: true,
148+
})
149+
tool := openai.ChatCompletionToolParam{
150+
Type: "function",
151+
Function: openai.FunctionDefinitionParam{
152+
Name: "calculator",
153+
Description: param.NewOpt("Simple calculator"),
154+
},
155+
}
156+
err := db.AddTool(tool, "Simple calculator", "utility", []string{"math"})
157+
Expect(err).NotTo(HaveOccurred())
158+
Expect(db.GetToolCount()).To(Equal(1))
159+
allTools := db.GetAllTools()
160+
Expect(allTools[0].Function.Name).To(Equal("calculator"))
161+
})
162+
163+
It("should do nothing if disabled", func() {
164+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
165+
SimilarityThreshold: 0.8,
166+
Enabled: false,
167+
})
168+
tool := openai.ChatCompletionToolParam{
169+
Type: "function",
170+
Function: openai.FunctionDefinitionParam{
171+
Name: "calculator",
172+
Description: param.NewOpt("Simple calculator"),
173+
},
174+
}
175+
err := db.AddTool(tool, "Simple calculator", "utility", []string{"math"})
176+
Expect(err).NotTo(HaveOccurred())
177+
Expect(db.GetToolCount()).To(Equal(0))
178+
})
179+
})
180+
181+
Describe("FindSimilarTools", func() {
182+
var db *tools.ToolsDatabase
183+
184+
BeforeEach(func() {
185+
db = tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
186+
SimilarityThreshold: 0.7,
187+
Enabled: true,
188+
})
189+
_ = db.AddTool(openai.ChatCompletionToolParam{
190+
Type: "function",
191+
Function: openai.FunctionDefinitionParam{
192+
Name: "weather",
193+
Description: param.NewOpt("Get weather info"),
194+
},
195+
}, "Get weather info", "utility", []string{"weather", "info"})
196+
_ = db.AddTool(openai.ChatCompletionToolParam{
197+
Type: "function",
198+
Function: openai.FunctionDefinitionParam{
199+
Name: "news",
200+
Description: param.NewOpt("Get latest news"),
201+
},
202+
}, "Get latest news", "information", []string{"news"})
203+
_ = db.AddTool(openai.ChatCompletionToolParam{
204+
Type: "function",
205+
Function: openai.FunctionDefinitionParam{
206+
Name: "calculator",
207+
Description: param.NewOpt("Simple calculator"),
208+
},
209+
}, "Simple calculator", "utility", []string{"math"})
210+
})
211+
212+
It("should find similar tools for a relevant query", func() {
213+
results, err := db.FindSimilarTools("weather", 2)
214+
Expect(err).NotTo(HaveOccurred())
215+
Expect(results).NotTo(BeEmpty())
216+
Expect(results[0].Function.Name).To(Equal("weather"))
217+
})
218+
219+
It("should return at most topK results", func() {
220+
results, err := db.FindSimilarTools("info", 1)
221+
Expect(err).NotTo(HaveOccurred())
222+
Expect(len(results)).To(BeNumerically("<=", 1))
223+
})
224+
225+
It("should return empty if disabled", func() {
226+
db2 := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
227+
SimilarityThreshold: 0.7,
228+
Enabled: false,
229+
})
230+
results, err := db2.FindSimilarTools("weather", 2)
231+
Expect(err).NotTo(HaveOccurred())
232+
Expect(results).To(BeEmpty())
233+
})
234+
})
235+
236+
Describe("GetAllTools", func() {
237+
It("should return all tools when enabled", func() {
238+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
239+
SimilarityThreshold: 0.8,
240+
Enabled: true,
241+
})
242+
_ = db.AddTool(openai.ChatCompletionToolParam{
243+
Type: "function",
244+
Function: openai.FunctionDefinitionParam{
245+
Name: "weather",
246+
Description: param.NewOpt("Get weather info"),
247+
},
248+
}, "Get weather info", "utility", []string{"weather"})
249+
_ = db.AddTool(openai.ChatCompletionToolParam{
250+
Type: "function",
251+
Function: openai.FunctionDefinitionParam{
252+
Name: "news",
253+
Description: param.NewOpt("Get latest news"),
254+
},
255+
}, "Get latest news", "information", []string{"news"})
256+
allTools := db.GetAllTools()
257+
Expect(allTools).To(HaveLen(2))
258+
})
259+
260+
It("should return empty if disabled", func() {
261+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
262+
SimilarityThreshold: 0.8,
263+
Enabled: false,
264+
})
265+
allTools := db.GetAllTools()
266+
Expect(allTools).To(BeEmpty())
267+
})
268+
})
269+
270+
Describe("GetToolCount", func() {
271+
It("should return correct count when enabled", func() {
272+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
273+
SimilarityThreshold: 0.8,
274+
Enabled: true,
275+
})
276+
Expect(db.GetToolCount()).To(Equal(0))
277+
_ = db.AddTool(openai.ChatCompletionToolParam{
278+
Type: "function",
279+
Function: openai.FunctionDefinitionParam{
280+
Name: "weather",
281+
Description: param.NewOpt("Get weather info"),
282+
},
283+
}, "Get weather info", "utility", []string{"weather"})
284+
Expect(db.GetToolCount()).To(Equal(1))
285+
})
286+
287+
It("should return zero if disabled", func() {
288+
db := tools.NewToolsDatabase(tools.ToolsDatabaseOptions{
289+
SimilarityThreshold: 0.8,
290+
Enabled: false,
291+
})
292+
Expect(db.GetToolCount()).To(Equal(0))
293+
})
294+
})
295+
})

0 commit comments

Comments
 (0)