Skip to content

Commit b034e46

Browse files
committed
feat: add system prompt toggle endpoint
Signed-off-by: Huamin Chen <[email protected]>
1 parent 6cef4dd commit b034e46

File tree

4 files changed

+258
-14
lines changed

4 files changed

+258
-14
lines changed

src/semantic-router/pkg/api/server.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux {
203203
mux.HandleFunc("GET /config/classification", s.handleGetConfig)
204204
mux.HandleFunc("PUT /config/classification", s.handleUpdateConfig)
205205

206+
// System prompt configuration endpoints
207+
mux.HandleFunc("GET /config/system-prompts", s.handleGetSystemPrompts)
208+
mux.HandleFunc("PUT /config/system-prompts", s.handleUpdateSystemPrompts)
209+
206210
return mux
207211
}
208212

@@ -705,3 +709,151 @@ func (s *ClassificationAPIServer) calculateUnifiedStatistics(unifiedResults *ser
705709
LowConfidenceCount: lowConfidenceCount,
706710
}
707711
}
712+
713+
// SystemPromptInfo represents system prompt information for a category
714+
type SystemPromptInfo struct {
715+
Category string `json:"category"`
716+
Prompt string `json:"prompt"`
717+
Enabled bool `json:"enabled"`
718+
Mode string `json:"mode"` // "replace" or "insert"
719+
}
720+
721+
// SystemPromptsResponse represents the response for GET /config/system-prompts
722+
type SystemPromptsResponse struct {
723+
SystemPrompts []SystemPromptInfo `json:"system_prompts"`
724+
}
725+
726+
// SystemPromptUpdateRequest represents a request to update system prompt settings
727+
type SystemPromptUpdateRequest struct {
728+
Category string `json:"category,omitempty"` // If empty, applies to all categories
729+
Enabled *bool `json:"enabled,omitempty"` // true to enable, false to disable
730+
Mode string `json:"mode,omitempty"` // "replace" or "insert"
731+
}
732+
733+
// handleGetSystemPrompts handles GET /config/system-prompts
734+
func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, r *http.Request) {
735+
cfg := s.classificationSvc.GetConfig()
736+
if cfg == nil {
737+
http.Error(w, "Configuration not available", http.StatusInternalServerError)
738+
return
739+
}
740+
741+
var systemPrompts []SystemPromptInfo
742+
for _, category := range cfg.Categories {
743+
systemPrompts = append(systemPrompts, SystemPromptInfo{
744+
Category: category.Name,
745+
Prompt: category.SystemPrompt,
746+
Enabled: category.IsSystemPromptEnabled(),
747+
Mode: category.GetSystemPromptMode(),
748+
})
749+
}
750+
751+
response := SystemPromptsResponse{
752+
SystemPrompts: systemPrompts,
753+
}
754+
755+
w.Header().Set("Content-Type", "application/json")
756+
if err := json.NewEncoder(w).Encode(response); err != nil {
757+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
758+
return
759+
}
760+
}
761+
762+
// handleUpdateSystemPrompts handles PUT /config/system-prompts
763+
func (s *ClassificationAPIServer) handleUpdateSystemPrompts(w http.ResponseWriter, r *http.Request) {
764+
var req SystemPromptUpdateRequest
765+
if err := s.parseJSONRequest(r, &req); err != nil {
766+
http.Error(w, err.Error(), http.StatusBadRequest)
767+
return
768+
}
769+
770+
if req.Enabled == nil && req.Mode == "" {
771+
http.Error(w, "either enabled or mode field is required", http.StatusBadRequest)
772+
return
773+
}
774+
775+
// Validate mode if provided
776+
if req.Mode != "" && req.Mode != "replace" && req.Mode != "insert" {
777+
http.Error(w, "mode must be either 'replace' or 'insert'", http.StatusBadRequest)
778+
return
779+
}
780+
781+
cfg := s.classificationSvc.GetConfig()
782+
if cfg == nil {
783+
http.Error(w, "Configuration not available", http.StatusInternalServerError)
784+
return
785+
}
786+
787+
// Create a copy of the config to modify
788+
newCfg := *cfg
789+
newCategories := make([]config.Category, len(cfg.Categories))
790+
copy(newCategories, cfg.Categories)
791+
newCfg.Categories = newCategories
792+
793+
updated := false
794+
if req.Category == "" {
795+
// Update all categories
796+
for i := range newCfg.Categories {
797+
if newCfg.Categories[i].SystemPrompt != "" {
798+
if req.Enabled != nil {
799+
newCfg.Categories[i].SystemPromptEnabled = req.Enabled
800+
}
801+
if req.Mode != "" {
802+
newCfg.Categories[i].SystemPromptMode = req.Mode
803+
}
804+
updated = true
805+
}
806+
}
807+
} else {
808+
// Update specific category
809+
for i := range newCfg.Categories {
810+
if newCfg.Categories[i].Name == req.Category {
811+
if newCfg.Categories[i].SystemPrompt == "" {
812+
http.Error(w, fmt.Sprintf("Category '%s' has no system prompt configured", req.Category), http.StatusBadRequest)
813+
return
814+
}
815+
if req.Enabled != nil {
816+
newCfg.Categories[i].SystemPromptEnabled = req.Enabled
817+
}
818+
if req.Mode != "" {
819+
newCfg.Categories[i].SystemPromptMode = req.Mode
820+
}
821+
updated = true
822+
break
823+
}
824+
}
825+
if !updated {
826+
http.Error(w, fmt.Sprintf("Category '%s' not found", req.Category), http.StatusNotFound)
827+
return
828+
}
829+
}
830+
831+
if !updated {
832+
http.Error(w, "No categories with system prompts found to update", http.StatusBadRequest)
833+
return
834+
}
835+
836+
// Update the configuration
837+
s.classificationSvc.UpdateConfig(&newCfg)
838+
839+
// Return the updated system prompts
840+
var systemPrompts []SystemPromptInfo
841+
for _, category := range newCfg.Categories {
842+
systemPrompts = append(systemPrompts, SystemPromptInfo{
843+
Category: category.Name,
844+
Prompt: category.SystemPrompt,
845+
Enabled: category.IsSystemPromptEnabled(),
846+
Mode: category.GetSystemPromptMode(),
847+
})
848+
}
849+
850+
response := SystemPromptsResponse{
851+
SystemPrompts: systemPrompts,
852+
}
853+
854+
w.Header().Set("Content-Type", "application/json")
855+
if err := json.NewEncoder(w).Encode(response); err != nil {
856+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
857+
return
858+
}
859+
}

