Skip to content

Commit 88c04f9

Browse files
authored
Added trivial tool filtering logic. (#1136)
This change adds support for `thv run --tools <name>` to only enable a subset of the tools available from an MCP server. The main focus of the implementation was to be self-contained so that we can swap it with something more appropriate at a later point in time. Note that the original idea was to use Cedar to implement a policy at startup, but this added complexity for a piece of code that should eventually be swapped out. Fixes #1003
1 parent b3bc4ed commit 88c04f9

File tree

9 files changed

+1171
-0
lines changed

9 files changed

+1171
-0
lines changed

cmd/thv/app/run_flags.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ type RunFlags struct {
6767

6868
// Execution mode
6969
Foreground bool
70+
71+
// Tools filter
72+
ToolsFilter []string
7073
}
7174

7275
// AddRunFlags adds all the run flags to a command
@@ -146,6 +149,12 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
146149
"Isolate the container network from the host (default: false)")
147150
cmd.Flags().StringArrayVarP(&config.Labels, "label", "l", []string{}, "Set labels on the container (format: key=value)")
148151
cmd.Flags().BoolVarP(&config.Foreground, "foreground", "f", false, "Run in foreground mode (block until container exits)")
152+
cmd.Flags().StringArrayVar(
153+
&config.ToolsFilter,
154+
"tools",
155+
nil,
156+
"Filter MCP server tools (comma-separated list of tool names)",
157+
)
149158
}
150159

151160
// BuildRunnerConfig creates a runner.RunConfig from the configuration
@@ -252,6 +261,7 @@ func BuildRunnerConfig(
252261
envVarValidator,
253262
types.ProxyMode(runConfig.ProxyMode),
254263
runConfig.Group,
264+
runConfig.ToolsFilter,
255265
)
256266
}
257267

