Skip to content

Commit 175b74f

Browse files
authored
feat: support injection system prompt response header (#297)
Signed-off-by: bitliu <[email protected]>
1 parent 53bb830 commit 175b74f

File tree

3 files changed

+119
-22
lines changed

3 files changed

+119
-22
lines changed

src/semantic-router/pkg/extproc/request_handler.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,26 +71,27 @@ func serializeOpenAIRequestWithStream(req *openai.ChatCompletionNewParams, hasSt
7171
}
7272

7373
// addSystemPromptToRequestBody adds a system prompt to the beginning of the messages array in the JSON request body
74-
func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]byte, error) {
74+
// Returns the modified body, whether the system prompt was actually injected, and any error
75+
func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]byte, bool, error) {
7576
if systemPrompt == "" {
76-
return requestBody, nil
77+
return requestBody, false, nil
7778
}
7879

7980
// Parse the JSON request body
8081
var requestMap map[string]interface{}
8182
if err := json.Unmarshal(requestBody, &requestMap); err != nil {
82-
return nil, err
83+
return nil, false, err
8384
}
8485

8586
// Get the messages array
8687
messagesInterface, ok := requestMap["messages"]
8788
if !ok {
88-
return requestBody, nil // No messages array, return original
89+
return requestBody, false, nil // No messages array, return original
8990
}
9091

9192
messages, ok := messagesInterface.([]interface{})
9293
if !ok {
93-
return requestBody, nil // Messages is not an array, return original
94+
return requestBody, false, nil // Messages is not an array, return original
9495
}
9596

9697
// Create a new system message
@@ -123,7 +124,8 @@ func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]by
123124
requestMap["messages"] = messages
124125

125126
// Marshal back to JSON
126-
return json.Marshal(requestMap)
127+
modifiedBody, err := json.Marshal(requestMap)
128+
return modifiedBody, true, err
127129
}
128130

129131
// extractUserAndNonUserContent extracts content from request messages
@@ -211,10 +213,11 @@ type RequestContext struct {
211213
TTFTSeconds float64
212214

213215
// VSR decision tracking
214-
VSRSelectedCategory string // The category selected by VSR
215-
VSRReasoningMode string // "on" or "off" - whether reasoning mode was determined to be used
216-
VSRSelectedModel string // The model selected by VSR
217-
VSRCacheHit bool // Whether this request hit the cache
216+
VSRSelectedCategory string // The category selected by VSR
217+
VSRReasoningMode string // "on" or "off" - whether reasoning mode was determined to be used
218+
VSRSelectedModel string // The model selected by VSR
219+
VSRCacheHit bool // Whether this request hit the cache
220+
VSRInjectedSystemPrompt bool // Whether a system prompt was injected into the request
218221
}
219222

220223
// handleRequestHeaders processes the request headers
@@ -563,13 +566,17 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
563566
if categoryName != "" {
564567
category := r.Classifier.GetCategoryByName(categoryName)
565568
if category != nil && category.SystemPrompt != "" {
566-
modifiedBody, err = addSystemPromptToRequestBody(modifiedBody, category.SystemPrompt)
569+
var injected bool
570+
modifiedBody, injected, err = addSystemPromptToRequestBody(modifiedBody, category.SystemPrompt)
567571
if err != nil {
568572
observability.Errorf("Error adding system prompt to request: %v", err)
569573
metrics.RecordRequestError(actualModel, "serialization_error")
570574
return nil, status.Errorf(codes.Internal, "error adding system prompt: %v", err)
571575
}
572-
observability.Infof("Added category-specific system prompt for category: %s", categoryName)
576+
if injected {
577+
ctx.VSRInjectedSystemPrompt = true
578+
observability.Infof("Added category-specific system prompt for category: %s", categoryName)
579+
}
573580
}
574581
}
575582

src/semantic-router/pkg/extproc/response_handler.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ func (r *OpenAIRouter) handleResponseHeaders(v *ext_proc.ProcessingRequest_Respo
8686
})
8787
}
8888

