Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/semantic-router/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ import (
func main() {
// Parse command-line flags
var (
configPath = flag.String("config", "config/config.yaml", "Path to the configuration file")
port = flag.Int("port", 50051, "Port to listen on for gRPC ExtProc")
apiPort = flag.Int("api-port", 8080, "Port to listen on for Classification API")
metricsPort = flag.Int("metrics-port", 9190, "Port for Prometheus metrics")
enableAPI = flag.Bool("enable-api", true, "Enable Classification API server")
secure = flag.Bool("secure", false, "Enable secure gRPC server with TLS")
certPath = flag.String("cert-path", "", "Path to TLS certificate directory (containing tls.crt and tls.key)")
configPath = flag.String("config", "config/config.yaml", "Path to the configuration file")
port = flag.Int("port", 50051, "Port to listen on for gRPC ExtProc")
apiPort = flag.Int("api-port", 8080, "Port to listen on for Classification API")
metricsPort = flag.Int("metrics-port", 9190, "Port for Prometheus metrics")
enableAPI = flag.Bool("enable-api", true, "Enable Classification API server")
enableSystemPromptAPI = flag.Bool("enable-system-prompt-api", false, "Enable system prompt configuration endpoints (SECURITY: only enable in trusted environments)")
secure = flag.Bool("secure", false, "Enable secure gRPC server with TLS")
certPath = flag.String("cert-path", "", "Path to TLS certificate directory (containing tls.crt and tls.key)")
)
flag.Parse()

Expand Down Expand Up @@ -58,7 +59,7 @@ func main() {
if *enableAPI {
go func() {
observability.Infof("Starting Classification API server on port %d", *apiPort)
if err := api.StartClassificationAPI(*configPath, *apiPort); err != nil {
if err := api.StartClassificationAPI(*configPath, *apiPort, *enableSystemPromptAPI); err != nil {
observability.Errorf("Classification API server error: %v", err)
}
}()
Expand Down
170 changes: 165 additions & 5 deletions src/semantic-router/pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import (

// ClassificationAPIServer holds the server state and dependencies
type ClassificationAPIServer struct {
classificationSvc *services.ClassificationService
config *config.RouterConfig
classificationSvc *services.ClassificationService
config *config.RouterConfig
enableSystemPromptAPI bool
}

// ModelsInfoResponse represents the response for models info endpoint
Expand Down Expand Up @@ -101,7 +102,7 @@ type ClassificationOptions struct {
}

// StartClassificationAPI starts the Classification API server
func StartClassificationAPI(configPath string, port int) error {
func StartClassificationAPI(configPath string, port int, enableSystemPromptAPI bool) error {
// Load configuration
cfg, err := config.LoadConfig(configPath)
if err != nil {
Expand Down Expand Up @@ -139,8 +140,9 @@ func StartClassificationAPI(configPath string, port int) error {

// Create server instance
apiServer := &ClassificationAPIServer{
classificationSvc: classificationSvc,
config: cfg,
classificationSvc: classificationSvc,
config: cfg,
enableSystemPromptAPI: enableSystemPromptAPI,
}

// Create HTTP server with routes
Expand Down Expand Up @@ -203,6 +205,15 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux {
mux.HandleFunc("GET /config/classification", s.handleGetConfig)
mux.HandleFunc("PUT /config/classification", s.handleUpdateConfig)

// System prompt configuration endpoints (only if explicitly enabled)
if s.enableSystemPromptAPI {
observability.Infof("System prompt configuration endpoints enabled")
mux.HandleFunc("GET /config/system-prompts", s.handleGetSystemPrompts)
mux.HandleFunc("PUT /config/system-prompts", s.handleUpdateSystemPrompts)
} else {
observability.Infof("System prompt configuration endpoints disabled for security")
}

return mux
}

Expand Down Expand Up @@ -705,3 +716,152 @@ func (s *ClassificationAPIServer) calculateUnifiedStatistics(unifiedResults *ser
LowConfidenceCount: lowConfidenceCount,
}
}

// SystemPromptInfo represents system prompt information for a category
type SystemPromptInfo struct {
Category string `json:"category"`
Prompt string `json:"prompt"`
Enabled bool `json:"enabled"`
Mode string `json:"mode"` // "replace" or "insert"
}

// SystemPromptsResponse represents the response for GET /config/system-prompts
type SystemPromptsResponse struct {
SystemPrompts []SystemPromptInfo `json:"system_prompts"`
}

// SystemPromptUpdateRequest represents a request to update system prompt settings
type SystemPromptUpdateRequest struct {
Category string `json:"category,omitempty"` // If empty, applies to all categories
Enabled *bool `json:"enabled,omitempty"` // true to enable, false to disable
Mode string `json:"mode,omitempty"` // "replace" or "insert"
}

// handleGetSystemPrompts handles GET /config/system-prompts
func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, r *http.Request) {
cfg := s.config
if cfg == nil {
http.Error(w, "Configuration not available", http.StatusInternalServerError)
return
}

var systemPrompts []SystemPromptInfo
for _, category := range cfg.Categories {
systemPrompts = append(systemPrompts, SystemPromptInfo{
Category: category.Name,
Prompt: category.SystemPrompt,
Enabled: category.IsSystemPromptEnabled(),
Mode: category.GetSystemPromptMode(),
})
}

response := SystemPromptsResponse{
SystemPrompts: systemPrompts,
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
}

// handleUpdateSystemPrompts handles PUT /config/system-prompts
func (s *ClassificationAPIServer) handleUpdateSystemPrompts(w http.ResponseWriter, r *http.Request) {
var req SystemPromptUpdateRequest
if err := s.parseJSONRequest(r, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

if req.Enabled == nil && req.Mode == "" {
http.Error(w, "either enabled or mode field is required", http.StatusBadRequest)
return
}

// Validate mode if provided
if req.Mode != "" && req.Mode != "replace" && req.Mode != "insert" {
http.Error(w, "mode must be either 'replace' or 'insert'", http.StatusBadRequest)
return
}

cfg := s.config
if cfg == nil {
http.Error(w, "Configuration not available", http.StatusInternalServerError)
return
}

// Create a copy of the config to modify
newCfg := *cfg
newCategories := make([]config.Category, len(cfg.Categories))
copy(newCategories, cfg.Categories)
newCfg.Categories = newCategories

updated := false
if req.Category == "" {
// Update all categories
for i := range newCfg.Categories {
if newCfg.Categories[i].SystemPrompt != "" {
if req.Enabled != nil {
newCfg.Categories[i].SystemPromptEnabled = req.Enabled
}
if req.Mode != "" {
newCfg.Categories[i].SystemPromptMode = req.Mode
}
updated = true
}
}
} else {
// Update specific category
for i := range newCfg.Categories {
if newCfg.Categories[i].Name == req.Category {
if newCfg.Categories[i].SystemPrompt == "" {
http.Error(w, fmt.Sprintf("Category '%s' has no system prompt configured", req.Category), http.StatusBadRequest)
return
}
if req.Enabled != nil {
newCfg.Categories[i].SystemPromptEnabled = req.Enabled
}
if req.Mode != "" {
newCfg.Categories[i].SystemPromptMode = req.Mode
}
updated = true
break
}
}
if !updated {
http.Error(w, fmt.Sprintf("Category '%s' not found", req.Category), http.StatusNotFound)
return
}
}

if !updated {
http.Error(w, "No categories with system prompts found to update", http.StatusBadRequest)
return
}

// Update the configuration
s.config = &newCfg
s.classificationSvc.UpdateConfig(&newCfg)

// Return the updated system prompts
var systemPrompts []SystemPromptInfo
for _, category := range newCfg.Categories {
systemPrompts = append(systemPrompts, SystemPromptInfo{
Category: category.Name,
Prompt: category.SystemPrompt,
Enabled: category.IsSystemPromptEnabled(),
Mode: category.GetSystemPromptMode(),
})
}

response := SystemPromptsResponse{
SystemPrompts: systemPrompts,
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
}
Loading
Loading