Skip to content

Commit 196961c

Browse files
authored
Merge branch 'main' into gha
2 parents b82dea1 + cc2417d commit 196961c

File tree

9 files changed

+804
-34
lines changed

9 files changed

+804
-34
lines changed

e2e-tests/03-classification-api-test.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,29 +189,80 @@ def test_batch_classification(self):
189189
response_json = response.json()
190190
results = response_json.get("results", [])
191191

192+
# Extract actual categories from results
193+
actual_categories = []
194+
correct_classifications = 0
195+
196+
for i, result in enumerate(results):
197+
if isinstance(result, dict):
198+
actual_category = result.get("category", "unknown")
199+
else:
200+
actual_category = "unknown"
201+
202+
actual_categories.append(actual_category)
203+
204+
if (
205+
i < len(expected_categories)
206+
and actual_category == expected_categories[i]
207+
):
208+
correct_classifications += 1
209+
210+
# Calculate accuracy
211+
accuracy = (
212+
(correct_classifications / len(expected_categories)) * 100
213+
if expected_categories
214+
else 0
215+
)
216+
192217
self.print_response_info(
193218
response,
194219
{
195220
"Total Texts": len(texts),
196221
"Results Count": len(results),
197222
"Processing Time (ms)": response_json.get("processing_time_ms", 0),
223+
"Accuracy": f"{accuracy:.1f}% ({correct_classifications}/{len(expected_categories)})",
198224
},
199225
)
200226

201-
passed = response.status_code == 200 and len(results) == len(texts)
227+
# Print detailed classification results
228+
print("\n📊 Detailed Classification Results:")
229+
for i, (text, expected, actual) in enumerate(
230+
zip(texts, expected_categories, actual_categories)
231+
):
232+
status = "✅" if expected == actual else "❌"
233+
print(f" {i+1}. {status} Expected: {expected:<15} | Actual: {actual:<15}")
234+
print(f" Text: {text[:60]}...")
235+
236+
# Check basic requirements first
237+
basic_checks_passed = response.status_code == 200 and len(results) == len(texts)
238+
239+
# Check classification accuracy (should be high for a working system)
240+
accuracy_threshold = 75.0 # Expect at least 75% accuracy
241+
accuracy_passed = accuracy >= accuracy_threshold
242+
243+
overall_passed = basic_checks_passed and accuracy_passed
202244

203245
self.print_test_result(
204-
passed=passed,
246+
passed=overall_passed,
205247
message=(
206-
f"Successfully classified {len(results)} texts"
207-
if passed
208-
else f"Batch classification failed or returned wrong count"
248+
f"Successfully classified {len(results)} texts with {accuracy:.1f}% accuracy"
249+
if overall_passed
250+
else f"Batch classification issues: Basic checks: {basic_checks_passed}, Accuracy: {accuracy:.1f}% (threshold: {accuracy_threshold}%)"
209251
),
210252
)
211253

254+
# Basic checks
212255
self.assertEqual(response.status_code, 200, "Batch request failed")
213256
self.assertEqual(len(results), len(texts), "Result count mismatch")
214257

258+
# NEW: Validate classification accuracy
259+
self.assertGreaterEqual(
260+
accuracy,
261+
accuracy_threshold,
262+
f"Classification accuracy too low: {accuracy:.1f}% < {accuracy_threshold}%. "
263+
f"Expected: {expected_categories}, Actual: {actual_categories}",
264+
)
265+
215266

216267
if __name__ == "__main__":
217268
unittest.main()

