Skip to content

Commit b14636f

Browse files
authored
feat: enforce platform_info as mandatory first call via session gating middleware (#164) (#165)
Strengthen the platform_info tool description with MANDATORY language and add server-side session gating middleware that blocks all non-exempt tool calls until platform_info has been invoked in the current session. - Update buildInfoToolDescription to lead with "MANDATORY first call" and name consequences of skipping (incorrect routing, rule violations) - Add MCPSessionGateMiddleware with per-session init tracking, configurable exempt tools list, TTL-based expiration, and violation logging - Add SessionGateConfig to platform config (session_gate.enabled, init_tool, exempt_tools) with defaults - Wire middleware between Auth/Authz (outer) and Audit (inner) so gated calls get a SETUP_REQUIRED error without producing audit events - Extract stopBackgroundTrackers to keep Close() under complexity limit
1 parent c073604 commit b14636f

File tree

7 files changed

+833
-14
lines changed

7 files changed

+833
-14
lines changed

pkg/middleware/mcp_session_gate.go

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
"sync"
8+
"time"
9+
10+
"github.com/modelcontextprotocol/go-sdk/mcp"
11+
)
12+
13+
// ErrCategorySetupRequired is the error category for session gate violations.
14+
const ErrCategorySetupRequired = "setup_required"
15+
16+
// defaultGateSessionTTL is the default TTL for session gate entries.
17+
const defaultGateSessionTTL = 30
18+
19+
// SessionGateConfig configures the session initialization gate middleware.
20+
type SessionGateConfig struct {
21+
// InitTool is the tool that initializes the session (e.g., "platform_info").
22+
InitTool string
23+
24+
// ExemptTools lists tool names that bypass the gate.
25+
ExemptTools []string
26+
27+
// SessionTTL is how long an initialized session is remembered.
28+
// Defaults to 30 minutes.
29+
SessionTTL time.Duration
30+
31+
// CleanupInterval is how often the cleanup routine runs.
32+
// Defaults to 1 minute.
33+
CleanupInterval time.Duration
34+
}
35+
36+
// SessionGate tracks which sessions have called the init tool.
37+
// It is safe for concurrent use.
38+
type SessionGate struct {
39+
mu sync.RWMutex
40+
sessions map[string]time.Time // session ID → initialization time
41+
initTool string
42+
exemptSet map[string]bool
43+
sessionTTL time.Duration
44+
done chan struct{}
45+
gateCount int64 // total gating violations
46+
retryCount int64 // total successful retries (init after gate)
47+
}
48+
49+
// NewSessionGate creates a new session gate tracker.
50+
func NewSessionGate(cfg SessionGateConfig) *SessionGate {
51+
ttl := cfg.SessionTTL
52+
if ttl == 0 {
53+
ttl = defaultGateSessionTTL * time.Minute
54+
}
55+
56+
exemptSet := make(map[string]bool, len(cfg.ExemptTools)+1)
57+
// The init tool itself is always exempt.
58+
exemptSet[cfg.InitTool] = true
59+
for _, t := range cfg.ExemptTools {
60+
exemptSet[t] = true
61+
}
62+
63+
return &SessionGate{
64+
sessions: make(map[string]time.Time),
65+
initTool: cfg.InitTool,
66+
exemptSet: exemptSet,
67+
sessionTTL: ttl,
68+
done: make(chan struct{}),
69+
}
70+
}
71+
72+
// RecordInit marks a session as initialized.
73+
func (g *SessionGate) RecordInit(sessionID string) {
74+
g.mu.Lock()
75+
defer g.mu.Unlock()
76+
77+
_, existed := g.sessions[sessionID]
78+
g.sessions[sessionID] = time.Now()
79+
80+
if existed {
81+
g.retryCount++
82+
}
83+
}
84+
85+
// IsInitialized returns true if the session has called the init tool.
86+
func (g *SessionGate) IsInitialized(sessionID string) bool {
87+
g.mu.RLock()
88+
defer g.mu.RUnlock()
89+
90+
initTime, ok := g.sessions[sessionID]
91+
if !ok {
92+
return false
93+
}
94+
// Check TTL expiration.
95+
return time.Since(initTime) < g.sessionTTL
96+
}
97+
98+
// IsExempt returns true if the tool bypasses the gate.
99+
func (g *SessionGate) IsExempt(toolName string) bool {
100+
return g.exemptSet[toolName]
101+
}
102+
103+
// IncrementGateCount increments the gating violation counter and returns the new count.
104+
func (g *SessionGate) IncrementGateCount() int64 {
105+
g.mu.Lock()
106+
defer g.mu.Unlock()
107+
g.gateCount++
108+
return g.gateCount
109+
}
110+
111+
// Stats returns current gate statistics.
112+
func (g *SessionGate) Stats() (gateViolations, retries, activeSessions int64) {
113+
g.mu.RLock()
114+
defer g.mu.RUnlock()
115+
return g.gateCount, g.retryCount, int64(len(g.sessions))
116+
}
117+
118+
// StartCleanup starts a background goroutine that evicts expired sessions.
119+
func (g *SessionGate) StartCleanup(interval time.Duration) {
120+
if interval == 0 {
121+
interval = 1 * time.Minute
122+
}
123+
go func() {
124+
ticker := time.NewTicker(interval)
125+
defer ticker.Stop()
126+
127+
for {
128+
select {
129+
case <-g.done:
130+
return
131+
case <-ticker.C:
132+
g.cleanup()
133+
}
134+
}
135+
}()
136+
}
137+
138+
// Stop stops the background cleanup goroutine.
139+
func (g *SessionGate) Stop() {
140+
close(g.done)
141+
}
142+
143+
// cleanup evicts sessions that have expired.
144+
func (g *SessionGate) cleanup() {
145+
g.mu.Lock()
146+
defer g.mu.Unlock()
147+
148+
now := time.Now()
149+
for id, initTime := range g.sessions {
150+
if now.Sub(initTime) > g.sessionTTL {
151+
delete(g.sessions, id)
152+
}
153+
}
154+
}
155+
156+
// MCPSessionGateMiddleware creates MCP protocol-level middleware that gates
157+
// all tool calls until the init tool (e.g., platform_info) has been called
158+
// in the current session.
159+
//
160+
// This middleware must be positioned INNER to MCPToolCallMiddleware so that
161+
// PlatformContext (with SessionID and ToolName) is available. It should be
162+
// positioned OUTER to rule enforcement and enrichment so that gated calls
163+
// never reach those layers.
164+
func MCPSessionGateMiddleware(gate *SessionGate) mcp.Middleware {
165+
return func(next mcp.MethodHandler) mcp.MethodHandler {
166+
return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) {
167+
if method != methodToolsCall {
168+
return next(ctx, method, req)
169+
}
170+
pc := GetPlatformContext(ctx)
171+
if pc == nil {
172+
return next(ctx, method, req)
173+
}
174+
if errResult := gate.checkAccess(pc); errResult != nil {
175+
return errResult, nil
176+
}
177+
return next(ctx, method, req)
178+
}
179+
}
180+
}
181+
182+
// checkAccess evaluates whether a tool call should proceed or be gated.
183+
// Returns nil if the call is allowed; returns an error result if gated.
184+
func (g *SessionGate) checkAccess(pc *PlatformContext) mcp.Result {
185+
// The init tool itself records initialization and is always allowed.
186+
if pc.ToolName == g.initTool {
187+
g.RecordInit(pc.SessionID)
188+
return nil
189+
}
190+
191+
// Exempt tools bypass the gate.
192+
if g.IsExempt(pc.ToolName) {
193+
return nil
194+
}
195+
196+
// Initialized sessions proceed normally.
197+
if g.IsInitialized(pc.SessionID) {
198+
return nil
199+
}
200+
201+
// Gate violation: session not initialized.
202+
count := g.IncrementGateCount()
203+
slog.Warn("session gate: tool called before platform_info",
204+
"tool", pc.ToolName,
205+
"session_id", pc.SessionID,
206+
"user_id", pc.UserID,
207+
"total_violations", count,
208+
)
209+
return createSessionGateError(g.initTool, pc.ToolName)
210+
}
211+
212+
// createSessionGateError builds a SETUP_REQUIRED error result.
213+
func createSessionGateError(initTool, blockedTool string) mcp.Result {
214+
msg := fmt.Sprintf(
215+
"SETUP_REQUIRED: You must call %s before using %s (or any other tool). "+
216+
"%s contains critical agent instructions for query routing, operational rules, "+
217+
"and platform capabilities. Call %s first, then retry your request.",
218+
initTool, blockedTool, initTool, initTool,
219+
)
220+
221+
result := &mcp.CallToolResult{}
222+
result.SetError(&PlatformError{
223+
Category: ErrCategorySetupRequired,
224+
Message: msg,
225+
})
226+
return result
227+
}

0 commit comments

Comments
 (0)