Skip to content

Commit db4c780

Browse files
authored
feat(mcp): 8245 header forwarding (#2305)
1 parent 2b3e64e commit db4c780

File tree

2 files changed

+157
-24
lines changed

2 files changed

+157
-24
lines changed

router-tests/mcp_test.go

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ import (
55
"fmt"
66
"net/http"
77
"strings"
8+
"sync"
89
"testing"
910

1011
"github.com/mark3labs/mcp-go/mcp"
1112
"github.com/stretchr/testify/assert"
1213
"github.com/stretchr/testify/require"
1314
"github.com/wundergraph/cosmo/router-tests/testenv"
15+
"github.com/wundergraph/cosmo/router/core"
1416
"github.com/wundergraph/cosmo/router/pkg/config"
1517
)
1618

@@ -212,7 +214,6 @@ func TestMCP(t *testing.T) {
212214
t.Run("Execute Query", func(t *testing.T) {
213215
t.Run("Execute operation of type query with valid input", func(t *testing.T) {
214216
testenv.Run(t, &testenv.Config{
215-
EnableNats: true,
216217
MCP: config.MCPConfiguration{
217218
Enabled: true,
218219
},
@@ -553,4 +554,124 @@ func TestMCP(t *testing.T) {
553554
})
554555
})
555556
})
557+
558+
t.Run("Header Forwarding", func(t *testing.T) {
559+
t.Run("All request headers are forwarded from MCP client through to subgraphs", func(t *testing.T) {
560+
// This test validates that ALL headers sent by MCP clients are forwarded
561+
// through the complete chain: MCP Client -> MCP Server -> Router -> Subgraphs
562+
//
563+
// The router's header forwarding rules (configured with wildcard `.*`) determine
564+
// what gets propagated to subgraphs. The MCP server acts as a transparent proxy,
565+
// forwarding all headers without filtering.
566+
//
567+
// Note: We use direct HTTP POST requests instead of the mcp-go client library
568+
// because transport.WithHTTPHeaders() in mcp-go sets headers at the SSE connection
569+
// level, not on individual tool execution requests. Direct HTTP requests allow us
570+
// to test per-request headers, which is what real MCP clients (like Claude Desktop) send.
571+
572+
var capturedSubgraphRequest *http.Request
573+
var subgraphMutex sync.Mutex
574+
575+
testenv.Run(t, &testenv.Config{
576+
MCP: config.MCPConfiguration{
577+
Enabled: true,
578+
Session: config.MCPSessionConfig{
579+
Stateless: true, // Enable stateless mode so we don't need session IDs
580+
},
581+
},
582+
RouterOptions: []core.Option{
583+
// Forward all headers including custom ones
584+
core.WithHeaderRules(config.HeaderRules{
585+
All: &config.GlobalHeaderRule{
586+
Request: []*config.RequestHeaderRule{
587+
{
588+
Operation: config.HeaderRuleOperationPropagate,
589+
Matching: ".*", // Forward all headers
590+
},
591+
},
592+
},
593+
}),
594+
},
595+
Subgraphs: testenv.SubgraphsConfig{
596+
GlobalMiddleware: func(handler http.Handler) http.Handler {
597+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
598+
subgraphMutex.Lock()
599+
capturedSubgraphRequest = r.Clone(r.Context())
600+
subgraphMutex.Unlock()
601+
handler.ServeHTTP(w, r)
602+
})
603+
},
604+
},
605+
}, func(t *testing.T, xEnv *testenv.Environment) {
606+
// With stateless mode enabled, we can make direct HTTP POST requests
607+
// without needing to establish a session first
608+
mcpAddr := xEnv.GetMCPServerAddr()
609+
610+
// Make a direct HTTP POST request with custom headers
611+
// This simulates a real MCP client sending custom headers on tool calls
612+
mcpRequest := map[string]interface{}{
613+
"jsonrpc": "2.0",
614+
"id": 1,
615+
"method": "tools/call",
616+
"params": map[string]interface{}{
617+
"name": "execute_operation_my_employees",
618+
"arguments": map[string]interface{}{
619+
"criteria": map[string]interface{}{},
620+
},
621+
},
622+
}
623+
624+
requestBody, err := json.Marshal(mcpRequest)
625+
require.NoError(t, err)
626+
627+
req, err := http.NewRequest("POST", mcpAddr, strings.NewReader(string(requestBody)))
628+
require.NoError(t, err)
629+
630+
// Add various headers to test forwarding
631+
req.Header.Set("Content-Type", "application/json")
632+
req.Header.Set("foo", "bar") // Non-standard header
633+
req.Header.Set("X-Custom-Header", "custom-value") // Custom X- header
634+
req.Header.Set("X-Trace-Id", "trace-123") // Tracing header
635+
req.Header.Set("Authorization", "Bearer test-token") // Auth header
636+
637+
// Make the request
638+
resp, err := xEnv.RouterClient.Do(req)
639+
require.NoError(t, err)
640+
defer resp.Body.Close()
641+
642+
// With stateless mode, the request should succeed
643+
t.Logf("Response Status: %d", resp.StatusCode)
644+
require.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in stateless mode")
645+
646+
// Verify headers reached subgraph
647+
subgraphMutex.Lock()
648+
defer subgraphMutex.Unlock()
649+
650+
require.NotNil(t, capturedSubgraphRequest, "Subgraph should have received a request")
651+
652+
// Log all headers that the subgraph received
653+
t.Logf("Headers received by subgraph:")
654+
for key, values := range capturedSubgraphRequest.Header {
655+
for _, value := range values {
656+
t.Logf(" %s: %s", key, value)
657+
}
658+
}
659+
660+
// Verify that all headers were forwarded through the entire chain:
661+
// MCP Client -> MCP Server -> Router -> Subgraph
662+
assert.Equal(t, "bar", capturedSubgraphRequest.Header.Get("Foo"),
663+
"'foo' header should be forwarded to subgraph")
664+
assert.Equal(t, "custom-value", capturedSubgraphRequest.Header.Get("X-Custom-Header"),
665+
"X-Custom-Header should be forwarded to subgraph")
666+
assert.Equal(t, "trace-123", capturedSubgraphRequest.Header.Get("X-Trace-Id"),
667+
"X-Trace-Id should be forwarded to subgraph")
668+
assert.Equal(t, "Bearer test-token", capturedSubgraphRequest.Header.Get("Authorization"),
669+
"Authorization header should be forwarded to subgraph")
670+
671+
// This test proves that ALL headers sent by MCP clients are forwarded
672+
// through the complete chain. The router's header rules determine what
673+
// ultimately reaches the subgraphs.
674+
})
675+
})
676+
})
556677
}