89+
// Add x-vsr-injected-system-prompt header
90+
injectedValue := "false"
91+
if ctx.VSRInjectedSystemPrompt {
92+
injectedValue = "true"
93+
}
94+
setHeaders = append(setHeaders, &core.HeaderValueOption{
95+
Header: &core.HeaderValue{
96+
Key: "x-vsr-injected-system-prompt",
97+
RawValue: []byte(injectedValue),
98+
},
99+
})
100+
89101
// Create header mutation if we have headers to add
90102
if len(setHeaders) > 0 {
91103
headerMutation = &ext_proc.HeaderMutation{

src/semantic-router/pkg/extproc/vsr_headers_test.go

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ func TestVSRHeadersAddedOnSuccessfulNonCachedResponse(t *testing.T) {
1414

1515
// Create request context with VSR decision information
1616
ctx := &RequestContext{
17-
VSRSelectedCategory: "math",
18-
VSRReasoningMode: "on",
19-
VSRSelectedModel: "deepseek-v31",
20-
VSRCacheHit: false, // Not a cache hit
17+
VSRSelectedCategory: "math",
18+
VSRReasoningMode: "on",
19+
VSRSelectedModel: "deepseek-v31",
20+
VSRCacheHit: false, // Not a cache hit
21+
VSRInjectedSystemPrompt: true, // System prompt was injected
2122
}
2223

2324
// Create response headers with successful status (200)
@@ -48,7 +49,7 @@ func TestVSRHeadersAddedOnSuccessfulNonCachedResponse(t *testing.T) {
4849
assert.NotNil(t, headerMutation, "HeaderMutation should not be nil for successful non-cached response")
4950

5051
setHeaders := headerMutation.GetSetHeaders()
51-
assert.Len(t, setHeaders, 3, "Should have 3 VSR headers")
52+
assert.Len(t, setHeaders, 4, "Should have 4 VSR headers")
5253

5354
// Verify each header
5455
headerMap := make(map[string]string)
@@ -59,6 +60,7 @@ func TestVSRHeadersAddedOnSuccessfulNonCachedResponse(t *testing.T) {
5960
assert.Equal(t, "math", headerMap["x-vsr-selected-category"])
6061
assert.Equal(t, "on", headerMap["x-vsr-selected-reasoning"])
6162
assert.Equal(t, "deepseek-v31", headerMap["x-vsr-selected-model"])
63+
assert.Equal(t, "true", headerMap["x-vsr-injected-system-prompt"])
6264
}
6365

6466
func TestVSRHeadersNotAddedOnCacheHit(t *testing.T) {
@@ -139,10 +141,11 @@ func TestVSRHeadersPartialInformation(t *testing.T) {
139141

140142
// Create request context with partial VSR information
141143
ctx := &RequestContext{
142-
VSRSelectedCategory: "math",
143-
VSRReasoningMode: "", // Empty reasoning mode
144-
VSRSelectedModel: "deepseek-v31",
145-
VSRCacheHit: false,
144+
VSRSelectedCategory: "math",
145+
VSRReasoningMode: "", // Empty reasoning mode
146+
VSRSelectedModel: "deepseek-v31",
147+
VSRCacheHit: false,
148+
VSRInjectedSystemPrompt: false, // No system prompt injected
146149
}
147150

148151
// Create response headers with successful status (200)
@@ -169,7 +172,7 @@ func TestVSRHeadersPartialInformation(t *testing.T) {
169172
assert.NotNil(t, headerMutation)
170173

171174
setHeaders := headerMutation.GetSetHeaders()
172-
assert.Len(t, setHeaders, 2, "Should have 2 VSR headers (excluding empty reasoning mode)")
175+
assert.Len(t, setHeaders, 3, "Should have 3 VSR headers (excluding empty reasoning mode, but including injected-system-prompt)")
173176

174177
// Verify each header
175178
headerMap := make(map[string]string)
@@ -179,5 +182,80 @@ func TestVSRHeadersPartialInformation(t *testing.T) {
179182

180183
assert.Equal(t, "math", headerMap["x-vsr-selected-category"])
181184
assert.Equal(t, "deepseek-v31", headerMap["x-vsr-selected-model"])
185+
assert.Equal(t, "false", headerMap["x-vsr-injected-system-prompt"])
182186
assert.NotContains(t, headerMap, "x-vsr-selected-reasoning", "Empty reasoning mode should not be added")
183187
}
188+
189+
func TestVSRInjectedSystemPromptHeader(t *testing.T) {
190+
router := &OpenAIRouter{}
191+
192+
// Test case 1: System prompt was injected
193+
t.Run("SystemPromptInjected", func(t *testing.T) {
194+
ctx := &RequestContext{
195+
VSRSelectedCategory: "coding",
196+
VSRReasoningMode: "on",
197+
VSRSelectedModel: "gpt-4",
198+
VSRCacheHit: false,
199+
VSRInjectedSystemPrompt: true,
200+
}
201+
202+
responseHeaders := &ext_proc.ProcessingRequest_ResponseHeaders{
203+
ResponseHeaders: &ext_proc.HttpHeaders{
204+
Headers: &core.HeaderMap{
205+
Headers: []*core.HeaderValue{
206+
{Key: ":status", Value: "200"},
207+
},
208+
},
209+
},
210+
}
211+
212+
response, err := router.handleResponseHeaders(responseHeaders, ctx)
213+
assert.NoError(t, err)
214+
assert.NotNil(t, response)
215+
216+
headerMutation := response.GetResponseHeaders().GetResponse().GetHeaderMutation()
217+
assert.NotNil(t, headerMutation)
218+
219+
headerMap := make(map[string]string)
220+
for _, header := range headerMutation.GetSetHeaders() {
221+
headerMap[header.Header.Key] = string(header.Header.RawValue)
222+
}
223+
224+
assert.Equal(t, "true", headerMap["x-vsr-injected-system-prompt"])
225+
})
226+
227+
// Test case 2: System prompt was not injected
228+
t.Run("SystemPromptNotInjected", func(t *testing.T) {
229+
ctx := &RequestContext{
230+
VSRSelectedCategory: "coding",
231+
VSRReasoningMode: "on",
232+
VSRSelectedModel: "gpt-4",
233+
VSRCacheHit: false,
234+
VSRInjectedSystemPrompt: false,
235+
}
236+
237+
responseHeaders := &ext_proc.ProcessingRequest_ResponseHeaders{
238+
ResponseHeaders: &ext_proc.HttpHeaders{
239+
Headers: &core.HeaderMap{
240+
Headers: []*core.HeaderValue{
241+
{Key: ":status", Value: "200"},
242+
},
243+
},
244+
},
245+
}
246+
247+
response, err := router.handleResponseHeaders(responseHeaders, ctx)
248+
assert.NoError(t, err)
249+
assert.NotNil(t, response)
250+
251+
headerMutation := response.GetResponseHeaders().GetResponse().GetHeaderMutation()
252+
assert.NotNil(t, headerMutation)
253+
254+
headerMap := make(map[string]string)
255+
for _, header := range headerMutation.GetSetHeaders() {
256+
headerMap[header.Header.Key] = string(header.Header.RawValue)
257+
}
258+
259+
assert.Equal(t, "false", headerMap["x-vsr-injected-system-prompt"])
260+
})
261+
}

0 commit comments

Comments
 (0)