diff --git a/src/semantic-router/cmd/main.go b/src/semantic-router/cmd/main.go index 99025735..f8a0fb67 100644 --- a/src/semantic-router/cmd/main.go +++ b/src/semantic-router/cmd/main.go @@ -15,13 +15,14 @@ import ( func main() { // Parse command-line flags var ( - configPath = flag.String("config", "config/config.yaml", "Path to the configuration file") - port = flag.Int("port", 50051, "Port to listen on for gRPC ExtProc") - apiPort = flag.Int("api-port", 8080, "Port to listen on for Classification API") - metricsPort = flag.Int("metrics-port", 9190, "Port for Prometheus metrics") - enableAPI = flag.Bool("enable-api", true, "Enable Classification API server") - secure = flag.Bool("secure", false, "Enable secure gRPC server with TLS") - certPath = flag.String("cert-path", "", "Path to TLS certificate directory (containing tls.crt and tls.key)") + configPath = flag.String("config", "config/config.yaml", "Path to the configuration file") + port = flag.Int("port", 50051, "Port to listen on for gRPC ExtProc") + apiPort = flag.Int("api-port", 8080, "Port to listen on for Classification API") + metricsPort = flag.Int("metrics-port", 9190, "Port for Prometheus metrics") + enableAPI = flag.Bool("enable-api", true, "Enable Classification API server") + enableSystemPromptAPI = flag.Bool("enable-system-prompt-api", false, "Enable system prompt configuration endpoints (SECURITY: only enable in trusted environments)") + secure = flag.Bool("secure", false, "Enable secure gRPC server with TLS") + certPath = flag.String("cert-path", "", "Path to TLS certificate directory (containing tls.crt and tls.key)") ) flag.Parse() @@ -58,7 +59,7 @@ func main() { if *enableAPI { go func() { observability.Infof("Starting Classification API server on port %d", *apiPort) - if err := api.StartClassificationAPI(*configPath, *apiPort); err != nil { + if err := api.StartClassificationAPI(*configPath, *apiPort, *enableSystemPromptAPI); err != nil { observability.Errorf("Classification API server error: %v", err) } }() diff --git a/src/semantic-router/pkg/api/server.go b/src/semantic-router/pkg/api/server.go index 6fe8dfd3..a281a811 100644 --- a/src/semantic-router/pkg/api/server.go +++ b/src/semantic-router/pkg/api/server.go @@ -17,8 +17,9 @@ import ( // ClassificationAPIServer holds the server state and dependencies type ClassificationAPIServer struct { - classificationSvc *services.ClassificationService - config *config.RouterConfig + classificationSvc *services.ClassificationService + config *config.RouterConfig + enableSystemPromptAPI bool } // ModelsInfoResponse represents the response for models info endpoint @@ -101,7 +102,7 @@ type ClassificationOptions struct { } // StartClassificationAPI starts the Classification API server -func StartClassificationAPI(configPath string, port int) error { +func StartClassificationAPI(configPath string, port int, enableSystemPromptAPI bool) error { // Load configuration cfg, err := config.LoadConfig(configPath) if err != nil { @@ -139,8 +140,9 @@ func StartClassificationAPI(configPath string, port int) error { // Create server instance apiServer := &ClassificationAPIServer{ - classificationSvc: classificationSvc, - config: cfg, + classificationSvc: classificationSvc, + config: cfg, + enableSystemPromptAPI: enableSystemPromptAPI, } // Create HTTP server with routes @@ -203,6 +205,15 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux { mux.HandleFunc("GET /config/classification", s.handleGetConfig) mux.HandleFunc("PUT /config/classification", s.handleUpdateConfig) + // System prompt configuration endpoints (only if explicitly enabled) + if s.enableSystemPromptAPI { + observability.Infof("System prompt configuration endpoints enabled") + mux.HandleFunc("GET /config/system-prompts", s.handleGetSystemPrompts) + mux.HandleFunc("PUT /config/system-prompts", s.handleUpdateSystemPrompts) + } else { + observability.Infof("System prompt configuration endpoints disabled for security") + } + return mux } @@ -705,3 +716,152 @@ func (s *ClassificationAPIServer) calculateUnifiedStatistics(unifiedResults *ser LowConfidenceCount: lowConfidenceCount, } } + +// SystemPromptInfo represents system prompt information for a category +type SystemPromptInfo struct { + Category string `json:"category"` + Prompt string `json:"prompt"` + Enabled bool `json:"enabled"` + Mode string `json:"mode"` // "replace" or "insert" +} + +// SystemPromptsResponse represents the response for GET /config/system-prompts +type SystemPromptsResponse struct { + SystemPrompts []SystemPromptInfo `json:"system_prompts"` +} + +// SystemPromptUpdateRequest represents a request to update system prompt settings +type SystemPromptUpdateRequest struct { + Category string `json:"category,omitempty"` // If empty, applies to all categories + Enabled *bool `json:"enabled,omitempty"` // true to enable, false to disable + Mode string `json:"mode,omitempty"` // "replace" or "insert" +} + +// handleGetSystemPrompts handles GET /config/system-prompts +func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, r *http.Request) { + cfg := s.config + if cfg == nil { + http.Error(w, "Configuration not available", http.StatusInternalServerError) + return + } + + var systemPrompts []SystemPromptInfo + for _, category := range cfg.Categories { + systemPrompts = append(systemPrompts, SystemPromptInfo{ + Category: category.Name, + Prompt: category.SystemPrompt, + Enabled: category.IsSystemPromptEnabled(), + Mode: category.GetSystemPromptMode(), + }) + } + + response := SystemPromptsResponse{ + SystemPrompts: systemPrompts, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } +} + +// handleUpdateSystemPrompts handles PUT /config/system-prompts +func (s *ClassificationAPIServer) handleUpdateSystemPrompts(w http.ResponseWriter, r *http.Request) { + var req SystemPromptUpdateRequest + if err := s.parseJSONRequest(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if req.Enabled == nil && req.Mode == "" { + http.Error(w, "either enabled or mode field is required", http.StatusBadRequest) + return + } + + // Validate mode if provided + if req.Mode != "" && req.Mode != "replace" && req.Mode != "insert" { + http.Error(w, "mode must be either 'replace' or 'insert'", http.StatusBadRequest) + return + } + + cfg := s.config + if cfg == nil { + http.Error(w, "Configuration not available", http.StatusInternalServerError) + return + } + + // Create a copy of the config to modify + newCfg := *cfg + newCategories := make([]config.Category, len(cfg.Categories)) + copy(newCategories, cfg.Categories) + newCfg.Categories = newCategories + + updated := false + if req.Category == "" { + // Update all categories + for i := range newCfg.Categories { + if newCfg.Categories[i].SystemPrompt != "" { + if req.Enabled != nil { + newCfg.Categories[i].SystemPromptEnabled = req.Enabled + } + if req.Mode != "" { + newCfg.Categories[i].SystemPromptMode = req.Mode + } + updated = true + } + } + } else { + // Update specific category + for i := range newCfg.Categories { + if newCfg.Categories[i].Name == req.Category { + if newCfg.Categories[i].SystemPrompt == "" { + http.Error(w, fmt.Sprintf("Category '%s' has no system prompt configured", req.Category), http.StatusBadRequest) + return + } + if req.Enabled != nil { + newCfg.Categories[i].SystemPromptEnabled = req.Enabled + } + if req.Mode != "" { + newCfg.Categories[i].SystemPromptMode = req.Mode + } + updated = true + break + } + } + if !updated { + http.Error(w, fmt.Sprintf("Category '%s' not found", req.Category), http.StatusNotFound) + return + } + } + + if !updated { + http.Error(w, "No categories with system prompts found to update", http.StatusBadRequest) + return + } + + // Update the configuration + s.config = &newCfg + s.classificationSvc.UpdateConfig(&newCfg) + + // Return the updated system prompts + var systemPrompts []SystemPromptInfo + for _, category := range newCfg.Categories { + systemPrompts = append(systemPrompts, SystemPromptInfo{ + Category: category.Name, + Prompt: category.SystemPrompt, + Enabled: category.IsSystemPromptEnabled(), + Mode: category.GetSystemPromptMode(), + }) + } + + response := SystemPromptsResponse{ + SystemPrompts: systemPrompts, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } +} diff --git a/src/semantic-router/pkg/api/server_test.go b/src/semantic-router/pkg/api/server_test.go index 4e5c18f1..450b3d20 100644 --- a/src/semantic-router/pkg/api/server_test.go +++ b/src/semantic-router/pkg/api/server_test.go @@ -305,3 +305,429 @@ func TestOpenAIModelsEndpoint(t *testing.T) { t.Errorf("expected configured models to be present, got=%v", got) } } + +// TestSystemPromptEndpointSecurity tests that system prompt endpoints are only accessible when explicitly enabled +func TestSystemPromptEndpointSecurity(t *testing.T) { + // Create test configuration with categories that have system prompts + cfg := &config.RouterConfig{ + Categories: []config.Category{ + { + Name: "math", + SystemPrompt: "You are a math expert.", + SystemPromptEnabled: &[]bool{true}[0], // Pointer to true + SystemPromptMode: "replace", + }, + { + Name: "coding", + SystemPrompt: "You are a coding assistant.", + SystemPromptEnabled: &[]bool{false}[0], // Pointer to false + SystemPromptMode: "insert", + }, + }, + } + + tests := []struct { + name string + enableSystemPromptAPI bool + method string + path string + requestBody string + expectedStatus int + description string + }{ + { + name: "GET system prompts - disabled API", + enableSystemPromptAPI: false, + method: "GET", + path: "/config/system-prompts", + expectedStatus: http.StatusNotFound, + description: "Should return 404 when system prompt API is disabled", + }, + { + name: "PUT system prompts - disabled API", + enableSystemPromptAPI: false, + method: "PUT", + path: "/config/system-prompts", + requestBody: `{"enabled": true}`, + expectedStatus: http.StatusNotFound, + description: "Should return 404 when system prompt API is disabled", + }, + { + name: "GET system prompts - enabled API", + enableSystemPromptAPI: true, + method: "GET", + path: "/config/system-prompts", + expectedStatus: http.StatusOK, + description: "Should return 200 when system prompt API is enabled", + }, + { + name: "PUT system prompts - enabled API - valid request", + enableSystemPromptAPI: true, + method: "PUT", + path: "/config/system-prompts", + requestBody: `{"category": "math", "enabled": false}`, + expectedStatus: http.StatusOK, + description: "Should return 200 for valid PUT request when API is enabled", + }, + { + name: "PUT system prompts - enabled API - invalid request", + enableSystemPromptAPI: true, + method: "PUT", + path: "/config/system-prompts", + requestBody: `{"category": "nonexistent"}`, + expectedStatus: http.StatusBadRequest, + description: "Should return 400 for invalid PUT request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test server that simulates the behavior + var mux *http.ServeMux + if tt.enableSystemPromptAPI { + // Simulate enabled API - create a server that has the endpoints + mux = http.NewServeMux() + mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("GET /config/classification", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("PUT /config/classification", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + // Add system prompt endpoints when enabled + mux.HandleFunc("GET /config/system-prompts", func(w http.ResponseWriter, r *http.Request) { + // Create a test server instance with config for the handler + testServerWithConfig := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: cfg, + enableSystemPromptAPI: true, + } + testServerWithConfig.handleGetSystemPrompts(w, r) + }) + mux.HandleFunc("PUT /config/system-prompts", func(w http.ResponseWriter, r *http.Request) { + // Create a test server instance with config for the handler + testServerWithConfig := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: cfg, + enableSystemPromptAPI: true, + } + testServerWithConfig.handleUpdateSystemPrompts(w, r) + }) + } else { + // Simulate disabled API - create a server without the endpoints + mux = http.NewServeMux() + mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("GET /config/classification", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("PUT /config/classification", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + // System prompt endpoints are NOT registered when disabled + } + + // Create request + var req *http.Request + if tt.requestBody != "" { + req = httptest.NewRequest(tt.method, tt.path, bytes.NewBufferString(tt.requestBody)) + req.Header.Set("Content-Type", "application/json") + } else { + req = httptest.NewRequest(tt.method, tt.path, nil) + } + + rr := httptest.NewRecorder() + + // Serve the request + mux.ServeHTTP(rr, req) + + // Check status code + if rr.Code != tt.expectedStatus { + t.Errorf("%s: expected status %d, got %d. Response: %s", + tt.description, tt.expectedStatus, rr.Code, rr.Body.String()) + } + + // Additional checks for specific cases + if tt.enableSystemPromptAPI && tt.method == "GET" && tt.expectedStatus == http.StatusOK { + // Verify the response structure for GET requests + var response SystemPromptsResponse + if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { + t.Errorf("Failed to unmarshal GET response: %v", err) + } + + // Should have system prompts from config + if len(response.SystemPrompts) != 2 { + t.Errorf("Expected 2 system prompts, got %d", len(response.SystemPrompts)) + } + + // Verify the content + foundMath := false + foundCoding := false + for _, sp := range response.SystemPrompts { + if sp.Category == "math" { + foundMath = true + if sp.Prompt != "You are a math expert." { + t.Errorf("Expected math prompt 'You are a math expert.', got '%s'", sp.Prompt) + } + if !sp.Enabled { + t.Errorf("Expected math category to be enabled") + } + if sp.Mode != "replace" { + t.Errorf("Expected math mode 'replace', got '%s'", sp.Mode) + } + } + if sp.Category == "coding" { + foundCoding = true + if sp.Enabled { + t.Errorf("Expected coding category to be disabled") + } + if sp.Mode != "insert" { + t.Errorf("Expected coding mode 'insert', got '%s'", sp.Mode) + } + } + } + + if !foundMath || !foundCoding { + t.Errorf("Expected to find both math and coding categories") + } + } + }) + } +} + +// TestSystemPromptEndpointFunctionality tests the actual functionality of system prompt endpoints +func TestSystemPromptEndpointFunctionality(t *testing.T) { + // Create test configuration + cfg := &config.RouterConfig{ + Categories: []config.Category{ + { + Name: "math", + SystemPrompt: "You are a math expert.", + SystemPromptEnabled: &[]bool{true}[0], + SystemPromptMode: "replace", + }, + { + Name: "no-prompt", + SystemPrompt: "", // No system prompt + }, + }, + } + + // Create a test server with the config for functionality testing + apiServer := &ClassificationAPIServer{ + classificationSvc: services.NewPlaceholderClassificationService(), + config: cfg, + enableSystemPromptAPI: true, // Enable for functionality testing + } + + t.Run("GET system prompts returns correct data", func(t *testing.T) { + req := httptest.NewRequest("GET", "/config/system-prompts", nil) + rr := httptest.NewRecorder() + + apiServer.handleGetSystemPrompts(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", rr.Code) + } + + var response SystemPromptsResponse + if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if len(response.SystemPrompts) != 2 { + t.Errorf("Expected 2 categories, got %d", len(response.SystemPrompts)) + } + }) + + t.Run("PUT system prompts - enable specific category", func(t *testing.T) { + requestBody := `{"category": "math", "enabled": false}` + req := httptest.NewRequest("PUT", "/config/system-prompts", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + apiServer.handleUpdateSystemPrompts(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d. Response: %s", rr.Code, rr.Body.String()) + } + + var response SystemPromptsResponse + if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Find the math category and verify it's disabled + for _, sp := range response.SystemPrompts { + if sp.Category == "math" && sp.Enabled { + t.Errorf("Expected math category to be disabled after PUT request") + } + } + }) + + t.Run("PUT system prompts - change mode", func(t *testing.T) { + requestBody := `{"category": "math", "mode": "insert"}` + req := httptest.NewRequest("PUT", "/config/system-prompts", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + apiServer.handleUpdateSystemPrompts(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d. Response: %s", rr.Code, rr.Body.String()) + } + + var response SystemPromptsResponse + if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Find the math category and verify mode is changed + for _, sp := range response.SystemPrompts { + if sp.Category == "math" && sp.Mode != "insert" { + t.Errorf("Expected math category mode to be 'insert', got '%s'", sp.Mode) + } + } + }) + + t.Run("PUT system prompts - update all categories", func(t *testing.T) { + requestBody := `{"enabled": true}` // No category specified = update all + req := httptest.NewRequest("PUT", "/config/system-prompts", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + apiServer.handleUpdateSystemPrompts(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d. Response: %s", rr.Code, rr.Body.String()) + } + }) + + t.Run("PUT system prompts - invalid category", func(t *testing.T) { + requestBody := `{"category": "nonexistent", "enabled": true}` + req := httptest.NewRequest("PUT", "/config/system-prompts", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + apiServer.handleUpdateSystemPrompts(rr, req) + + if rr.Code != http.StatusNotFound { + t.Errorf("Expected 404 for nonexistent category, got %d", rr.Code) + } + }) + + t.Run("PUT system prompts - category without system prompt", func(t *testing.T) { + requestBody := `{"category": "no-prompt", "enabled": true}` + req := httptest.NewRequest("PUT", "/config/system-prompts", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + apiServer.handleUpdateSystemPrompts(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for category without system prompt, got %d", rr.Code) + } + }) + + t.Run("PUT system prompts - invalid mode", func(t *testing.T) { + requestBody := `{"category": "math", "mode": "invalid"}` + req := httptest.NewRequest("PUT", "/config/system-prompts", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + apiServer.handleUpdateSystemPrompts(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for invalid mode, got %d", rr.Code) + } + }) + + t.Run("PUT system prompts - empty request", func(t *testing.T) { + requestBody := `{}` + req := httptest.NewRequest("PUT", "/config/system-prompts", bytes.NewBufferString(requestBody)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + apiServer.handleUpdateSystemPrompts(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for empty request, got %d", rr.Code) + } + }) +} + +// TestSetupRoutesSecurityBehavior tests that setupRoutes correctly includes/excludes endpoints based on security flag +func TestSetupRoutesSecurityBehavior(t *testing.T) { + tests := []struct { + name string + enableSystemPromptAPI bool + expectedEndpoints map[string]bool // path -> should exist + }{ + { + name: "System prompt API disabled", + enableSystemPromptAPI: false, + expectedEndpoints: map[string]bool{ + "/health": true, + "/config/classification": true, + "/config/system-prompts": false, // Should NOT exist + }, + }, + { + name: "System prompt API enabled", + enableSystemPromptAPI: true, + expectedEndpoints: map[string]bool{ + "/health": true, + "/config/classification": true, + "/config/system-prompts": true, // Should exist + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test mux that simulates the setupRoutes behavior + mux := http.NewServeMux() + + // Always add basic endpoints + mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("GET /config/classification", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Conditionally add system prompt endpoints based on the flag + if tt.enableSystemPromptAPI { + mux.HandleFunc("GET /config/system-prompts", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("PUT /config/system-prompts", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + } + + // Test each endpoint + for path, shouldExist := range tt.expectedEndpoints { + req := httptest.NewRequest("GET", path, nil) + rr := httptest.NewRecorder() + + mux.ServeHTTP(rr, req) + + if shouldExist { + // Endpoint should exist (not 404) + if rr.Code == http.StatusNotFound { + t.Errorf("Expected endpoint %s to exist, but got 404", path) + } + } else { + // Endpoint should NOT exist (404) + if rr.Code != http.StatusNotFound { + t.Errorf("Expected endpoint %s to return 404, but got %d", path, rr.Code) + } + } + } + }) + } +} diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index 1f481af0..78edc546 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -276,6 +276,13 @@ type Category struct { MMLUCategories []string `yaml:"mmlu_categories,omitempty"` // SystemPrompt is an optional category-specific system prompt automatically injected into requests SystemPrompt string `yaml:"system_prompt,omitempty"` + // SystemPromptEnabled controls whether the system prompt should be injected for this category + // Defaults to true when SystemPrompt is not empty + SystemPromptEnabled *bool `yaml:"system_prompt_enabled,omitempty"` + // SystemPromptMode controls how the system prompt is injected: "replace" (default) or "insert" + // "replace": Replace any existing system message with the category-specific prompt + // "insert": Prepend the category-specific prompt to the existing system message content + SystemPromptMode string `yaml:"system_prompt_mode,omitempty"` } // Legacy types - can be removed once migration is complete @@ -411,6 +418,8 @@ func ReplaceGlobalConfig(newCfg *RouterConfig) { // GetConfig returns the current configuration func GetConfig() *RouterConfig { + configMu.RLock() + defer configMu.RUnlock() return config } @@ -671,3 +680,31 @@ func (c *RouterConfig) ValidateEndpoints() error { return nil } + +// IsSystemPromptEnabled returns whether system prompt injection is enabled for a category +func (c *Category) IsSystemPromptEnabled() bool { + // If SystemPromptEnabled is explicitly set, use that value + if c.SystemPromptEnabled != nil { + return *c.SystemPromptEnabled + } + // Default to true if SystemPrompt is not empty + return c.SystemPrompt != "" +} + +// GetSystemPromptMode returns the system prompt injection mode, defaulting to "replace" +func (c *Category) GetSystemPromptMode() string { + if c.SystemPromptMode == "" { + return "replace" // Default mode + } + return c.SystemPromptMode +} + +// GetCategoryByName returns a category by name +func (c *RouterConfig) GetCategoryByName(name string) *Category { + for i := range c.Categories { + if c.Categories[i].Name == name { + return &c.Categories[i] + } + } + return nil +} diff --git a/src/semantic-router/pkg/extproc/request_handler.go b/src/semantic-router/pkg/extproc/request_handler.go index 52b04e8e..46490ff5 100644 --- a/src/semantic-router/pkg/extproc/request_handler.go +++ b/src/semantic-router/pkg/extproc/request_handler.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc/status" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/cache" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/http" @@ -72,7 +73,7 @@ func serializeOpenAIRequestWithStream(req *openai.ChatCompletionNewParams, hasSt // addSystemPromptToRequestBody adds a system prompt to the beginning of the messages array in the JSON request body // Returns the modified body, whether the system prompt was actually injected, and any error -func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]byte, bool, error) { +func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string, mode string) ([]byte, bool, error) { if systemPrompt == "" { return requestBody, false, nil } @@ -94,32 +95,63 @@ func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]by return requestBody, false, nil // Messages is not an array, return original } - // Create a new system message - systemMessage := map[string]interface{}{ - "role": "system", - "content": systemPrompt, - } - // Check if there's already a system message at the beginning hasSystemMessage := false + var existingSystemContent string if len(messages) > 0 { if firstMsg, ok := messages[0].(map[string]interface{}); ok { if role, ok := firstMsg["role"].(string); ok && role == "system" { hasSystemMessage = true + if content, ok := firstMsg["content"].(string); ok { + existingSystemContent = content + } } } } + // Handle different injection modes + var finalSystemContent string + var logMessage string + + switch mode { + case "insert": + if hasSystemMessage { + // Insert mode: prepend category prompt to existing system message + finalSystemContent = systemPrompt + "\n\n" + existingSystemContent + logMessage = "Inserted category-specific system prompt before existing system message" + } else { + // No existing system message, just use the category prompt + finalSystemContent = systemPrompt + logMessage = "Added category-specific system prompt (insert mode, no existing system message)" + } + case "replace": + fallthrough + default: + // Replace mode: use only the category prompt + finalSystemContent = systemPrompt + if hasSystemMessage { + logMessage = "Replaced existing system message with category-specific system prompt" + } else { + logMessage = "Added category-specific system prompt to the beginning of messages" + } + } + + // Create the final system message + systemMessage := map[string]interface{}{ + "role": "system", + "content": finalSystemContent, + } + if hasSystemMessage { - // Replace the existing system message + // Update the existing system message messages[0] = systemMessage - observability.Infof("Replaced existing system message with category-specific system prompt") } else { // Prepend the system message to the beginning of the messages array messages = append([]interface{}{systemMessage}, messages...) - observability.Infof("Added category-specific system prompt to the beginning of messages") } + observability.Infof("%s (mode: %s)", logMessage, mode) + // Update the messages in the request map requestMap["messages"] = messages @@ -564,10 +596,23 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe // Add category-specific system prompt if configured if categoryName != "" { - category := r.Classifier.GetCategoryByName(categoryName) - if category != nil && category.SystemPrompt != "" { + // Try to get the most up-to-date category configuration from global config first + // This ensures API updates are reflected immediately + globalConfig := config.GetConfig() + var category *config.Category + if globalConfig != nil { + category = globalConfig.GetCategoryByName(categoryName) + } + + // If not found in global config, fall back to router's config (for tests and initial setup) + if category == nil { + category = r.Classifier.GetCategoryByName(categoryName) + } + + if category != nil && category.SystemPrompt != "" && category.IsSystemPromptEnabled() { + mode := category.GetSystemPromptMode() var injected bool - modifiedBody, injected, err = addSystemPromptToRequestBody(modifiedBody, category.SystemPrompt) + modifiedBody, injected, err = addSystemPromptToRequestBody(modifiedBody, category.SystemPrompt, mode) if err != nil { observability.Errorf("Error adding system prompt to request: %v", err) metrics.RecordRequestError(actualModel, "serialization_error") @@ -575,8 +620,13 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe } if injected { ctx.VSRInjectedSystemPrompt = true - observability.Infof("Added category-specific system prompt for category: %s", categoryName) + observability.Infof("Added category-specific system prompt for category: %s (mode: %s)", categoryName, mode) } + + // Log metadata about system prompt injection (avoid logging sensitive user data) + observability.Infof("System prompt injection completed for category: %s, body size: %d bytes", categoryName, len(modifiedBody)) + } else if category != nil && category.SystemPrompt != "" && !category.IsSystemPromptEnabled() { + observability.Infof("System prompt disabled for category: %s", categoryName) } } diff --git a/src/semantic-router/pkg/services/classification.go b/src/semantic-router/pkg/services/classification.go index 1240e1e5..f58406b0 100644 --- a/src/semantic-router/pkg/services/classification.go +++ b/src/semantic-router/pkg/services/classification.go @@ -3,6 +3,7 @@ package services import ( "fmt" "os" + "sync" "time" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" @@ -18,6 +19,7 @@ type ClassificationService struct { classifier *classification.Classifier unifiedClassifier *classification.UnifiedClassifier // New unified classifier config *config.RouterConfig + configMutex sync.RWMutex // Protects config access } // NewClassificationService creates a new classification service @@ -485,3 +487,19 @@ func (s *ClassificationService) GetUnifiedClassifierStats() map[string]interface stats["available"] = true return stats } + +// GetConfig returns the current configuration +func (s *ClassificationService) GetConfig() *config.RouterConfig { + s.configMutex.RLock() + defer s.configMutex.RUnlock() + return s.config +} + +// UpdateConfig updates the configuration +func (s *ClassificationService) UpdateConfig(newConfig *config.RouterConfig) { + s.configMutex.Lock() + defer s.configMutex.Unlock() + s.config = newConfig + // Update the global config as well + config.ReplaceGlobalConfig(newConfig) +} diff --git a/tools/make/build-run-test.mk b/tools/make/build-run-test.mk index 67ccb4fa..b85490d9 100644 --- a/tools/make/build-run-test.mk +++ b/tools/make/build-run-test.mk @@ -16,7 +16,7 @@ build-router: rust run-router: build-router download-models @echo "Running router with config: ${CONFIG_FILE}" @export LD_LIBRARY_PATH=${PWD}/candle-binding/target/release && \ - ./bin/router -config=${CONFIG_FILE} + ./bin/router -config=${CONFIG_FILE} --enable-system-prompt-api=true # Run the router with e2e config for testing run-router-e2e: build-router download-models diff --git a/website/docs/overview/categories/configuration.md b/website/docs/overview/categories/configuration.md index 040a01c7..7bc776d0 100644 --- a/website/docs/overview/categories/configuration.md +++ b/website/docs/overview/categories/configuration.md @@ -55,6 +55,7 @@ categories: - **Type**: String - **Description**: Category-specific system prompt automatically injected into requests - **Behavior**: Replaces existing system messages or adds new one at the beginning +- **Runtime Control**: Can be enabled/disabled via API when `--enable-system-prompt-api` flag is used - **Example**: `"You are a mathematics expert. Provide step-by-step solutions."` ```yaml @@ -63,6 +64,23 @@ categories: system_prompt: "You are a mathematics expert. Provide step-by-step solutions, show your work clearly, and explain mathematical concepts in an understandable way." ``` +**Runtime Management**: System prompts can be dynamically controlled via REST API endpoints when the server is started with `--enable-system-prompt-api` flag: + +```bash +# Start server with system prompt API enabled +./semantic-router --enable-system-prompt-api + +# Toggle system prompt for specific category +curl -X PUT http://localhost:8080/config/system-prompts \ + -H "Content-Type: application/json" \ + -d '{"category": "math", "enabled": false}' + +# Set injection mode (replace/insert) +curl -X PUT http://localhost:8080/config/system-prompts \ + -H "Content-Type: application/json" \ + -d '{"category": "math", "mode": "insert"}' +``` + ### Reasoning Configuration #### `use_reasoning` (Required)