Skip to content

Commit 7766fc2

Browse files
Copilotrootfs
andcommitted
Fix code review issues: add header forwarding and improve safety
Co-authored-by: rootfs <[email protected]>
1 parent 6c64997 commit 7766fc2

File tree

4 files changed

+25
-9
lines changed

4 files changed

+25
-9
lines changed

e2e-tests/testcases/go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzM
5151
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
5252
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
5353
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
54-
github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw=
55-
github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
5654
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
5755
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
5856
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=

src/semantic-router/pkg/ensemble/factory.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"strings"
1011
"sync"
1112
"time"
1213

@@ -117,11 +118,11 @@ func (f *Factory) Execute(req *Request) *Response {
117118
}
118119
}
119120

120-
// Build metadata
121+
// Build metadata (only include successful responses)
121122
metadata.TotalLatencyMs = totalLatency
122123
metadata.ModelLatenciesMs = make(map[string]int64)
123124
metadata.ConfidenceScores = make(map[string]float64)
124-
for _, resp := range responses {
125+
for _, resp := range successfulResponses {
125126
metadata.ModelLatenciesMs[resp.ModelName] = resp.Latency.Milliseconds()
126127
if resp.Confidence > 0 {
127128
metadata.ConfidenceScores[resp.ModelName] = resp.Confidence
@@ -145,8 +146,12 @@ func (f *Factory) queryModels(req *Request) []ModelResponse {
145146
responses := make([]ModelResponse, len(req.Models))
146147
var wg sync.WaitGroup
147148

148-
// Limit concurrent requests
149-
semaphore := make(chan struct{}, f.config.MaxConcurrentRequests)
149+
// Limit concurrent requests (ensure at least 1)
150+
maxConcurrent := f.config.MaxConcurrentRequests
151+
if maxConcurrent <= 0 {
152+
maxConcurrent = 10 // Default to 10 if not set or invalid
153+
}
154+
semaphore := make(chan struct{}, maxConcurrent)
150155

151156
for i, modelName := range req.Models {
152157
wg.Add(1)
@@ -157,7 +162,7 @@ func (f *Factory) queryModels(req *Request) []ModelResponse {
157162
semaphore <- struct{}{}
158163
defer func() { <-semaphore }()
159164

160-
responses[idx] = f.queryModel(req.Context, model, req.OriginalRequest)
165+
responses[idx] = f.queryModel(req.Context, model, req.OriginalRequest, req.Headers)
161166
}(i, modelName)
162167
}
163168

@@ -166,7 +171,7 @@ func (f *Factory) queryModels(req *Request) []ModelResponse {
166171
}
167172

168173
// queryModel queries a single model endpoint
169-
func (f *Factory) queryModel(ctx context.Context, modelName string, requestBody []byte) ModelResponse {
174+
func (f *Factory) queryModel(ctx context.Context, modelName string, requestBody []byte, headers map[string]string) ModelResponse {
170175
startTime := time.Now()
171176

172177
endpoint, ok := f.endpoints[modelName]
@@ -200,6 +205,15 @@ func (f *Factory) queryModel(ctx context.Context, modelName string, requestBody
200205

201206
httpReq.Header.Set("Content-Type", "application/json")
202207

208+
// Forward authentication and other headers from original request
209+
for key, value := range headers {
210+
// Forward authorization and other important headers
211+
lowerKey := strings.ToLower(key)
212+
if lowerKey == "authorization" || lowerKey == "x-api-key" || strings.HasPrefix(lowerKey, "x-") {
213+
httpReq.Header.Set(key, value)
214+
}
215+
}
216+
203217
// Execute request
204218
resp, err := f.httpClient.Do(httpReq)
205219
if err != nil {

src/semantic-router/pkg/ensemble/types.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ type Request struct {
5757
// OriginalRequest is the original OpenAI API request body
5858
OriginalRequest []byte
5959

60+
// Headers contains HTTP headers to forward to model endpoints (e.g., Authorization)
61+
Headers map[string]string
62+
6063
// Context for cancellation and timeout
6164
Context context.Context
6265
}

src/semantic-router/pkg/extproc/req_filter_ensemble.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ func (r *OpenAIRouter) handleEnsembleRequest(ctx *RequestContext) (*ext_proc.Pro
2424

2525
logging.Infof("Processing ensemble request with %d models: %v", len(ctx.EnsembleModels), ctx.EnsembleModels)
2626

27-
// Build ensemble request
27+
// Build ensemble request with headers for authentication
2828
ensembleReq := &ensemble.Request{
2929
Models: ctx.EnsembleModels,
3030
Strategy: ensemble.Strategy(ctx.EnsembleStrategy),
3131
MinResponses: ctx.EnsembleMinResponses,
3232
OriginalRequest: ctx.OriginalRequestBody,
33+
Headers: ctx.Headers, // Forward original request headers for authentication
3334
Context: ctx.TraceContext,
3435
}
3536

0 commit comments

Comments
 (0)