Skip to content

Commit 957c3a6

Browse files
authored
fix: resolve dual React instance in JSX renderer and propagate provenance session IDs (#194)
* fix: resolve dual React instance in JSX renderer and propagate provenance session IDs The JSX sandbox iframe used ?bundle on all esm.sh imports, causing each package (react, react-dom, react/jsx-runtime) to bundle its own copy of React internals. When createRoot mounted on one instance but components used hooks from another, useState saw a null dispatcher — blank page. Fix: remove ?bundle from react/react-dom (let esm.sh deduplicate via shared module URLs), use ?bundle&external=react,react-dom on leaf packages (recharts, lucide-react). Add unhandledrejection handler so module load errors surface instead of silent blank pages. Also adds WithReplacedSessionID/ReplacedSessionID to the session handler for provenance tracking, and harvestProvenance helper to the provenance middleware for merging both original and replacement session IDs. * fix: escape </script> in JSX renderer to prevent HTML breakout Transformed JSX injected into the iframe's <script> block could contain literal </script> strings, breaking the HTML structure. Add escapeScriptClose() to replace </script with <\/script (valid JS, does not terminate the HTML tag). Also switch the error path to use textContent via JSON.stringify instead of raw interpolation.
1 parent f8f6254 commit 957c3a6

File tree

11 files changed

+1102
-83
lines changed

11 files changed

+1102
-83
lines changed

pkg/middleware/mcp_provenance.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"time"
99

1010
"github.com/modelcontextprotocol/go-sdk/mcp"
11+
12+
pkgsession "github.com/txn2/mcp-data-platform/pkg/session"
1113
)
1214

1315
// provenanceContextKey is the context key for provenance tool calls.
@@ -122,6 +124,17 @@ func summarizeParams(params map[string]any) string {
122124
return s
123125
}
124126

127+
// harvestProvenance collects tool calls from the current session and, if a
128+
// session replacement occurred, also from the old (replaced) session.
129+
func harvestProvenance(ctx context.Context, tracker *ProvenanceTracker, sessionID string) []ProvenanceToolCall {
130+
calls := tracker.Harvest(sessionID)
131+
if replacedID := pkgsession.ReplacedSessionID(ctx); replacedID != "" {
132+
oldCalls := tracker.Harvest(replacedID)
133+
calls = append(oldCalls, calls...)
134+
}
135+
return calls
136+
}
137+
125138
// MCPProvenanceMiddleware tracks tool calls per session and injects
126139
// accumulated provenance into the context when save_artifact is called.
127140
func MCPProvenanceMiddleware(tracker *ProvenanceTracker, saveToolName string) mcp.Middleware {
@@ -145,7 +158,7 @@ func MCPProvenanceMiddleware(tracker *ProvenanceTracker, saveToolName string) mc
145158
}
146159

147160
if toolName == saveToolName {
148-
calls := tracker.Harvest(sessionID)
161+
calls := harvestProvenance(ctx, tracker, sessionID)
149162
ctx = WithProvenanceToolCalls(ctx, calls)
150163
return next(ctx, method, req)
151164
}

pkg/middleware/mcp_provenance_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"github.com/modelcontextprotocol/go-sdk/mcp"
1010
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"
12+
13+
pkgsession "github.com/txn2/mcp-data-platform/pkg/session"
1214
)
1315

1416
// newTestServerRequest creates a test server request for tools/call.
@@ -238,6 +240,73 @@ func TestProvenanceTracker_CleanupBefore(t *testing.T) {
238240
assert.Len(t, calls, 1)
239241
}
240242

