@@ -16,100 +16,106 @@ func TestClassifier(t *testing.T) {
1616 RunSpecs (t , "Classifier Suite" )
1717}
1818
19- // MockModelInference implements ModelInference interface for testing
20- type MockModelInference struct {
21- classifyTextResult candle_binding.ClassResult
22- classifyTextError error
23- classifyModernBertResult candle_binding.ClassResult
24- classifyModernBertError error
19+ type MockCategoryInference struct {
20+ classifyResult candle_binding.ClassResult
21+ classifyError error
2522}
2623
27- func (m * MockModelInference ) ClassifyText (text string ) (candle_binding.ClassResult , error ) {
28- return m .classifyTextResult , m .classifyTextError
24+ func (m * MockCategoryInference ) Classify (text string ) (candle_binding.ClassResult , error ) {
25+ return m .classifyResult , m .classifyError
2926}
3027
31- func (m * MockModelInference ) ClassifyModernBertText (text string ) (candle_binding.ClassResult , error ) {
32- return m .classifyModernBertResult , m .classifyModernBertError
33- }
34-
35- var _ = Describe ("Classifier" , func () {
28+ var _ = Describe ("ClassifyCategory" , func () {
3629 var (
37- classifier * Classifier
38- mockModel * MockModelInference
30+ classifier * Classifier
31+ mockCategoryModel * MockCategoryInference
3932 )
4033
4134 BeforeEach (func () {
42- mockModel = & MockModelInference {}
35+ mockCategoryModel = & MockCategoryInference {}
4336 cfg := & config.RouterConfig {}
4437 cfg .Classifier .CategoryModel .Threshold = 0.5 // Set threshold for testing
4538
4639 classifier = & Classifier {
47- modelInference : mockModel ,
48- Config : cfg ,
40+ categoryInference : mockCategoryModel ,
41+ Config : cfg ,
4942 CategoryMapping : & CategoryMapping {
5043 CategoryToIdx : map [string ]int {"technology" : 0 , "sports" : 1 , "politics" : 2 },
5144 IdxToCategory : map [string ]string {"0" : "technology" , "1" : "sports" , "2" : "politics" },
5245 },
5346 }
5447 })
5548
56- Describe ("ClassifyCategory" , func () {
57- Context ("when classification succeeds with high confidence" , func () {
58- It ("should return the correct category" , func () {
59- mockModel .classifyTextResult = candle_binding.ClassResult {
60- Class : 2 ,
61- Confidence : 0.95 ,
62- }
49+ Context ("when classification succeeds with high confidence" , func () {
50+ It ("should return the correct category" , func () {
51+ mockCategoryModel .classifyResult = candle_binding.ClassResult {
52+ Class : 2 ,
53+ Confidence : 0.95 ,
54+ }
55+
56+ category , score , err := classifier .ClassifyCategory ("This is about politics" )
57+
58+ Expect (err ).To (BeNil ())
59+ Expect (category ).To (Equal ("politics" ))
60+ Expect (score ).To (BeNumerically ("~" , 0.95 , 0.001 ))
61+ })
62+ })
63+
64+ Context ("when classification has low confidence below threshold" , func () {
65+ It ("should return empty category" , func () {
66+ mockCategoryModel .classifyResult = candle_binding.ClassResult {
67+ Class : 0 ,
68+ Confidence : 0.3 ,
69+ }
6370
64- category , score , err := classifier .ClassifyCategory ("This is about politics " )
71+ category , score , err := classifier .ClassifyCategory ("Ambiguous text " )
6572
66- Expect (err ).To (BeNil ())
67- Expect (category ).To (Equal ("politics" ))
68- Expect (score ).To (BeNumerically ("~" , 0.95 , 0.001 ))
69- })
73+ Expect (err ).To (BeNil ())
74+ Expect (category ).To (Equal ("" ))
75+ Expect (score ).To (BeNumerically ("~" , 0.3 , 0.001 ))
7076 })
77+ })
7178
72- Context ("when classification has low confidence below threshold" , func () {
73- It ("should return empty category" , func () {
74- mockModel .classifyTextResult = candle_binding.ClassResult {
75- Class : 0 ,
76- Confidence : 0.3 ,
77- }
79+ Context ("when BERT model returns error" , func () {
80+ It ("should return empty category with zero score" , func () {
81+ mockCategoryModel .classifyError = errors .New ("model inference failed" )
7882
79- category , score , err := classifier .ClassifyCategory ("Ambiguous text" )
83+ category , score , err := classifier .ClassifyCategory ("Some text" )
8084
81- Expect (err ).To (BeNil ())
82- Expect (category ).To (Equal ("" ))
83- Expect (score ).To (BeNumerically ("~" , 0.3 , 0.001 ))
84- })
85+ Expect (err ).ToNot (BeNil ())
86+ Expect (category ).To (Equal ("" ))
87+ Expect (score ).To (BeNumerically ("~" , 0.0 , 0.001 ))
8588 })
89+ })
8690
87- Context ("when BERT model returns error" , func () {
88- It ("should return unknown category with zero score" , func () {
89- mockModel .classifyTextError = errors .New ("model inference failed" )
91+ Context ("when input is empty or invalid" , func () {
92+ It ("should handle empty text gracefully" , func () {
93+ mockCategoryModel .classifyResult = candle_binding.ClassResult {
94+ Class : 0 ,
95+ Confidence : 0.8 ,
96+ }
9097
91- category , score , err := classifier .ClassifyCategory ("Some text " )
98+ category , score , err := classifier .ClassifyCategory ("" )
9299
93- Expect ( err ). ToNot ( BeNil ())
94- Expect (category ).To (Equal ( "" ))
95- Expect (score ).To (BeNumerically ( "~" , 0.0 , 0.001 ))
96- } )
100+ // Should still attempt classification
101+ Expect (err ).To (BeNil ( ))
102+ Expect (category ).To (Equal ( "technology" ))
103+ Expect ( score ). To ( BeNumerically ( "~" , 0.8 , 0.001 ) )
97104 })
105+ })
98106
99- Context ("when input is empty or invalid" , func () {
100- It ("should handle empty text gracefully" , func () {
101- mockModel . classifyTextResult = candle_binding.ClassResult {
102- Class : 0 ,
103- Confidence : 0.8 ,
104- }
107+ Context ("when category mapping is invalid" , func () {
108+ It ("should handle invalid category mapping gracefully" , func () {
109+ mockCategoryModel . classifyResult = candle_binding.ClassResult {
110+ Class : 9 ,
111+ Confidence : 0.8 ,
112+ }
105113
106- category , score , err := classifier .ClassifyCategory ("" )
114+ category , score , err := classifier .ClassifyCategory ("Some text " )
107115
108- // Should still attempt classification
109- Expect (err ).To (BeNil ())
110- Expect (category ).To (Equal ("technology" ))
111- Expect (score ).To (BeNumerically ("~" , 0.8 , 0.001 ))
112- })
116+ Expect (err ).To (BeNil ())
117+ Expect (category ).To (Equal ("" ))
118+ Expect (score ).To (BeNumerically ("~" , 0.8 , 0.001 ))
113119 })
114120 })
115121})
0 commit comments