src/semantic-router/pkg/config/config.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ type Category struct {
276276
MMLUCategories []string `yaml:"mmlu_categories,omitempty"`
277277
// SystemPrompt is an optional category-specific system prompt automatically injected into requests
278278
SystemPrompt string `yaml:"system_prompt,omitempty"`
279+
// SystemPromptEnabled controls whether the system prompt should be injected for this category
280+
// Defaults to true when SystemPrompt is not empty
281+
SystemPromptEnabled *bool `yaml:"system_prompt_enabled,omitempty"`
282+
// SystemPromptMode controls how the system prompt is injected: "replace" (default) or "insert"
283+
// "replace": Replace any existing system message with the category-specific prompt
284+
// "insert": Prepend the category-specific prompt to the existing system message content
285+
SystemPromptMode string `yaml:"system_prompt_mode,omitempty"`
279286
}
280287

281288
// Legacy types - can be removed once migration is complete
@@ -671,3 +678,31 @@ func (c *RouterConfig) ValidateEndpoints() error {
671678

672679
return nil
673680
}
681+
682+
// IsSystemPromptEnabled returns whether system prompt injection is enabled for a category
683+
func (c *Category) IsSystemPromptEnabled() bool {
684+
// If SystemPromptEnabled is explicitly set, use that value
685+
if c.SystemPromptEnabled != nil {
686+
return *c.SystemPromptEnabled
687+
}
688+
// Default to true if SystemPrompt is not empty
689+
return c.SystemPrompt != ""
690+
}
691+
692+
// GetSystemPromptMode returns the system prompt injection mode, defaulting to "replace"
693+
func (c *Category) GetSystemPromptMode() string {
694+
if c.SystemPromptMode == "" {
695+
return "replace" // Default mode
696+
}
697+
return c.SystemPromptMode
698+
}
699+
700+
// GetCategoryByName returns a category by name
701+
func (c *RouterConfig) GetCategoryByName(name string) *Category {
702+
for i := range c.Categories {
703+
if c.Categories[i].Name == name {
704+
return &c.Categories[i]
705+
}
706+
}
707+
return nil
708+
}

src/semantic-router/pkg/extproc/request_handler.go

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"google.golang.org/grpc/status"
1414

1515
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/cache"
16+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
1617
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics"
1718
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability"
1819
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/http"
@@ -72,7 +73,7 @@ func serializeOpenAIRequestWithStream(req *openai.ChatCompletionNewParams, hasSt
7273

7374
// addSystemPromptToRequestBody adds a system prompt to the beginning of the messages array in the JSON request body
7475
// Returns the modified body, whether the system prompt was actually injected, and any error
75-
func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]byte, bool, error) {
76+
func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string, mode string) ([]byte, bool, error) {
7677
if systemPrompt == "" {
7778
return requestBody, false, nil
7879
}
@@ -94,32 +95,63 @@ func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]by
9495
return requestBody, false, nil // Messages is not an array, return original
9596
}
9697