docs/cli/thv_run.md

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/mcp/tool_filter.go

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
package mcp
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"strings"
11+
12+
"github.com/stacklok/toolhive/pkg/logger"
13+
"github.com/stacklok/toolhive/pkg/transport/types"
14+
)
15+
16+
var errToolNameNotFound = errors.New("tool name not found")
17+
var errToolNotInFilter = errors.New("tool not in filter")
18+
var errBug = errors.New("there's a bug")
19+
20+
// NewToolFilterMiddleware creates an HTTP middleware that parses SSE responses
21+
// and plain JSON objects to extract tool names from JSON-RPC messages containing
22+
// tool lists or tool calls.
23+
//
24+
// The middleware looks for SSE events with:
25+
// - event: message
26+
// - data: {"jsonrpc":"2.0","id":X,"result":{"tools":[...]}}
27+
//
28+
// When it finds such messages, it prints the name of each tool in the list.
29+
// If filterTools is provided, only tools in that list will be logged.
30+
// If filterTools is nil or empty, all tools will be logged.
31+
//
32+
// This middleware is designed to be used ONLY when tool filtering is enabled,
33+
// and expects the list of tools to be "correct" (i.e. not empty and not
34+
// containing nonexisting tools).
35+
func NewToolFilterMiddleware(filterTools []string) (types.Middleware, error) {
36+
if len(filterTools) == 0 {
37+
return nil, fmt.Errorf("tools list for filtering is empty")
38+
}
39+
40+
toolsMap := make(map[string]struct{})
41+
for _, tool := range filterTools {
42+
toolsMap[tool] = struct{}{}
43+
}
44+
45+
return func(next http.Handler) http.Handler {
46+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
47+
// NOTE: this middleware only checks the response body, whose
48+
// format at this point is not yet known and might be either a
49+
// JSON payload or an SSE stream.
50+
//
51+
// The way this is implemented is that we wrap the response writer
52+
// in order to buffer the response body. Once Flush() is called, we
53+
// process the buffer according to its content type and possibly
54+
// modify it before returning it to the client.
55+
rw := &toolFilterWriter{
56+
ResponseWriter: w,
57+
filterTools: toolsMap,
58+
}
59+
60+
// Call the next handler
61+
next.ServeHTTP(rw, r)
62+
})
63+
}, nil
64+
}
65+
66+
// NewToolCallFilterMiddleware creates an HTTP middleware that parses tool call
67+
// requests and filters out tools that are not in the filter list.
68+
//
69+
// The middleware looks for JSON-RPC messages with:
70+
// - method: tool/call
71+
// - params: {"name": "tool_name"}
72+
//
73+
// This middleware is designed to be used ONLY when tool filtering is enabled,
74+
// and expects the list of tools to be "correct" (i.e. not empty and not
75+
// containing nonexisting tools).
76+
func NewToolCallFilterMiddleware(filterTools []string) (types.Middleware, error) {
77+
if len(filterTools) == 0 {
78+
return nil, fmt.Errorf("tools list for filtering is empty")
79+
}
80+
81+
toolsMap := make(map[string]struct{})
82+
for _, tool := range filterTools {
83+
toolsMap[tool] = struct{}{}
84+
}
85+
86+
return func(next http.Handler) http.Handler {
87+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
88+
// Read the request body
89+
bodyBytes, err := io.ReadAll(r.Body)
90+
if err != nil {
91+
// If we can't read the body, let the next handler deal with it
92+
next.ServeHTTP(w, r)
93+
return
94+
}
95+
96+
// Restore the request body for downstream handlers
97+
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
98+
99+
// Try to parse the request as a tool call request. If it succeeds,
100+
// check if the tool is in the filter. If it is not a tool call request,
101+
// just pass it through.
102+
var toolCallRequest toolCallRequest
103+
err = json.Unmarshal(bodyBytes, &toolCallRequest)
104+
if err == nil && toolCallRequest.Params != nil && toolCallRequest.Method == "tools/call" {
105+
err = processToolCallRequest(toolsMap, toolCallRequest)
106+
107+
// NOTE: ideally, trying to call that was filtered out by config should be
108+
// equivalent to calling a nonexisting tool; in such cases and when the SSE
109+
// transport is used, the behaviour of the official Python SDK is to return
110+
// a 202 Accepted to THIS call and return an success message in the SSE
111+
// stream saying that the tool does not exist.
112+
//
113+
// It basically fails successfully.
114+
//
115+
// Unfortunately, implementing this behaviour is not trivial and requires
116+
// session management, as the SSE stream is managed by the proxy in an entirely
117+
// different thread of execution. As a consequence, the best thing we can
118+
// do that is still compliant with the spec is to return a 400 Bad Request
119+
// to the client.
120+
if errors.Is(err, errToolNotInFilter) {
121+
w.WriteHeader(http.StatusBadRequest)
122+
return
123+
}
124+
if err != nil {
125+
logger.Errorf("Error processing tool call of a filtered tool: %v", err)
126+
next.ServeHTTP(w, r)
127+
return
128+
}
129+
}
130+
131+
next.ServeHTTP(w, r)
132+
})
133+
}, nil
134+
}
135+
136+
// toolFilterWriter wraps http.ResponseWriter to capture and process SSE responses
137+
type toolFilterWriter struct {
138+
http.ResponseWriter
139+
buffer []byte
140+
filterTools map[string]struct{}
141+
}
142+
143+
// WriteHeader captures the status code
144+
func (rw *toolFilterWriter) WriteHeader(statusCode int) {
145+
rw.ResponseWriter.WriteHeader(statusCode)
146+
}
147+
148+
// Write captures the response body and processes SSE events
149+
func (rw *toolFilterWriter) Write(data []byte) (int, error) {
150+
rw.buffer = append(rw.buffer, data...)
151+
return len(data), nil
152+
}
153+
154+
// Flush processes any remaining buffered data and writes it to the underlying ResponseWriter
155+
func (rw *toolFilterWriter) Flush() {
156+
if len(rw.buffer) > 0 {
157+
mimeType := strings.Split(rw.ResponseWriter.Header().Get("Content-Type"), ";")[0]
158+
159+
if mimeType == "" {
160+
_, err := rw.ResponseWriter.Write(rw.buffer)
161+
if err != nil {
162+
logger.Errorf("Error writing buffer: %v", err)
163+
}
164+
return
165+
}
166+
167+
var b bytes.Buffer
168+
if err := processBuffer(rw.filterTools, rw.buffer, mimeType, &b); err != nil {
169+
logger.Errorf("Error flushing response: %v", err)
170+
}
171+
172+
_, err := rw.ResponseWriter.Write(b.Bytes())
173+
if err != nil {
174+
logger.Errorf("Error writing buffer: %v", err)
175+
}
176+
rw.buffer = rw.buffer[:0] // Reset buffer
177+
}
178+
179+
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
180+
flusher.Flush()
181+
}
182+
}
183+
184+
type toolsListResponse struct {
185+
JSONRPC string `json:"jsonrpc"`
186+
ID any `json:"id"`
187+
Result struct {
188+
Tools *[]map[string]any `json:"tools"`
189+
} `json:"result,omitempty"`
190+
}
191+
192+
type toolCallRequest struct {
193+
JSONRPC string `json:"jsonrpc"`
194+
ID any `json:"id"`
195+
Method string `json:"method"`
196+
Params *map[string]any `json:"params,omitempty"`
197+
}
198+
199+
// processSSEBuffer processes any complete SSE events in the buffer
200+
func processBuffer(filterTools map[string]struct{}, buffer []byte, mimeType string, w io.Writer) error {
201+
if len(buffer) == 0 {
202+
return nil
203+
}
204+
205+
switch mimeType {
206+
case "application/json":
207+
var toolsListResponse toolsListResponse
208+
err := json.Unmarshal(buffer, &toolsListResponse)
209+
if err == nil && toolsListResponse.Result.Tools != nil {
210+
return processToolsListResponse(filterTools, toolsListResponse, w)
211+
}
212+
case "text/event-stream":
213+
return processSSEEvents(filterTools, buffer, w)
214+
default:
215+
// NOTE: Content-Type header is mandatory in the spec, and as of the
216+
// time of this writing, the only allowed content types are
217+
// * application/json, and
218+
// * text/event-stream
219+
//
220+
// As a result, we should never get here and it is safe to return an
221+
// error.
222+
return fmt.Errorf("unsupported mime type: %s", mimeType)
223+
}
224+
225+
return fmt.Errorf("%w: tool filtering middleware", errBug)
226+
}
227+
228+
//nolint:gocyclo
229+
func processSSEEvents(filterTools map[string]struct{}, buffer []byte, w io.Writer) error {
230+
var linesep []byte
231+
if bytes.Contains(buffer, []byte("\r\n")) {
232+
linesep = []byte("\r\n")
233+
} else if bytes.Contains(buffer, []byte("\n")) {
234+
linesep = []byte("\n")
235+
} else if bytes.Contains(buffer, []byte("\r")) {
236+
linesep = []byte("\r")
237+
} else {
238+
return fmt.Errorf("unsupported separator: %s", string(buffer))
239+
}
240+
241+
var linesepTotal, linesepCount int
242+
linesepTotal = bytes.Count(buffer, linesep)
243+
lines := bytes.Split(buffer, linesep)
244+
for _, line := range lines {
245+
if len(line) == 0 {
246+
continue
247+
}
248+
249+
var written bool
250+
if data, ok := bytes.CutPrefix(line, []byte("data:")); ok {
251+
var toolsListResponse toolsListResponse
252+
if err := json.Unmarshal(data, &toolsListResponse); err == nil && toolsListResponse.Result.Tools != nil {
253+
if err := processToolsListResponse(filterTools, toolsListResponse, w); err != nil {
254+
return err
255+
}
256+
written = true
257+
}
258+
}
259+
260+
if !written {
261+
_, err := w.Write(line)
262+
if err != nil {
263+
return fmt.Errorf("%w: %v", errBug, err)
264+
}
265+
}
266+
267+
_, err := w.Write(linesep)
268+
if err != nil {
269+
return fmt.Errorf("%w: %v", errBug, err)
270+
}
271+
linesepCount++
272+
}
273+
274+
// This ensures we don't send too few line separators, which might break
275+
// SSE parsing.
276+
if linesepCount < linesepTotal {
277+
_, err := w.Write(linesep)
278+
if err != nil {
279+
return fmt.Errorf("%w: %v", errBug, err)
280+
}
281+
}
282+
283+
return nil
284+
}
285+
286+
// processToolsListResponse processes a tools list response filtering out
287+
// tools that are not in the filter list.
288+
func processToolsListResponse(filterTools map[string]struct{}, toolsListResponse toolsListResponse, w io.Writer) error {
289+
filteredTools := []map[string]any{}
290+
for _, tool := range *toolsListResponse.Result.Tools {
291+
toolName, ok := tool["name"].(string)
292+
if !ok {
293+
return errToolNameNotFound
294+
}
295+
296+
if isToolInFilter(filterTools, toolName) {
297+
filteredTools = append(filteredTools, tool)
298+
}
299+
}
300+
301+
toolsListResponse.Result.Tools = &filteredTools
302+
if err := json.NewEncoder(w).Encode(toolsListResponse); err != nil {
303+
return fmt.Errorf("%w: %v", errBug, err)
304+
}
305+
306+
return nil
307+
}
308+
309+
// processToolCallRequest processes a tool call request checking if the tool
310+
// is in the filter list.
311+
func processToolCallRequest(filterTools map[string]struct{}, toolCallRequest toolCallRequest) error {
312+
toolName, ok := (*toolCallRequest.Params)["name"].(string)
313+
if !ok {
314+
return errToolNameNotFound
315+
}
316+
317+
if isToolInFilter(filterTools, toolName) {
318+
return nil
319+
}
320+
321+
return errToolNotInFilter
322+
}
323+
324+
// isToolInFilter checks if a tool name is in the filter
325+
func isToolInFilter(filterTools map[string]struct{}, toolName string) bool {
326+
_, ok := filterTools[toolName]
327+
return ok
328+
}

0 commit comments

Comments
 (0)