3
3
package transparent
4
4
5
5
import (
6
+ "bufio"
7
+ "bytes"
6
8
"context"
9
+ "encoding/json"
7
10
"fmt"
11
+ "io"
12
+ "mime"
8
13
"net/http"
9
14
"net/http/httputil"
10
15
"net/url"
16
+ "regexp"
17
+ "strings"
11
18
"sync"
12
19
"time"
13
20
14
21
"golang.org/x/exp/jsonrpc2"
15
22
16
23
"github.com/stacklok/toolhive/pkg/healthcheck"
17
24
"github.com/stacklok/toolhive/pkg/logger"
25
+ "github.com/stacklok/toolhive/pkg/transport/session"
18
26
"github.com/stacklok/toolhive/pkg/transport/types"
19
27
)
20
28
@@ -47,6 +55,12 @@ type TransparentProxy struct {
47
55
48
56
// Optional Prometheus metrics handler
49
57
prometheusHandler http.Handler
58
+
59
+ // Sessions for tracking state
60
+ sessionManager * session.Manager
61
+
62
+ // If mcp server has been initialized
63
+ IsServerInitialized bool
50
64
}
51
65
52
66
// NewTransparentProxy creates a new transparent proxy with optional middlewares.
@@ -66,6 +80,7 @@ func NewTransparentProxy(
66
80
middlewares : middlewares ,
67
81
shutdownCh : make (chan struct {}),
68
82
prometheusHandler : prometheusHandler ,
83
+ sessionManager : session .NewManager (30 * time .Minute , session .NewProxySession ),
69
84
}
70
85
71
86
// Create MCP pinger and health checker
@@ -75,6 +90,144 @@ func NewTransparentProxy(
75
90
return proxy
76
91
}
77
92
93
+ type tracingTransport struct {
94
+ base http.RoundTripper
95
+ p * TransparentProxy
96
+ }
97
+
98
+ func (p * TransparentProxy ) setServerInitialized () {
99
+ if ! p .IsServerInitialized {
100
+ p .mutex .Lock ()
101
+ p .IsServerInitialized = true
102
+ p .mutex .Unlock ()
103
+ logger .Infof ("Server was initialized successfully for %s" , p .containerName )
104
+ }
105
+ }
106
+
107
+ func (t * tracingTransport ) forward (req * http.Request ) (* http.Response , error ) {
108
+ tr := t .base
109
+ if tr == nil {
110
+ tr = http .DefaultTransport
111
+ }
112
+ return tr .RoundTrip (req )
113
+ }
114
+
115
+ func (t * tracingTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
116
+ reqBody := readRequestBody (req )
117
+
118
+ path := req .URL .Path
119
+ isMCP := strings .HasPrefix (path , "/mcp" )
120
+ isJSON := strings .Contains (req .Header .Get ("Content-Type" ), "application/json" )
121
+ sawInitialize := false
122
+
123
+ if isMCP && isJSON && len (reqBody ) > 0 {
124
+ sawInitialize = t .detectInitialize (reqBody )
125
+ }
126
+
127
+ resp , err := t .forward (req )
128
+ if err != nil {
129
+ logger .Errorf ("Failed to forward request: %v" , err )
130
+ return nil , err
131
+ }
132
+
133
+ if resp .StatusCode == http .StatusOK {
134
+ // check if we saw a valid mcp header
135
+ ct := resp .Header .Get ("Mcp-Session-Id" )
136
+ if ct != "" {
137
+ logger .Infof ("Detected Mcp-Session-Id header: %s" , ct )
138
+ if _ , ok := t .p .sessionManager .Get (ct ); ! ok {
139
+ if err := t .p .sessionManager .AddWithID (ct ); err != nil {
140
+ logger .Errorf ("Failed to create session from header %s: %v" , ct , err )
141
+ }
142
+ }
143
+ t .p .setServerInitialized ()
144
+ return resp , nil
145
+ }
146
+ // status was ok and we saw an initialize call
147
+ if sawInitialize && ! t .p .IsServerInitialized {
148
+ t .p .setServerInitialized ()
149
+ return resp , nil
150
+ }
151
+ }
152
+
153
+ return resp , nil
154
+ }
155
+
156
+ func readRequestBody (req * http.Request ) []byte {
157
+ reqBody := []byte {}
158
+ if req .Body != nil {
159
+ buf , err := io .ReadAll (req .Body )
160
+ if err != nil {
161
+ logger .Errorf ("Failed to read request body: %v" , err )
162
+ } else {
163
+ reqBody = buf
164
+ }
165
+ req .Body = io .NopCloser (bytes .NewReader (reqBody ))
166
+ }
167
+ return reqBody
168
+ }
169
+
170
+ func (t * tracingTransport ) detectInitialize (body []byte ) bool {
171
+ var rpc struct {
172
+ Method string `json:"method"`
173
+ }
174
+ if err := json .Unmarshal (body , & rpc ); err != nil {
175
+ logger .Errorf ("Failed to parse JSON-RPC body: %v" , err )
176
+ return false
177
+ }
178
+ if rpc .Method == "initialize" {
179
+ logger .Infof ("Detected initialize method call for %s" , t .p .containerName )
180
+ return true
181
+ }
182
+ return false
183
+ }
184
+
185
+ var sessionRe = regexp .MustCompile (`sessionId=([0-9A-Fa-f-]+)|"sessionId"\s*:\s*"([^"]+)"` )
186
+
187
+ func (p * TransparentProxy ) modifyForSessionID (resp * http.Response ) error {
188
+ mediaType , _ , _ := mime .ParseMediaType (resp .Header .Get ("Content-Type" ))
189
+ if mediaType != "text/event-stream" {
190
+ return nil
191
+ }
192
+
193
+ pr , pw := io .Pipe ()
194
+ originalBody := resp .Body
195
+ resp .Body = pr
196
+
197
+ go func () {
198
+ defer pw .Close ()
199
+ scanner := bufio .NewScanner (originalBody )
200
+ found := false
201
+
202
+ for scanner .Scan () {
203
+ line := scanner .Bytes ()
204
+ if ! found {
205
+ if m := sessionRe .FindSubmatch (line ); m != nil {
206
+ sid := string (m [1 ])
207
+ if sid == "" {
208
+ sid = string (m [2 ])
209
+ }
210
+ p .setServerInitialized ()
211
+ err := p .sessionManager .AddWithID (sid )
212
+ if err != nil {
213
+ logger .Errorf ("Failed to create session from SSE line: %v" , err )
214
+ }
215
+ found = true
216
+ }
217
+ }
218
+ if _ , err := pw .Write (append (line , '\n' )); err != nil {
219
+ return
220
+ }
221
+ }
222
+ _ , err := io .Copy (pw , originalBody )
223
+ if err != nil && err != io .EOF {
224
+ logger .Errorf ("Failed to copy response body: %v" , err )
225
+ }
226
+ }()
227
+
228
+ return nil
229
+ }
230
+
78
231
// Start starts the transparent proxy.
79
232
func (p * TransparentProxy ) Start (ctx context.Context ) error {
80
233
p .mutex .Lock ()
@@ -88,6 +241,11 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
88
241
89
242
// Create a reverse proxy
90
243
proxy := httputil .NewSingleHostReverseProxy (targetURL )
244
+ proxy .FlushInterval = - 1
245
+ proxy .Transport = & tracingTransport {base : http .DefaultTransport , p : p }
246
+ proxy .ModifyResponse = func (resp * http.Response ) error {
247
+ return p .modifyForSessionID (resp )
248
+ }
91
249
92
250
// Create a handler that logs requests
93
251
handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -160,13 +318,18 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
160
318
logger .Infof ("Shutdown initiated, stopping health monitor for %s" , p .containerName )
161
319
return
162
320
case <- ticker .C :
163
- alive := p .healthChecker .CheckHealth (parentCtx )
164
- if alive .Status != healthcheck .StatusHealthy {
165
- logger .Infof ("Health check failed for %s; initiating proxy shutdown" , p .containerName )
166
- if err := p .Stop (parentCtx ); err != nil {
167
- logger .Errorf ("Failed to stop proxy for %s: %v" , p .containerName , err )
321
+ // Perform health check only if mcp server has been initialized
322
+ if p .IsServerInitialized {
323
+ alive := p .healthChecker .CheckHealth (parentCtx )
324
+ if alive .Status != healthcheck .StatusHealthy {
325
+ logger .Infof ("Health check failed for %s; initiating proxy shutdown" , p .containerName )
326
+ if err := p .Stop (parentCtx ); err != nil {
327
+ logger .Errorf ("Failed to stop proxy for %s: %v" , p .containerName , err )
328
+ }
329
+ return
168
330
}
169
- return
331
+ } else {
332
+ logger .Infof ("MCP server not initialized yet, skipping health check for %s" , p .containerName )
170
333
}
171
334
}
172
335
}
0 commit comments