243+
func TestMCPProvenanceMiddleware_SessionReplacement(t *testing.T) {
244+
tracker := NewProvenanceTracker()
245+
246+
// Record tool calls under the "old" session (the one that will expire).
247+
tracker.Record("old-session", "trino_query", map[string]any{"sql": "SELECT 1"})
248+
tracker.Record("old-session", "datahub_search", map[string]any{"query": "sales"})
249+
250+
var capturedCalls []ProvenanceToolCall
251+
base := func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
252+
capturedCalls = GetProvenanceToolCalls(ctx)
253+
return &mcp.CallToolResult{}, nil
254+
}
255+
256+
handler := MCPProvenanceMiddleware(tracker, "save_artifact")(base)
257+
258+
req := newTestServerRequest(&mcp.CallToolParamsRaw{
259+
Name: "save_artifact",
260+
})
261+
262+
// The new session has no tool calls, but the replaced session ID carries
263+
// the old session's provenance.
264+
ctx := WithPlatformContext(context.Background(), &PlatformContext{SessionID: "new-session"})
265+
ctx = pkgsession.WithReplacedSessionID(ctx, "old-session")
266+
267+
_, err := handler(ctx, methodToolsCall, req)
268+
require.NoError(t, err)
269+
270+
// Both old-session tool calls should be recovered.
271+
require.Len(t, capturedCalls, 2, "should recover provenance from replaced session")
272+
assert.Equal(t, "trino_query", capturedCalls[0].ToolName)
273+
assert.Equal(t, "datahub_search", capturedCalls[1].ToolName)
274+
275+
// Both sessions should be cleared after harvest.
276+
assert.Nil(t, tracker.Harvest("old-session"))
277+
assert.Nil(t, tracker.Harvest("new-session"))
278+
}
279+
280+
func TestMCPProvenanceMiddleware_SessionReplacementMergesBoth(t *testing.T) {
281+
tracker := NewProvenanceTracker()
282+
283+
// Old session has 2 calls, new session has 1 call.
284+
tracker.Record("old-sess", "tool_a", nil)
285+
tracker.Record("old-sess", "tool_b", nil)
286+
tracker.Record("new-sess", "tool_c", nil)
287+
288+
var capturedCalls []ProvenanceToolCall
289+
base := func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
290+
capturedCalls = GetProvenanceToolCalls(ctx)
291+
return &mcp.CallToolResult{}, nil
292+
}
293+
294+
handler := MCPProvenanceMiddleware(tracker, "save_artifact")(base)
295+
296+
req := newTestServerRequest(&mcp.CallToolParamsRaw{Name: "save_artifact"})
297+
ctx := WithPlatformContext(context.Background(), &PlatformContext{SessionID: "new-sess"})
298+
ctx = pkgsession.WithReplacedSessionID(ctx, "old-sess")
299+
300+
_, err := handler(ctx, methodToolsCall, req)
301+
require.NoError(t, err)
302+
303+
// Old calls come first, then new session calls.
304+
require.Len(t, capturedCalls, 3)
305+
assert.Equal(t, "tool_a", capturedCalls[0].ToolName)
306+
assert.Equal(t, "tool_b", capturedCalls[1].ToolName)
307+
assert.Equal(t, "tool_c", capturedCalls[2].ToolName)
308+
}
309+
241310
func TestExtractToolParams_NilCases(t *testing.T) {
242311
// Request with nil arguments.
243312
req := newTestServerRequest(&mcp.CallToolParamsRaw{Name: "test"})

pkg/session/handler.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ import (
1515
// awareSessionKey is the context key for the AwareHandler session ID.
1616
type awareSessionKey struct{}
1717

18+
// replacedSessionKey is the context key for the old session ID that was replaced
19+
// during session recovery (expired session → new session with auth credentials).
20+
type replacedSessionKey struct{}
21+
1822
// AwareSessionID returns the session ID set by AwareHandler, or "".
1923
func AwareSessionID(ctx context.Context) string {
2024
if id, ok := ctx.Value(awareSessionKey{}).(string); ok {
@@ -30,6 +34,22 @@ func WithAwareSessionID(ctx context.Context, sessionID string) context.Context {
3034
return context.WithValue(ctx, awareSessionKey{}, sessionID)
3135
}
3236

37+
// ReplacedSessionID returns the old session ID that was replaced during session
38+
// recovery, or "" if no replacement occurred.
39+
func ReplacedSessionID(ctx context.Context) string {
40+
if id, ok := ctx.Value(replacedSessionKey{}).(string); ok {
41+
return id
42+
}
43+
return ""
44+
}
45+
46+
// WithReplacedSessionID returns a context carrying the old session ID that was
47+
// replaced. This allows downstream middleware (e.g. provenance) to recover data
48+
// recorded under the old session.
49+
func WithReplacedSessionID(ctx context.Context, oldSessionID string) context.Context {
50+
return context.WithValue(ctx, replacedSessionKey{}, oldSessionID)
51+
}
52+
3353
const (
3454
// sessionIDHeader is the MCP session header name.
3555
sessionIDHeader = "Mcp-Session-Id"
@@ -142,6 +162,7 @@ func (h *AwareHandler) handleExisting(w http.ResponseWriter, r *http.Request, se
142162
if extractToken(r) != "" {
143163
slog.Info("session: expired, creating replacement",
144164
"old_session_id", sanitizeLogValue(sessionID)) // #nosec G706 -- sessionID sanitized via sanitizeLogValue
165+
r = r.WithContext(WithReplacedSessionID(r.Context(), sessionID))
145166
h.handleInitialize(w, r)
146167
return
147168
}

pkg/session/handler_test.go

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,17 +442,19 @@ func TestAwareSessionID_Roundtrip(t *testing.T) {
442442
assert.Equal(t, "test-session-123", got)
443443
}
444444

445-
// contextCapturingHandler captures the AwareSessionID from the request context.
445+
// contextCapturingHandler captures the AwareSessionID and ReplacedSessionID from the request context.
446446
type contextCapturingHandler struct {
447-
mu sync.Mutex
448-
awareSessionID string
449-
capturedCalled bool
447+
mu sync.Mutex
448+
awareSessionID string
449+
replacedSessionID string
450+
capturedCalled bool
450451
}
451452

452453
func (h *contextCapturingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
453454
h.mu.Lock()
454455
h.capturedCalled = true
455456
h.awareSessionID = AwareSessionID(r.Context())
457+
h.replacedSessionID = ReplacedSessionID(r.Context())
456458
h.mu.Unlock()
457459
w.WriteHeader(http.StatusOK)
458460
}
@@ -531,3 +533,69 @@ func TestHandler_ConcurrentAccess(t *testing.T) {
531533
}
532534
wg.Wait()
533535
}
536+
537+
func TestReplacedSessionID_EmptyContext(t *testing.T) {
538+
got := ReplacedSessionID(context.Background())
539+
assert.Empty(t, got, "plain context should return empty string")
540+
}
541+
542+
func TestReplacedSessionID_Roundtrip(t *testing.T) {
543+
ctx := WithReplacedSessionID(context.Background(), "old-session-abc")
544+
got := ReplacedSessionID(ctx)
545+
assert.Equal(t, "old-session-abc", got)
546+
}
547+
548+
func TestHandler_SessionReplacement_SetsReplacedSessionID(t *testing.T) {
549+
store := NewMemoryStore(handlerTestTTL)
550+
capture := &contextCapturingHandler{}
551+
handler := NewAwareHandler(capture, HandlerConfig{
552+
Store: store,
553+
TTL: handlerTestTTL,
554+
})
555+
556+
// Create an expired session
557+
sess := newTestSession("old-session-for-replace", -time.Second)
558+
sess.UserID = hashToken("replace-token")
559+
require.NoError(t, store.Create(context.Background(), sess))
560+
561+
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, handlerTestPath, http.NoBody)
562+
req.Header.Set(sessionIDHeader, "old-session-for-replace")
563+
req.Header.Set(handlerTestAuthHeader, "Bearer replace-token")
564+
w := httptest.NewRecorder()
565+
566+
handler.ServeHTTP(w, req)
567+
568+
capture.mu.Lock()
569+
defer capture.mu.Unlock()
570+
assert.True(t, capture.capturedCalled, "inner handler should be called")
571+
assert.NotEmpty(t, capture.awareSessionID, "new session ID should be set")
572+
assert.NotEqual(t, "old-session-for-replace", capture.awareSessionID)
573+
assert.Equal(t, "old-session-for-replace", capture.replacedSessionID,
574+
"replaced session ID should carry the old session ID")
575+
}
576+
577+
func TestHandler_NormalSession_NoReplacedSessionID(t *testing.T) {
578+
store := NewMemoryStore(handlerTestTTL)
579+
capture := &contextCapturingHandler{}
580+
handler := NewAwareHandler(capture, HandlerConfig{
581+
Store: store,
582+
TTL: handlerTestTTL,
583+
})
584+
585+
// Create a valid (non-expired) session
586+
sess := newTestSession("normal-session", handlerTestTTL)
587+
sess.UserID = ""
588+
require.NoError(t, store.Create(context.Background(), sess))
589+
590+
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, handlerTestPath, http.NoBody)
591+
req.Header.Set(sessionIDHeader, "normal-session")
592+
w := httptest.NewRecorder()
593+
594+
handler.ServeHTTP(w, req)
595+
596+
capture.mu.Lock()
597+
defer capture.mu.Unlock()
598+
assert.True(t, capture.capturedCalled)
599+
assert.Empty(t, capture.replacedSessionID,
600+
"normal sessions should not have a replaced session ID")
601+
}

0 commit comments

Comments
 (0)