Skip to content

Commit e8e6f2e

Browse files
yroblataskbotCopilot
authored
feat: add proxy stdio subcommand to enable Claude Desktop (#1236)
Co-authored-by: taskbot <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent d6b6087 commit e8e6f2e

File tree

7 files changed

+648
-3
lines changed

7 files changed

+648
-3
lines changed

cmd/thv/app/proxy.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,10 @@ func init() {
171171
if err := proxyCmd.MarkFlagRequired("target-uri"); err != nil {
172172
logger.Warnf("Warning: Failed to mark flag as required: %v", err)
173173
}
174-
175-
// Attach the subcommand to the main proxy command
174+
// Attach the subcommands to the main proxy command
176175
proxyCmd.AddCommand(proxyTunnelCmd)
176+
proxyCmd.AddCommand(proxyStdioCmd)
177+
177178
}
178179

179180
func proxyCmdFunc(cmd *cobra.Command, args []string) error {

cmd/thv/app/proxy_stdio.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package app
2+
3+
import (
4+
"fmt"
5+
"os/signal"
6+
"syscall"
7+
8+
"github.com/spf13/cobra"
9+
10+
"github.com/stacklok/toolhive/pkg/logger"
11+
"github.com/stacklok/toolhive/pkg/transport"
12+
"github.com/stacklok/toolhive/pkg/workloads"
13+
)
14+
15+
var proxyStdioCmd = &cobra.Command{
16+
Use: "stdio WORKLOAD-NAME",
17+
Short: "Create a stdio-based proxy for an MCP server",
18+
Long: `Create a stdio-based proxy that connects stdin/stdout to a target MCP server.
19+
20+
Example:
21+
thv proxy stdio my-workload
22+
`,
23+
Args: cobra.ExactArgs(1),
24+
RunE: proxyStdioCmdFunc,
25+
}
26+
27+
func proxyStdioCmdFunc(cmd *cobra.Command, args []string) error {
28+
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
29+
defer cancel()
30+
31+
workloadName := args[0]
32+
workloadManager, err := workloads.NewManager(ctx)
33+
if err != nil {
34+
return fmt.Errorf("failed to create workload manager: %w", err)
35+
}
36+
stdioWorkload, err := workloadManager.GetWorkload(ctx, workloadName)
37+
if err != nil {
38+
return fmt.Errorf("failed to get workload %q: %w", workloadName, err)
39+
}
40+
logger.Infof("Starting stdio proxy for workload=%q", workloadName)
41+
42+
bridge, err := transport.NewStdioBridge(stdioWorkload.URL, stdioWorkload.TransportType)
43+
if err != nil {
44+
return fmt.Errorf("failed to create stdio bridge: %w", err)
45+
}
46+
bridge.Start(ctx)
47+
48+
// Consume until interrupt
49+
<-ctx.Done()
50+
logger.Info("Shutting down bridge")
51+
bridge.Shutdown()
52+
return nil
53+
}

docs/cli/thv_proxy.md

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/cli/thv_proxy_stdio.md

Lines changed: 43 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/container/images/registry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/docker/docker/client"
1010
"github.com/google/go-containerregistry/pkg/authn"
1111
"github.com/google/go-containerregistry/pkg/name"
12-
"github.com/google/go-containerregistry/pkg/v1"
12+
v1 "github.com/google/go-containerregistry/pkg/v1"
1313
"github.com/google/go-containerregistry/pkg/v1/daemon"
1414
"github.com/google/go-containerregistry/pkg/v1/remote"
1515

pkg/transport/bridge.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package transport
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"strings"
8+
"sync"
9+
10+
"github.com/mark3labs/mcp-go/client"
11+
"github.com/mark3labs/mcp-go/client/transport"
12+
"github.com/mark3labs/mcp-go/mcp"
13+
"github.com/mark3labs/mcp-go/server"
14+
15+
"github.com/stacklok/toolhive/pkg/logger"
16+
"github.com/stacklok/toolhive/pkg/transport/types"
17+
)
18+
19+
// StdioBridge connects stdin/stdout to a target MCP server using the specified transport type.
20+
type StdioBridge struct {
21+
mode types.TransportType
22+
rawTarget string // upstream base URL
23+
24+
up *client.Client
25+
srv *server.MCPServer
26+
27+
wg sync.WaitGroup
28+
cancel context.CancelFunc
29+
}
30+
31+
// NewStdioBridge creates a new StdioBridge instance for the given target URL and transport type.
32+
func NewStdioBridge(rawURL string, mode types.TransportType) (*StdioBridge, error) {
33+
return &StdioBridge{mode: mode, rawTarget: rawURL}, nil
34+
}
35+
36+
// Start initializes the bridge and connects to the upstream MCP server.
37+
func (b *StdioBridge) Start(ctx context.Context) {
38+
ctx, b.cancel = context.WithCancel(ctx)
39+
b.wg.Add(1)
40+
go b.run(ctx)
41+
}
42+
43+
// Shutdown gracefully stops the bridge, closing connections and waiting for cleanup.
44+
func (b *StdioBridge) Shutdown() {
45+
if b.cancel != nil {
46+
b.cancel()
47+
}
48+
if b.up != nil {
49+
_ = b.up.Close()
50+
}
51+
b.wg.Wait()
52+
}
53+
54+
func (b *StdioBridge) run(ctx context.Context) {
55+
logger.Infof("Starting StdioBridge for %s in mode %s", b.rawTarget, b.mode)
56+
defer b.wg.Done()
57+
58+
up, err := b.connectUpstream(ctx)
59+
if err != nil {
60+
logger.Errorf("upstream connect failed: %v", err)
61+
return
62+
}
63+
b.up = up
64+
logger.Infof("Connected to upstream %s", b.rawTarget)
65+
66+
if err := b.initializeUpstream(ctx); err != nil {
67+
logger.Errorf("upstream initialize failed: %v", err)
68+
return
69+
}
70+
logger.Infof("Upstream initialized successfully")
71+
72+
// Tiny local stdio server
73+
b.srv = server.NewMCPServer(
74+
"toolhive-stdio-bridge",
75+
"0.1.0",
76+
server.WithToolCapabilities(true),
77+
server.WithResourceCapabilities(true, true),
78+
server.WithPromptCapabilities(true),
79+
)
80+
logger.Infof("Starting local stdio server")
81+
82+
b.up.OnConnectionLost(func(err error) { logger.Warnf("upstream lost: %v", err) })
83+
84+
// Handle upstream notifications
85+
b.up.OnNotification(func(n mcp.JSONRPCNotification) {
86+
logger.Infof("upstream → downstream notify: %s %v", n.Method, n.Params)
87+
// Convert the Params struct to JSON and back to a generic map
88+
var params map[string]any
89+
if buf, err := json.Marshal(n.Params); err != nil {
90+
logger.Warnf("Failed to marshal params: %v", err)
91+
params = map[string]any{}
92+
} else if err := json.Unmarshal(buf, &params); err != nil {
93+
logger.Warnf("Failed to unmarshal to map: %v", err)
94+
params = map[string]any{}
95+
}
96+
97+
b.srv.SendNotificationToAllClients(n.Method, params)
98+
})
99+
100+
// Forwarders (register once; no pagination/refresh to keep it simple)
101+
b.forwardAll(ctx)
102+
103+
// Serve stdio (blocks)
104+
if err := server.ServeStdio(b.srv); err != nil {
105+
logger.Errorf("stdio server error: %v", err)
106+
}
107+
}
108+
109+
func (b *StdioBridge) connectUpstream(_ context.Context) (*client.Client, error) {
110+
logger.Infof("Connecting to upstream %s using mode %s", b.rawTarget, b.mode)
111+
112+
switch b.mode {
113+
case types.TransportTypeStreamableHTTP:
114+
c, err := client.NewStreamableHttpClient(
115+
b.rawTarget,
116+
transport.WithHTTPTimeout(0),
117+
transport.WithContinuousListening(),
118+
)
119+
if err != nil {
120+
return nil, err
121+
}
122+
// use separate, never-ending context for the client
123+
if err := c.Start(context.Background()); err != nil {
124+
return nil, err
125+
}
126+
return c, nil
127+
case types.TransportTypeSSE:
128+
c, err := client.NewSSEMCPClient(
129+
b.rawTarget,
130+
)
131+
if err != nil {
132+
return nil, err
133+
}
134+
if err := c.Start(context.Background()); err != nil {
135+
return nil, err
136+
}
137+
return c, nil
138+
case types.TransportTypeStdio:
139+
// if url contains sse it's sse else streamable-http
140+
var c *client.Client
141+
var err error
142+
if strings.Contains(b.rawTarget, "sse") {
143+
c, err = client.NewSSEMCPClient(
144+
b.rawTarget,
145+
)
146+
if err != nil {
147+
return nil, err
148+
}
149+
} else {
150+
c, err = client.NewStreamableHttpClient(
151+
b.rawTarget,
152+
)
153+
if err != nil {
154+
return nil, err
155+
}
156+
}
157+
if err := c.Start(context.Background()); err != nil {
158+
return nil, err
159+
}
160+
return c, nil
161+
case types.TransportTypeInspector:
162+
fallthrough
163+
default:
164+
return nil, fmt.Errorf("unsupported mode %q", b.mode)
165+
}
166+
}
167+
168+
func (b *StdioBridge) initializeUpstream(ctx context.Context) error {
169+
logger.Infof("Initializing upstream %s", b.rawTarget)
170+
_, err := b.up.Initialize(ctx, mcp.InitializeRequest{
171+
Params: mcp.InitializeParams{
172+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
173+
ClientInfo: mcp.Implementation{Name: "toolhive-bridge", Version: "0.1.0"},
174+
Capabilities: mcp.ClientCapabilities{},
175+
},
176+
})
177+
if err != nil {
178+
return err
179+
}
180+
return nil
181+
}
182+
183+
func (b *StdioBridge) forwardAll(ctx context.Context) {
184+
logger.Infof("Forwarding all upstream data to local stdio server")
185+
// Tools -> straight passthrough
186+
logger.Infof("Forwarding tools from upstream to local stdio server")
187+
if lt, err := b.up.ListTools(ctx, mcp.ListToolsRequest{}); err == nil {
188+
for _, tool := range lt.Tools {
189+
toolCopy := tool
190+
b.srv.AddTool(toolCopy, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
191+
return b.up.CallTool(ctx, req)
192+
})
193+
}
194+
}
195+
196+
// Resources -> return []mcp.ResourceContents
197+
logger.Infof("Forwarding resources from upstream to local stdio server")
198+
if lr, err := b.up.ListResources(ctx, mcp.ListResourcesRequest{}); err == nil {
199+
for _, res := range lr.Resources {
200+
resCopy := res
201+
b.srv.AddResource(resCopy, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
202+
out, err := b.up.ReadResource(ctx, req)
203+
if err != nil {
204+
return nil, err
205+
}
206+
return out.Contents, nil
207+
})
208+
}
209+
}
210+
211+
// Resource templates -> same return type as resources
212+
logger.Infof("Forwarding resource templates from upstream to local stdio server")
213+
if lt, err := b.up.ListResourceTemplates(ctx, mcp.ListResourceTemplatesRequest{}); err == nil {
214+
for _, tpl := range lt.ResourceTemplates {
215+
tplCopy := tpl
216+
b.srv.AddResourceTemplate(tplCopy, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
217+
out, err := b.up.ReadResource(ctx, req)
218+
if err != nil {
219+
return nil, err
220+
}
221+
return out.Contents, nil
222+
})
223+
}
224+
}
225+
226+
// Prompts -> straight passthrough
227+
logger.Infof("Forwarding prompts from upstream to local stdio server")
228+
if lp, err := b.up.ListPrompts(ctx, mcp.ListPromptsRequest{}); err == nil {
229+
for _, p := range lp.Prompts {
230+
pCopy := p
231+
b.srv.AddPrompt(pCopy, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
232+
return b.up.GetPrompt(ctx, req)
233+
})
234+
}
235+
}
236+
}

0 commit comments

Comments
 (0)