97-
// Create a new system message
98-
systemMessage := map[string]interface{}{
99-
"role": "system",
100-
"content": systemPrompt,
101-
}
102-
10398
// Check if there's already a system message at the beginning
10499
hasSystemMessage := false
100+
var existingSystemContent string
105101
if len(messages) > 0 {
106102
if firstMsg, ok := messages[0].(map[string]interface{}); ok {
107103
if role, ok := firstMsg["role"].(string); ok && role == "system" {
108104
hasSystemMessage = true
105+
if content, ok := firstMsg["content"].(string); ok {
106+
existingSystemContent = content
107+
}
109108
}
110109
}
111110
}
112111

112+
// Handle different injection modes
113+
var finalSystemContent string
114+
var logMessage string
115+
116+
switch mode {
117+
case "insert":
118+
if hasSystemMessage {
119+
// Insert mode: prepend category prompt to existing system message
120+
finalSystemContent = systemPrompt + "\n\n" + existingSystemContent
121+
logMessage = "Inserted category-specific system prompt before existing system message"
122+
} else {
123+
// No existing system message, just use the category prompt
124+
finalSystemContent = systemPrompt
125+
logMessage = "Added category-specific system prompt (insert mode, no existing system message)"
126+
}
127+
case "replace":
128+
fallthrough
129+
default:
130+
// Replace mode: use only the category prompt
131+
finalSystemContent = systemPrompt
132+
if hasSystemMessage {
133+
logMessage = "Replaced existing system message with category-specific system prompt"
134+
} else {
135+
logMessage = "Added category-specific system prompt to the beginning of messages"
136+
}
137+
}
138+
139+
// Create the final system message
140+
systemMessage := map[string]interface{}{
141+
"role": "system",
142+
"content": finalSystemContent,
143+
}
144+
113145
if hasSystemMessage {
114-
// Replace the existing system message
146+
// Update the existing system message
115147
messages[0] = systemMessage
116-
observability.Infof("Replaced existing system message with category-specific system prompt")
117148
} else {
118149
// Prepend the system message to the beginning of the messages array
119150
messages = append([]interface{}{systemMessage}, messages...)
120-
observability.Infof("Added category-specific system prompt to the beginning of messages")
121151
}
122152

153+
observability.Infof("%s (mode: %s)", logMessage, mode)
154+
123155
// Update the messages in the request map
124156
requestMap["messages"] = messages
125157

@@ -564,19 +596,32 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
564596

565597
// Add category-specific system prompt if configured
566598
if categoryName != "" {
567-
category := r.Classifier.GetCategoryByName(categoryName)
568-
if category != nil && category.SystemPrompt != "" {
599+
// Use global config to get the most up-to-date category configuration
600+
// This ensures API updates are reflected immediately
601+
globalConfig := config.GetConfig()
602+
var category *config.Category
603+
if globalConfig != nil {
604+
category = globalConfig.GetCategoryByName(categoryName)
605+
}
606+
607+
if category != nil && category.SystemPrompt != "" && category.IsSystemPromptEnabled() {
608+
mode := category.GetSystemPromptMode()
569609
var injected bool
570-
modifiedBody, injected, err = addSystemPromptToRequestBody(modifiedBody, category.SystemPrompt)
610+
modifiedBody, injected, err = addSystemPromptToRequestBody(modifiedBody, category.SystemPrompt, mode)
571611
if err != nil {
572612
observability.Errorf("Error adding system prompt to request: %v", err)
573613
metrics.RecordRequestError(actualModel, "serialization_error")
574614
return nil, status.Errorf(codes.Internal, "error adding system prompt: %v", err)
575615
}
576616
if injected {
577617
ctx.VSRInjectedSystemPrompt = true
578-
observability.Infof("Added category-specific system prompt for category: %s", categoryName)
618+
observability.Infof("Added category-specific system prompt for category: %s (mode: %s)", categoryName, mode)
579619
}
620+
621+
// Log the complete message structure after system prompt injection
622+
observability.Infof("Complete request body after system prompt injection: %s", string(modifiedBody))
623+
} else if category != nil && category.SystemPrompt != "" && !category.IsSystemPromptEnabled() {
624+
observability.Infof("System prompt disabled for category: %s", categoryName)
580625
}
581626
}
582627

src/semantic-router/pkg/services/classification.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,15 @@ func (s *ClassificationService) GetUnifiedClassifierStats() map[string]interface
485485
stats["available"] = true
486486
return stats
487487
}
488+
489+
// GetConfig returns the current configuration
490+
func (s *ClassificationService) GetConfig() *config.RouterConfig {
491+
return s.config
492+
}
493+
494+
// UpdateConfig updates the configuration
495+
func (s *ClassificationService) UpdateConfig(newConfig *config.RouterConfig) {
496+
s.config = newConfig
497+
// Update the global config as well
498+
config.ReplaceGlobalConfig(newConfig)
499+
}

0 commit comments

Comments
 (0)