Skip to content

Commit 78345e9

Browse files
yroblataskbotCopilot
authored
feat: add session management for proxy (#1081)
Co-authored-by: taskbot <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 423e203 commit 78345e9

File tree

5 files changed

+578
-6
lines changed

5 files changed

+578
-6
lines changed

pkg/transport/proxy/transparent/transparent_proxy.go

Lines changed: 169 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,26 @@
33
package transparent
44

55
import (
6+
"bufio"
7+
"bytes"
68
"context"
9+
"encoding/json"
710
"fmt"
11+
"io"
12+
"mime"
813
"net/http"
914
"net/http/httputil"
1015
"net/url"
16+
"regexp"
17+
"strings"
1118
"sync"
1219
"time"
1320

1421
"golang.org/x/exp/jsonrpc2"
1522

1623
"github.com/stacklok/toolhive/pkg/healthcheck"
1724
"github.com/stacklok/toolhive/pkg/logger"
25+
"github.com/stacklok/toolhive/pkg/transport/session"
1826
"github.com/stacklok/toolhive/pkg/transport/types"
1927
)
2028

@@ -47,6 +55,12 @@ type TransparentProxy struct {
4755

4856
// Optional Prometheus metrics handler
4957
prometheusHandler http.Handler
58+
59+
// Sessions for tracking state
60+
sessionManager *session.Manager
61+
62+
// If mcp server has been initialized
63+
IsServerInitialized bool
5064
}
5165

5266
// NewTransparentProxy creates a new transparent proxy with optional middlewares.
@@ -66,6 +80,7 @@ func NewTransparentProxy(
6680
middlewares: middlewares,
6781
shutdownCh: make(chan struct{}),
6882
prometheusHandler: prometheusHandler,
83+
sessionManager: session.NewManager(30*time.Minute, session.NewProxySession),
6984
}
7085

7186
// Create MCP pinger and health checker
@@ -75,6 +90,144 @@ func NewTransparentProxy(
7590
return proxy
7691
}
7792

93+
type tracingTransport struct {
94+
base http.RoundTripper
95+
p *TransparentProxy
96+
}
97+
98+
func (p *TransparentProxy) setServerInitialized() {
99+
if !p.IsServerInitialized {
100+
p.mutex.Lock()
101+
p.IsServerInitialized = true
102+
p.mutex.Unlock()
103+
logger.Infof("Server was initialized successfully for %s", p.containerName)
104+
}
105+
}
106+
107+
func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) {
108+
tr := t.base
109+
if tr == nil {
110+
tr = http.DefaultTransport
111+
}
112+
return tr.RoundTrip(req)
113+
}
114+
115+
func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
116+
reqBody := readRequestBody(req)
117+
118+
path := req.URL.Path
119+
isMCP := strings.HasPrefix(path, "/mcp")
120+
isJSON := strings.Contains(req.Header.Get("Content-Type"), "application/json")
121+
sawInitialize := false
122+
123+
if isMCP && isJSON && len(reqBody) > 0 {
124+
sawInitialize = t.detectInitialize(reqBody)
125+
}
126+
127+
resp, err := t.forward(req)
128+
if err != nil {
129+
logger.Errorf("Failed to forward request: %v", err)
130+
return nil, err
131+
}
132+
133+
if resp.StatusCode == http.StatusOK {
134+
// check if we saw a valid mcp header
135+
ct := resp.Header.Get("Mcp-Session-Id")
136+
if ct != "" {
137+
logger.Infof("Detected Mcp-Session-Id header: %s", ct)
138+
if _, ok := t.p.sessionManager.Get(ct); !ok {
139+
if err := t.p.sessionManager.AddWithID(ct); err != nil {
140+
logger.Errorf("Failed to create session from header %s: %v", ct, err)
141+
}
142+
}
143+
t.p.setServerInitialized()
144+
return resp, nil
145+
}
146+
// status was ok and we saw an initialize call
147+
if sawInitialize && !t.p.IsServerInitialized {
148+
t.p.setServerInitialized()
149+
return resp, nil
150+
}
151+
}
152+
153+
return resp, nil
154+
}
155+
156+
func readRequestBody(req *http.Request) []byte {
157+
reqBody := []byte{}
158+
if req.Body != nil {
159+
buf, err := io.ReadAll(req.Body)
160+
if err != nil {
161+
logger.Errorf("Failed to read request body: %v", err)
162+
} else {
163+
reqBody = buf
164+
}
165+
req.Body = io.NopCloser(bytes.NewReader(reqBody))
166+
}
167+
return reqBody
168+
}
169+
170+
func (t *tracingTransport) detectInitialize(body []byte) bool {
171+
var rpc struct {
172+
Method string `json:"method"`
173+
}
174+
if err := json.Unmarshal(body, &rpc); err != nil {
175+
logger.Errorf("Failed to parse JSON-RPC body: %v", err)
176+
return false
177+
}
178+
if rpc.Method == "initialize" {
179+
logger.Infof("Detected initialize method call for %s", t.p.containerName)
180+
return true
181+
}
182+
return false
183+
}
184+
185+
var sessionRe = regexp.MustCompile(`sessionId=([0-9A-Fa-f-]+)|"sessionId"\s*:\s*"([^"]+)"`)
186+
187+
func (p *TransparentProxy) modifyForSessionID(resp *http.Response) error {
188+
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
189+
if mediaType != "text/event-stream" {
190+
return nil
191+
}
192+
193+
pr, pw := io.Pipe()
194+
originalBody := resp.Body
195+
resp.Body = pr
196+
197+
go func() {
198+
defer pw.Close()
199+
scanner := bufio.NewScanner(originalBody)
200+
found := false
201+
202+
for scanner.Scan() {
203+
line := scanner.Bytes()
204+
if !found {
205+
if m := sessionRe.FindSubmatch(line); m != nil {
206+
sid := string(m[1])
207+
if sid == "" {
208+
sid = string(m[2])
209+
}
210+
p.setServerInitialized()
211+
err := p.sessionManager.AddWithID(sid)
212+
if err != nil {
213+
logger.Errorf("Failed to create session from SSE line: %v", err)
214+
}
215+
found = true
216+
}
217+
}
218+
if _, err := pw.Write(append(line, '\n')); err != nil {
219+
return
220+
}
221+
}
222+
_, err := io.Copy(pw, originalBody)
223+
if err != nil && err != io.EOF {
224+
logger.Errorf("Failed to copy response body: %v", err)
225+
}
226+
}()
227+
228+
return nil
229+
}
230+
78231
// Start starts the transparent proxy.
79232
func (p *TransparentProxy) Start(ctx context.Context) error {
80233
p.mutex.Lock()
@@ -88,6 +241,11 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
88241

89242
// Create a reverse proxy
90243
proxy := httputil.NewSingleHostReverseProxy(targetURL)
244+
proxy.FlushInterval = -1
245+
proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p}
246+
proxy.ModifyResponse = func(resp *http.Response) error {
247+
return p.modifyForSessionID(resp)
248+
}
91249

