@@ -5,12 +5,14 @@ import (
55 "fmt"
66 "net/http"
77 "strings"
8+ "sync"
89 "testing"
910
1011 "github.com/mark3labs/mcp-go/mcp"
1112 "github.com/stretchr/testify/assert"
1213 "github.com/stretchr/testify/require"
1314 "github.com/wundergraph/cosmo/router-tests/testenv"
15+ "github.com/wundergraph/cosmo/router/core"
1416 "github.com/wundergraph/cosmo/router/pkg/config"
1517)
1618
@@ -212,7 +214,6 @@ func TestMCP(t *testing.T) {
212214 t .Run ("Execute Query" , func (t * testing.T ) {
213215 t .Run ("Execute operation of type query with valid input" , func (t * testing.T ) {
214216 testenv .Run (t , & testenv.Config {
215- EnableNats : true ,
216217 MCP : config.MCPConfiguration {
217218 Enabled : true ,
218219 },
@@ -553,4 +554,124 @@ func TestMCP(t *testing.T) {
553554 })
554555 })
555556 })
557+
558+ t .Run ("Header Forwarding" , func (t * testing.T ) {
559+ t .Run ("All request headers are forwarded from MCP client through to subgraphs" , func (t * testing.T ) {
560+ // This test validates that ALL headers sent by MCP clients are forwarded
561+ // through the complete chain: MCP Client -> MCP Server -> Router -> Subgraphs
562+ //
563+ // The router's header forwarding rules (configured with wildcard `.*`) determine
564+ // what gets propagated to subgraphs. The MCP server acts as a transparent proxy,
565+ // forwarding all headers without filtering.
566+ //
567+ // Note: We use direct HTTP POST requests instead of the mcp-go client library
568+ // because transport.WithHTTPHeaders() in mcp-go sets headers at the SSE connection
569+ // level, not on individual tool execution requests. Direct HTTP requests allow us
570+ // to test per-request headers, which is what real MCP clients (like Claude Desktop) send.
571+
572+ var capturedSubgraphRequest * http.Request
573+ var subgraphMutex sync.Mutex
574+
575+ testenv .Run (t , & testenv.Config {
576+ MCP : config.MCPConfiguration {
577+ Enabled : true ,
578+ Session : config.MCPSessionConfig {
579+ Stateless : true , // Enable stateless mode so we don't need session IDs
580+ },
581+ },
582+ RouterOptions : []core.Option {
583+ // Forward all headers including custom ones
584+ core .WithHeaderRules (config.HeaderRules {
585+ All : & config.GlobalHeaderRule {
586+ Request : []* config.RequestHeaderRule {
587+ {
588+ Operation : config .HeaderRuleOperationPropagate ,
589+ Matching : ".*" , // Forward all headers
590+ },
591+ },
592+ },
593+ }),
594+ },
595+ Subgraphs : testenv.SubgraphsConfig {
596+ GlobalMiddleware : func (handler http.Handler ) http.Handler {
597+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
598+ subgraphMutex .Lock ()
599+ capturedSubgraphRequest = r .Clone (r .Context ())
600+ subgraphMutex .Unlock ()
601+ handler .ServeHTTP (w , r )
602+ })
603+ },
604+ },
605+ }, func (t * testing.T , xEnv * testenv.Environment ) {
606+ // With stateless mode enabled, we can make direct HTTP POST requests
607+ // without needing to establish a session first
608+ mcpAddr := xEnv .GetMCPServerAddr ()
609+
610+ // Make a direct HTTP POST request with custom headers
611+ // This simulates a real MCP client sending custom headers on tool calls
612+ mcpRequest := map [string ]interface {}{
613+ "jsonrpc" : "2.0" ,
614+ "id" : 1 ,
615+ "method" : "tools/call" ,
616+ "params" : map [string ]interface {}{
617+ "name" : "execute_operation_my_employees" ,
618+ "arguments" : map [string ]interface {}{
619+ "criteria" : map [string ]interface {}{},
620+ },
621+ },
622+ }
623+
624+ requestBody , err := json .Marshal (mcpRequest )
625+ require .NoError (t , err )
626+
627+ req , err := http .NewRequest ("POST" , mcpAddr , strings .NewReader (string (requestBody )))
628+ require .NoError (t , err )
629+
630+ // Add various headers to test forwarding
631+ req .Header .Set ("Content-Type" , "application/json" )
632+ req .Header .Set ("foo" , "bar" ) // Non-standard header
633+ req .Header .Set ("X-Custom-Header" , "custom-value" ) // Custom X- header
634+ req .Header .Set ("X-Trace-Id" , "trace-123" ) // Tracing header
635+ req .Header .Set ("Authorization" , "Bearer test-token" ) // Auth header
636+
637+ // Make the request
638+ resp , err := xEnv .RouterClient .Do (req )
639+ require .NoError (t , err )
640+ defer resp .Body .Close ()
641+
642+ // With stateless mode, the request should succeed
643+ t .Logf ("Response Status: %d" , resp .StatusCode )
644+ require .Equal (t , http .StatusOK , resp .StatusCode , "Request should succeed in stateless mode" )
645+
646+ // Verify headers reached subgraph
647+ subgraphMutex .Lock ()
648+ defer subgraphMutex .Unlock ()
649+
650+ require .NotNil (t , capturedSubgraphRequest , "Subgraph should have received a request" )
651+
652+ // Log all headers that the subgraph received
653+ t .Logf ("Headers received by subgraph:" )
654+ for key , values := range capturedSubgraphRequest .Header {
655+ for _ , value := range values {
656+ t .Logf (" %s: %s" , key , value )
657+ }
658+ }
659+
660+ // Verify that all headers were forwarded through the entire chain:
661+ // MCP Client -> MCP Server -> Router -> Subgraph
662+ assert .Equal (t , "bar" , capturedSubgraphRequest .Header .Get ("Foo" ),
663+ "'foo' header should be forwarded to subgraph" )
664+ assert .Equal (t , "custom-value" , capturedSubgraphRequest .Header .Get ("X-Custom-Header" ),
665+ "X-Custom-Header should be forwarded to subgraph" )
666+ assert .Equal (t , "trace-123" , capturedSubgraphRequest .Header .Get ("X-Trace-Id" ),
667+ "X-Trace-Id should be forwarded to subgraph" )
668+ assert .Equal (t , "Bearer test-token" , capturedSubgraphRequest .Header .Get ("Authorization" ),
669+ "Authorization header should be forwarded to subgraph" )
670+
671+ // This test proves that ALL headers sent by MCP clients are forwarded
672+ // through the complete chain. The router's header rules determine what
673+ // ultimately reaches the subgraphs.
674+ })
675+ })
676+ })
556677}
0 commit comments