src/semantic-router/cmd/main.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ import (
1515
func main() {
1616
// Parse command-line flags
1717
var (
18-
configPath = flag.String("config", "config/config.yaml", "Path to the configuration file")
19-
port = flag.Int("port", 50051, "Port to listen on for gRPC ExtProc")
20-
apiPort = flag.Int("api-port", 8080, "Port to listen on for Classification API")
21-
metricsPort = flag.Int("metrics-port", 9190, "Port for Prometheus metrics")
22-
enableAPI = flag.Bool("enable-api", true, "Enable Classification API server")
23-
secure = flag.Bool("secure", false, "Enable secure gRPC server with TLS")
24-
certPath = flag.String("cert-path", "", "Path to TLS certificate directory (containing tls.crt and tls.key)")
18+
configPath = flag.String("config", "config/config.yaml", "Path to the configuration file")
19+
port = flag.Int("port", 50051, "Port to listen on for gRPC ExtProc")
20+
apiPort = flag.Int("api-port", 8080, "Port to listen on for Classification API")
21+
metricsPort = flag.Int("metrics-port", 9190, "Port for Prometheus metrics")
22+
enableAPI = flag.Bool("enable-api", true, "Enable Classification API server")
23+
enableSystemPromptAPI = flag.Bool("enable-system-prompt-api", false, "Enable system prompt configuration endpoints (SECURITY: only enable in trusted environments)")
24+
secure = flag.Bool("secure", false, "Enable secure gRPC server with TLS")
25+
certPath = flag.String("cert-path", "", "Path to TLS certificate directory (containing tls.crt and tls.key)")
2526
)
2627
flag.Parse()
2728

@@ -58,7 +59,7 @@ func main() {
5859
if *enableAPI {
5960
go func() {
6061
observability.Infof("Starting Classification API server on port %d", *apiPort)
61-
if err := api.StartClassificationAPI(*configPath, *apiPort); err != nil {
62+
if err := api.StartClassificationAPI(*configPath, *apiPort, *enableSystemPromptAPI); err != nil {
6263
observability.Errorf("Classification API server error: %v", err)
6364
}
6465
}()

src/semantic-router/pkg/api/server.go

Lines changed: 175 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ import (
1717

1818
// ClassificationAPIServer holds the server state and dependencies
1919
type ClassificationAPIServer struct {
20-
classificationSvc *services.ClassificationService
21-
config *config.RouterConfig
20+
classificationSvc *services.ClassificationService
21+
config *config.RouterConfig
22+
enableSystemPromptAPI bool
2223
}
2324

2425
// ModelsInfoResponse represents the response for models info endpoint
@@ -101,7 +102,7 @@ type ClassificationOptions struct {
101102
}
102103

103104
// StartClassificationAPI starts the Classification API server
104-
func StartClassificationAPI(configPath string, port int) error {
105+
func StartClassificationAPI(configPath string, port int, enableSystemPromptAPI bool) error {
105106
// Load configuration
106107
cfg, err := config.LoadConfig(configPath)
107108
if err != nil {
@@ -139,8 +140,9 @@ func StartClassificationAPI(configPath string, port int) error {
139140

140141
// Create server instance
141142
apiServer := &ClassificationAPIServer{
142-
classificationSvc: classificationSvc,
143-
config: cfg,
143+
classificationSvc: classificationSvc,
144+
config: cfg,
145+
enableSystemPromptAPI: enableSystemPromptAPI,
144146
}
145147

146148
// Create HTTP server with routes
@@ -203,6 +205,15 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux {
203205
mux.HandleFunc("GET /config/classification", s.handleGetConfig)
204206
mux.HandleFunc("PUT /config/classification", s.handleUpdateConfig)
205207

208+
// System prompt configuration endpoints (only if explicitly enabled)
209+
if s.enableSystemPromptAPI {
210+
observability.Infof("System prompt configuration endpoints enabled")
211+
mux.HandleFunc("GET /config/system-prompts", s.handleGetSystemPrompts)
212+
mux.HandleFunc("PUT /config/system-prompts", s.handleUpdateSystemPrompts)
213+
} else {
214+
observability.Infof("System prompt configuration endpoints disabled for security")
215+
}
216+
206217
return mux
207218
}
208219

@@ -221,7 +232,16 @@ func (s *ClassificationAPIServer) handleIntentClassification(w http.ResponseWrit
221232
return
222233
}
223234

224-
response, err := s.classificationSvc.ClassifyIntent(req)
235+
// Use unified classifier if available, otherwise fall back to legacy
236+
var response *services.IntentResponse
237+
var err error
238+
239+
if s.classificationSvc.HasUnifiedClassifier() {
240+
response, err = s.classificationSvc.ClassifyIntentUnified(req)
241+
} else {
242+
response, err = s.classificationSvc.ClassifyIntent(req)
243+
}
244+
225245
if err != nil {
226246
s.writeErrorResponse(w, http.StatusInternalServerError, "CLASSIFICATION_ERROR", err.Error())
227247
return
@@ -705,3 +725,152 @@ func (s *ClassificationAPIServer) calculateUnifiedStatistics(unifiedResults *ser
705725
LowConfidenceCount: lowConfidenceCount,
706726
}
707727
}
728+
729+
// SystemPromptInfo represents system prompt information for a category
730+
type SystemPromptInfo struct {
731+
Category string `json:"category"`
732+
Prompt string `json:"prompt"`
733+
Enabled bool `json:"enabled"`
734+
Mode string `json:"mode"` // "replace" or "insert"
735+
}
736+
737+
// SystemPromptsResponse represents the response for GET /config/system-prompts
738+
type SystemPromptsResponse struct {
739+
SystemPrompts []SystemPromptInfo `json:"system_prompts"`
740+
}
741+
742+
// SystemPromptUpdateRequest represents a request to update system prompt settings
743+
type SystemPromptUpdateRequest struct {
744+
Category string `json:"category,omitempty"` // If empty, applies to all categories
745+
Enabled *bool `json:"enabled,omitempty"` // true to enable, false to disable
746+
Mode string `json:"mode,omitempty"` // "replace" or "insert"
747+
}
748+
749+
// handleGetSystemPrompts handles GET /config/system-prompts
750+
func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, r *http.Request) {
751+
cfg := s.config
752+
if cfg == nil {
753+
http.Error(w, "Configuration not available", http.StatusInternalServerError)
754+
return
755+
}
756+
757+
var systemPrompts []SystemPromptInfo
758+
for _, category := range cfg.Categories {
759+
systemPrompts = append(systemPrompts, SystemPromptInfo{
760+
Category: category.Name,
761+
Prompt: category.SystemPrompt,
762+
Enabled: category.IsSystemPromptEnabled(),
763+
Mode: category.GetSystemPromptMode(),
764+
})
765+
}
766+
767+
response := SystemPromptsResponse{
768+
SystemPrompts: systemPrompts,
769+
}
770+
771+
w.Header().Set("Content-Type", "application/json")
772+
if err := json.NewEncoder(w).Encode(response); err != nil {
773+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
774+
return
775+
}
776+
}
777+
778+
// handleUpdateSystemPrompts handles PUT /config/system-prompts
779+
func (s *ClassificationAPIServer) handleUpdateSystemPrompts(w http.ResponseWriter, r *http.Request) {
780+
var req SystemPromptUpdateRequest
781+
if err := s.parseJSONRequest(r, &req); err != nil {
782+
http.Error(w, err.Error(), http.StatusBadRequest)
783+
return
784+
}
785+
786+
if req.Enabled == nil && req.Mode == "" {
787+
http.Error(w, "either enabled or mode field is required", http.StatusBadRequest)
788+
return
789+
}
790+
791+
// Validate mode if provided
792+
if req.Mode != "" && req.Mode != "replace" && req.Mode != "insert" {
793+
http.Error(w, "mode must be either 'replace' or 'insert'", http.StatusBadRequest)
794+
return
795+
}
796+
797+
cfg := s.config
798+
if cfg == nil {
799+
http.Error(w, "Configuration not available", http.StatusInternalServerError)
800+
return
801+
}
802+
803+
// Create a copy of the config to modify
804+
newCfg := *cfg
805+
newCategories := make([]config.Category, len(cfg.Categories))
806+
copy(newCategories, cfg.Categories)
807+
newCfg.Categories = newCategories
808+
809+
updated := false
810+
if req.Category == "" {
811+
// Update all categories
812+
for i := range newCfg.Categories {
813+
if newCfg.Categories[i].SystemPrompt != "" {
814+
if req.Enabled != nil {
815+
newCfg.Categories[i].SystemPromptEnabled = req.Enabled
816+
}
817+
if req.Mode != "" {
818+
newCfg.Categories[i].SystemPromptMode = req.Mode
819+
}
820+
updated = true
821+
}
822+
}
823+
} else {
824+
// Update specific category
825+
for i := range newCfg.Categories {
826+
if newCfg.Categories[i].Name == req.Category {
827+
if newCfg.Categories[i].SystemPrompt == "" {
828+
http.Error(w, fmt.Sprintf("Category '%s' has no system prompt configured", req.Category), http.StatusBadRequest)
829+
return
830+
}
831+
if req.Enabled != nil {
832+
newCfg.Categories[i].SystemPromptEnabled = req.Enabled
833+
}
834+
if req.Mode != "" {
835+
newCfg.Categories[i].SystemPromptMode = req.Mode
836+
}
837+
updated = true
838+
break
839+
}
840+
}
841+
if !updated {
842+
http.Error(w, fmt.Sprintf("Category '%s' not found", req.Category), http.StatusNotFound)
843+
return
844+
}
845+
}
846+
847+
if !updated {
848+
http.Error(w, "No categories with system prompts found to update", http.StatusBadRequest)
849+
return
850+
}
851+
852+
// Update the configuration
853+
s.config = &newCfg
854+
s.classificationSvc.UpdateConfig(&newCfg)
855+
856+
// Return the updated system prompts
857+
var systemPrompts []SystemPromptInfo
858+
for _, category := range newCfg.Categories {
859+
systemPrompts = append(systemPrompts, SystemPromptInfo{
860+
Category: category.Name,
861+
Prompt: category.SystemPrompt,
862+
Enabled: category.IsSystemPromptEnabled(),
863+
Mode: category.GetSystemPromptMode(),
864+
})
865+
}
866+
867+
response := SystemPromptsResponse{
868+
SystemPrompts: systemPrompts,
869+
}
870+
871+
w.Header().Set("Content-Type", "application/json")
872+
if err := json.NewEncoder(w).Encode(response); err != nil {
873+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
874+
return
875+
}
876+
}

0 commit comments

Comments
 (0)