92250
// Create a handler that logs requests
93251
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -160,13 +318,18 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
160318
logger.Infof("Shutdown initiated, stopping health monitor for %s", p.containerName)
161319
return
162320
case <-ticker.C:
163-
alive := p.healthChecker.CheckHealth(parentCtx)
164-
if alive.Status != healthcheck.StatusHealthy {
165-
logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName)
166-
if err := p.Stop(parentCtx); err != nil {
167-
logger.Errorf("Failed to stop proxy for %s: %v", p.containerName, err)
321+
// Perform health check only if mcp server has been initialized
322+
if p.IsServerInitialized {
323+
alive := p.healthChecker.CheckHealth(parentCtx)
324+
if alive.Status != healthcheck.StatusHealthy {
325+
logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName)
326+
if err := p.Stop(parentCtx); err != nil {
327+
logger.Errorf("Failed to stop proxy for %s: %v", p.containerName, err)
328+
}
329+
return
168330
}
169-
return
331+
} else {
332+
logger.Infof("MCP server not initialized yet, skipping health check for %s", p.containerName)
170333
}
171334
}
172335
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package transparent
2+
3+
import (
4+
"bufio"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/http/httputil"
8+
"net/url"
9+
"testing"
10+
"time"
11+
12+
"github.com/stretchr/testify/assert"
13+
14+
"github.com/stacklok/toolhive/pkg/logger"
15+
)
16+
17+
func init() {
18+
logger.Initialize() // ensure logging doesn't panic
19+
}
20+
21+
func TestStreamingSessionIDDetection(t *testing.T) {
22+
t.Parallel()
23+
proxy := NewTransparentProxy("127.0.0.1", 0, "test", "http://example.com", nil)
24+
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
25+
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
26+
w.WriteHeader(200)
27+
28+
// Simulate SSE lines
29+
w.Write([]byte("data: hello\n"))
30+
w.Write([]byte("data: sessionId=ABC123\n"))
31+
w.(http.Flusher).Flush()
32+
33+
time.Sleep(10 * time.Millisecond)
34+
w.Write([]byte("data: more\n"))
35+
}))
36+
defer target.Close()
37+
38+
// set up reverse proxy using ModifyResponse
39+
parsedURL, _ := http.NewRequest("GET", target.URL, nil)
40+
proxyURL := httputil.NewSingleHostReverseProxy(parsedURL.URL)
41+
proxyURL.FlushInterval = -1
42+
proxyURL.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy}
43+
proxyURL.ModifyResponse = proxy.modifyForSessionID
44+
45+
// hit the proxy
46+
rec := httptest.NewRecorder()
47+
req := httptest.NewRequest("GET", target.URL, nil)
48+
proxyURL.ServeHTTP(rec, req)
49+
50+
// read all SSE lines
51+
sc := bufio.NewScanner(rec.Body)
52+
var bodyLines []string
53+
for sc.Scan() {
54+
bodyLines = append(bodyLines, sc.Text())
55+
}
56+
assert.Contains(t, bodyLines, "data: sessionId=ABC123")
57+
58+
// side-effect: proxy should have seen session
59+
assert.True(t, proxy.IsServerInitialized, "server should have been initialized")
60+
_, ok := proxy.sessionManager.Get("ABC123")
61+
assert.True(t, ok, "sessionManager should have stored ABC123")
62+
}
63+
64+
func createBasicProxy(p *TransparentProxy, targetURL *url.URL) *httputil.ReverseProxy {
65+
proxy := httputil.NewSingleHostReverseProxy(targetURL)
66+
proxy.Director = func(r *http.Request) {
67+
r.URL.Scheme = targetURL.Scheme
68+
r.URL.Host = targetURL.Host
69+
r.Host = targetURL.Host
70+
}
71+
proxy.FlushInterval = -1
72+
proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p}
73+
proxy.ModifyResponse = p.modifyForSessionID
74+
return proxy
75+
}
76+
77+
func TestNoSessionIDInNonSSE(t *testing.T) {
78+
t.Parallel()
79+
80+
p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil)
81+
82+
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
83+
// Set both content-type and also optionally MCP header to test behavior
84+
w.Header().Set("Content-Type", "application/json")
85+
w.WriteHeader(200)
86+
w.Write([]byte(`{"hello": "world"}`))
87+
}))
88+
defer target.Close()
89+
90+
targetURL, _ := url.Parse(target.URL)
91+
proxy := createBasicProxy(p, targetURL)
92+
93+
rec := httptest.NewRecorder()
94+
req := httptest.NewRequest("GET", target.URL, nil)
95+
proxy.ServeHTTP(rec, req)
96+
97+
assert.False(t, p.IsServerInitialized, "server should not be initialized for application/json")
98+
_, ok := p.sessionManager.Get("XYZ789")
99+
assert.False(t, ok, "no session should be added")
100+
}
101+
102+
func TestHeaderBasedSessionInitialization(t *testing.T) {
103+
t.Parallel()
104+
105+
p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil)
106+
107+
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
108+
// Set both content-type and also optionally MCP header to test behavior
109+
w.Header().Set("Content-Type", "application/json")
110+
w.Header().Set("Mcp-Session-Id", "XYZ789")
111+
w.WriteHeader(200)
112+
w.Write([]byte(`{"hello": "world"}`))
113+
}))
114+
defer target.Close()
115+
116+
targetURL, _ := url.Parse(target.URL)
117+
proxy := createBasicProxy(p, targetURL)
118+
119+
rec := httptest.NewRecorder()
120+
req := httptest.NewRequest("GET", target.URL, nil)
121+
proxy.ServeHTTP(rec, req)
122+
123+
assert.True(t, p.IsServerInitialized, "server should not be initialized for application/json")
124+
_, ok := p.sessionManager.Get("XYZ789")
125+
assert.True(t, ok, "no session should be added")
126+
}

0 commit comments

Comments
 (0)