router/pkg/mcpserver/server.go

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,28 @@ import (
2222
"go.uber.org/zap"
2323
)
2424

25-
// authKey is a custom context key for storing the auth token.
26-
type authKey struct{}
25+
// requestHeadersKey is a custom context key for storing request headers.
26+
type requestHeadersKey struct{}
2727

28-
// withAuthKey adds an auth key to the context.
29-
func withAuthKey(ctx context.Context, auth string) context.Context {
30-
return context.WithValue(ctx, authKey{}, auth)
28+
// withRequestHeaders adds request headers to the context.
29+
func withRequestHeaders(ctx context.Context, headers http.Header) context.Context {
30+
return context.WithValue(ctx, requestHeadersKey{}, headers)
3131
}
3232

33-
// authFromRequest extracts the auth token from the request headers.
34-
func authFromRequest(ctx context.Context, r *http.Request) context.Context {
35-
return withAuthKey(ctx, r.Header.Get("Authorization"))
33+
// requestHeadersFromRequest extracts all headers from the request and stores them in context.
34+
func requestHeadersFromRequest(ctx context.Context, r *http.Request) context.Context {
35+
// Clone the headers to avoid any mutation issues
36+
headers := r.Header.Clone()
37+
return withRequestHeaders(ctx, headers)
3638
}
3739

38-
// tokenFromContext extracts the auth token from the context.
39-
// This can be used by clients to pass the auth token to the server.
40-
func tokenFromContext(ctx context.Context) (string, error) {
41-
auth, ok := ctx.Value(authKey{}).(string)
40+
// headersFromContext extracts the request headers from the context.
41+
func headersFromContext(ctx context.Context) (http.Header, error) {
42+
headers, ok := ctx.Value(requestHeadersKey{}).(http.Header)
4243
if !ok {
43-
return "", fmt.Errorf("missing auth")
44+
return nil, fmt.Errorf("missing request headers")
4445
}
45-
return auth, nil
46+
return headers, nil
4647
}
4748

4849
// Options represents configuration options for the GraphQLSchemaServer
@@ -223,6 +224,11 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options)
223224
return gs, nil
224225
}
225226

227+
// SetHTTPClient allows setting a custom HTTP client (useful for testing)
228+
func (s *GraphQLSchemaServer) SetHTTPClient(client *http.Client) {
229+
s.httpClient = client
230+
}
231+
226232
// WithGraphName sets the graph name
227233
func WithGraphName(graphName string) func(*Options) {
228234
return func(o *Options) {
@@ -299,7 +305,7 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) {
299305
server.WithStreamableHTTPServer(httpServer),
300306
server.WithLogger(NewZapAdapter(s.logger.With(zap.String("component", "mcp-server")))),
301307
server.WithStateLess(s.stateless),
302-
server.WithHTTPContextFunc(authFromRequest),
308+
server.WithHTTPContextFunc(requestHeadersFromRequest),
303309
server.WithHeartbeatInterval(10*time.Second),
304310
)
305311

@@ -672,17 +678,23 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str
672678
return nil, fmt.Errorf("failed to create request: %w", err)
673679
}
674680

675-
req.Header.Set("Accept", "application/json")
676-
req.Header.Set("Content-Type", "application/json; charset=utf-8")
677-
678-
token, err := tokenFromContext(ctx)
681+
// Forward all headers from the original MCP request to the GraphQL server
682+
// The router's header forwarding rules will then determine what gets sent to subgraphs
683+
headers, err := headersFromContext(ctx)
679684
if err != nil {
680-
s.logger.Debug("failed to get token from context", zap.Error(err))
681-
} else if token != "" {
682-
req.Header.Set("Authorization", token)
685+
s.logger.Debug("failed to get headers from context", zap.Error(err))
686+
} else {
687+
// Copy all headers from the MCP request
688+
for key, values := range headers {
689+
for _, value := range values {
690+
req.Header.Add(key, value)
691+
}
692+
}
683693
}
684694

685-
// Forward Authorization header if provided
695+
// Override specific headers that must be set for GraphQL requests
696+
req.Header.Set("Accept", "application/json")
697+
req.Header.Set("Content-Type", "application/json; charset=utf-8")
686698

687699
resp, err := s.httpClient.Do(req)
688700
if err != nil {

0 commit comments

Comments
 (0)