Skip to content

Commit 11483fe

Browse files
authored
fix: revive expired sessions with same ID to preserve provenance and dedup (#195)
* fix: revive expired sessions with same ID to preserve provenance and dedup When a client reconnects after session expiry, handleExisting was calling handleInitialize which generated a new random session ID. But MCP clients (e.g. Claude Desktop) don't update their stored session ID from response headers, so every subsequent request triggered another replacement — breaking provenance tracking (0 tool calls) and enrichment dedup (never fires). Replace handleInitialize call with reviveSession that recreates the session using the client's original ID. Delete-then-create avoids unique constraint violations from expired-but-not-yet-cleaned rows. * fix: remove dead ReplacedSessionID code and add upsert for revive race safety Remove ReplacedSessionID/WithReplacedSessionID infrastructure that became dead code after switching to same-ID session revival. The replaced-session context value was never populated in the revive path, making the harvestProvenance fallback and all related tests unreachable. Add ON CONFLICT upsert to the Postgres session store's Create method so concurrent revive attempts on the same expired session ID don't fail with unique constraint violations.
1 parent 957c3a6 commit 11483fe

File tree

5 files changed

+202
-161
lines changed

5 files changed

+202
-161
lines changed

pkg/middleware/mcp_provenance.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ import (
88
"time"
99

1010
"github.com/modelcontextprotocol/go-sdk/mcp"
11-
12-
pkgsession "github.com/txn2/mcp-data-platform/pkg/session"
1311
)
1412

1513
// provenanceContextKey is the context key for provenance tool calls.
@@ -124,15 +122,9 @@ func summarizeParams(params map[string]any) string {
124122
return s
125123
}
126124

127-
// harvestProvenance collects tool calls from the current session and, if a
128-
// session replacement occurred, also from the old (replaced) session.
129-
func harvestProvenance(ctx context.Context, tracker *ProvenanceTracker, sessionID string) []ProvenanceToolCall {
130-
calls := tracker.Harvest(sessionID)
131-
if replacedID := pkgsession.ReplacedSessionID(ctx); replacedID != "" {
132-
oldCalls := tracker.Harvest(replacedID)
133-
calls = append(oldCalls, calls...)
134-
}
135-
return calls
125+
// harvestProvenance collects tool calls from the current session.
126+
func harvestProvenance(_ context.Context, tracker *ProvenanceTracker, sessionID string) []ProvenanceToolCall {
127+
return tracker.Harvest(sessionID)
136128
}
137129

138130
// MCPProvenanceMiddleware tracks tool calls per session and injects

pkg/middleware/mcp_provenance_test.go

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import (
99
"github.com/modelcontextprotocol/go-sdk/mcp"
1010
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"
12-
13-
pkgsession "github.com/txn2/mcp-data-platform/pkg/session"
1412
)
1513

1614
// newTestServerRequest creates a test server request for tools/call.
@@ -240,73 +238,6 @@ func TestProvenanceTracker_CleanupBefore(t *testing.T) {
240238
assert.Len(t, calls, 1)
241239
}
242240

243-
func TestMCPProvenanceMiddleware_SessionReplacement(t *testing.T) {
244-
tracker := NewProvenanceTracker()
245-
246-
// Record tool calls under the "old" session (the one that will expire).
247-
tracker.Record("old-session", "trino_query", map[string]any{"sql": "SELECT 1"})
248-
tracker.Record("old-session", "datahub_search", map[string]any{"query": "sales"})
249-
250-
var capturedCalls []ProvenanceToolCall
251-
base := func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
252-
capturedCalls = GetProvenanceToolCalls(ctx)
253-
return &mcp.CallToolResult{}, nil
254-
}
255-
256-
handler := MCPProvenanceMiddleware(tracker, "save_artifact")(base)
257-
258-
req := newTestServerRequest(&mcp.CallToolParamsRaw{
259-
Name: "save_artifact",
260-
})
261-
262-
// The new session has no tool calls, but the replaced session ID carries
263-
// the old session's provenance.
264-
ctx := WithPlatformContext(context.Background(), &PlatformContext{SessionID: "new-session"})
265-
ctx = pkgsession.WithReplacedSessionID(ctx, "old-session")
266-
267-
_, err := handler(ctx, methodToolsCall, req)
268-
require.NoError(t, err)
269-
270-
// Both old-session tool calls should be recovered.
271-
require.Len(t, capturedCalls, 2, "should recover provenance from replaced session")
272-
assert.Equal(t, "trino_query", capturedCalls[0].ToolName)
273-
assert.Equal(t, "datahub_search", capturedCalls[1].ToolName)
274-
275-
// Both sessions should be cleared after harvest.
276-
assert.Nil(t, tracker.Harvest("old-session"))
277-
assert.Nil(t, tracker.Harvest("new-session"))
278-
}
279-
280-
func TestMCPProvenanceMiddleware_SessionReplacementMergesBoth(t *testing.T) {
281-
tracker := NewProvenanceTracker()
282-
283-
// Old session has 2 calls, new session has 1 call.
284-
tracker.Record("old-sess", "tool_a", nil)
285-
tracker.Record("old-sess", "tool_b", nil)
286-
tracker.Record("new-sess", "tool_c", nil)
287-
288-
var capturedCalls []ProvenanceToolCall
289-
base := func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
290-
capturedCalls = GetProvenanceToolCalls(ctx)
291-
return &mcp.CallToolResult{}, nil
292-
}
293-
294-
handler := MCPProvenanceMiddleware(tracker, "save_artifact")(base)
295-
296-
req := newTestServerRequest(&mcp.CallToolParamsRaw{Name: "save_artifact"})
297-
ctx := WithPlatformContext(context.Background(), &PlatformContext{SessionID: "new-sess"})
298-
ctx = pkgsession.WithReplacedSessionID(ctx, "old-sess")
299-
300-
_, err := handler(ctx, methodToolsCall, req)
301-
require.NoError(t, err)
302-
303-
// Old calls come first, then new session calls.
304-
require.Len(t, capturedCalls, 3)
305-
assert.Equal(t, "tool_a", capturedCalls[0].ToolName)
306-
assert.Equal(t, "tool_b", capturedCalls[1].ToolName)
307-
assert.Equal(t, "tool_c", capturedCalls[2].ToolName)
308-
}
309-
310241
func TestExtractToolParams_NilCases(t *testing.T) {
311242
// Request with nil arguments.
312243
req := newTestServerRequest(&mcp.CallToolParamsRaw{Name: "test"})

pkg/session/handler.go

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ import (
1515
// awareSessionKey is the context key for the AwareHandler session ID.
1616
type awareSessionKey struct{}
1717

18-
// replacedSessionKey is the context key for the old session ID that was replaced
19-
// during session recovery (expired session → new session with auth credentials).
20-
type replacedSessionKey struct{}
21-
2218
// AwareSessionID returns the session ID set by AwareHandler, or "".
2319
func AwareSessionID(ctx context.Context) string {
2420
if id, ok := ctx.Value(awareSessionKey{}).(string); ok {
@@ -34,22 +30,6 @@ func WithAwareSessionID(ctx context.Context, sessionID string) context.Context {
3430
return context.WithValue(ctx, awareSessionKey{}, sessionID)
3531
}
3632

37-
// ReplacedSessionID returns the old session ID that was replaced during session
38-
// recovery, or "" if no replacement occurred.
39-
func ReplacedSessionID(ctx context.Context) string {
40-
if id, ok := ctx.Value(replacedSessionKey{}).(string); ok {
41-
return id
42-
}
43-
return ""
44-
}
45-
46-
// WithReplacedSessionID returns a context carrying the old session ID that was
47-
// replaced. This allows downstream middleware (e.g. provenance) to recover data
48-
// recorded under the old session.
49-
func WithReplacedSessionID(ctx context.Context, oldSessionID string) context.Context {
50-
return context.WithValue(ctx, replacedSessionKey{}, oldSessionID)
51-
}
52-
5333
const (
5434
// sessionIDHeader is the MCP session header name.
5535
sessionIDHeader = "Mcp-Session-Id"
@@ -63,6 +43,9 @@ const (
6343
// slogKeyError is the slog attribute key for error values.
6444
slogKeyError = "error"
6545

46+
// httpErrInternal is the response body for HTTP 500 errors.
47+
httpErrInternal = "internal server error"
48+
6649
// touchTimeout is the maximum time for async session touch operations.
6750
touchTimeout = 5 * time.Second
6851
)
@@ -112,7 +95,7 @@ func (h *AwareHandler) handleInitialize(w http.ResponseWriter, r *http.Request)
11295
sessionID, err := generateSessionID()
11396
if err != nil {
11497
slog.Error("session: failed to generate ID", slogKeyError, err)
115-
http.Error(w, "internal server error", http.StatusInternalServerError)
98+
http.Error(w, httpErrInternal, http.StatusInternalServerError)
11699
return
117100
}
118101

@@ -128,7 +111,7 @@ func (h *AwareHandler) handleInitialize(w http.ResponseWriter, r *http.Request)
128111

129112
if err := h.store.Create(r.Context(), sess); err != nil {
130113
slog.Error("session: failed to create", slogKeyError, err)
131-
http.Error(w, "internal server error", http.StatusInternalServerError)
114+
http.Error(w, httpErrInternal, http.StatusInternalServerError)
132115
return
133116
}
134117

@@ -151,19 +134,26 @@ func (h *AwareHandler) handleExisting(w http.ResponseWriter, r *http.Request, se
151134
sess, err := h.store.Get(r.Context(), sessionID)
152135
if err != nil {
153136
slog.Error("session: store error", slogKeyError, err)
154-
http.Error(w, "internal server error", http.StatusInternalServerError)
137+
http.Error(w, httpErrInternal, http.StatusInternalServerError)
155138
return
156139
}
157140
if sess == nil {
158141
// Session expired or was cleaned up. If the request carries auth
159-
// credentials, create a replacement session transparently instead
160-
// of forcing the client to re-initialize (which the Go SDK client
161-
// does not do automatically).
142+
// credentials, revive the session using the SAME ID. Clients like
143+
// Claude Desktop do not update their stored session ID from
144+
// response headers, so generating a new ID would cause every
145+
// subsequent request to trigger another revive — breaking
146+
// provenance tracking and enrichment dedup.
162147
if extractToken(r) != "" {
163-
slog.Info("session: expired, creating replacement",
164-
"old_session_id", sanitizeLogValue(sessionID)) // #nosec G706 -- sessionID sanitized via sanitizeLogValue
165-
r = r.WithContext(WithReplacedSessionID(r.Context(), sessionID))
166-
h.handleInitialize(w, r)
148+
slog.Info("session: reviving expired session",
149+
"session_id", sanitizeLogValue(sessionID)) // #nosec G706 -- sessionID sanitized via sanitizeLogValue
150+
if err := h.reviveSession(r.Context(), sessionID, r); err != nil {
151+
slog.Error("session: failed to revive", slogKeyError, err)
152+
http.Error(w, httpErrInternal, http.StatusInternalServerError)
153+
return
154+
}
155+
r = r.WithContext(WithAwareSessionID(r.Context(), sessionID))
156+
h.inner.ServeHTTP(w, r)
167157
return
168158
}
169159
http.Error(w, "session not found or expired", http.StatusNotFound)
@@ -200,6 +190,28 @@ func (h *AwareHandler) handleDelete(w http.ResponseWriter, r *http.Request) {
200190
h.inner.ServeHTTP(w, r)
201191
}
202192

193+
// reviveSession recreates an expired or missing session using the same ID.
194+
// This ensures session stability when clients don't update their Mcp-Session-Id
195+
// from response headers (e.g. Claude Desktop). The expired row (if any) is
196+
// deleted first to avoid unique constraint violations in the database store.
197+
func (h *AwareHandler) reviveSession(ctx context.Context, sessionID string, r *http.Request) error {
198+
// Remove expired-but-not-yet-cleaned row (if any) to avoid INSERT conflict.
199+
_ = h.store.Delete(ctx, sessionID)
200+
201+
now := time.Now()
202+
if err := h.store.Create(ctx, &Session{
203+
ID: sessionID,
204+
UserID: hashToken(extractToken(r)),
205+
CreatedAt: now,
206+
LastActiveAt: now,
207+
ExpiresAt: now.Add(h.ttl),
208+
State: make(map[string]any),
209+
}); err != nil {
210+
return fmt.Errorf("reviving session: %w", err)
211+
}
212+
return nil
213+
}
214+
203215
// validateOwnership checks that the request token matches the session owner.
204216
// Anonymous sessions (empty UserID) skip this check.
205217
func validateOwnership(sess *Session, r *http.Request) bool {

0 commit comments

Comments
 (0)