4
4
"context"
5
5
"encoding/json"
6
6
"fmt"
7
+ "net/url"
7
8
"os"
9
+ "strings"
8
10
"text/tabwriter"
9
11
"time"
10
12
@@ -13,13 +15,17 @@ import (
13
15
"github.com/spf13/cobra"
14
16
15
17
"github.com/stacklok/toolhive/pkg/logger"
18
+ "github.com/stacklok/toolhive/pkg/transport/ssecommon"
19
+ "github.com/stacklok/toolhive/pkg/transport/streamable"
20
+ "github.com/stacklok/toolhive/pkg/transport/types"
16
21
"github.com/stacklok/toolhive/pkg/versions"
17
22
)
18
23
19
24
var (
20
25
mcpServerURL string
21
26
mcpFormat string
22
27
mcpTimeout time.Duration
28
+ mcpTransport string
23
29
)
24
30
25
31
func newMCPCommand () * cobra.Command {
@@ -80,6 +86,7 @@ func addMCPFlags(cmd *cobra.Command) {
80
86
cmd .Flags ().StringVar (& mcpServerURL , "server" , "" , "MCP server URL (required)" )
81
87
cmd .Flags ().StringVar (& mcpFormat , "format" , FormatText , "Output format (json or text)" )
82
88
cmd .Flags ().DurationVar (& mcpTimeout , "timeout" , 30 * time .Second , "Connection timeout" )
89
+ cmd .Flags ().StringVar (& mcpTransport , "transport" , "auto" , "Transport type (auto, sse, streamable-http)" )
83
90
_ = cmd .MarkFlagRequired ("server" )
84
91
}
85
92
@@ -197,15 +204,67 @@ func mcpListPromptsCmdFunc(cmd *cobra.Command, _ []string) error {
197
204
return outputMCPData (map [string ]interface {}{"prompts" : result .Prompts }, mcpFormat )
198
205
}
199
206
200
- // createMCPClient creates an MCP client based on the server URL
207
+ // createMCPClient creates an MCP client based on the server URL and transport type
201
208
func createMCPClient () (* client.Client , error ) {
202
- // For now, we'll use SSE client as the default
203
- // In the future, we could auto-detect or allow specifying the transport type
204
- mcpClient , err := client .NewSSEMCPClient (mcpServerURL )
209
+ transportType := determineTransportType (mcpServerURL , mcpTransport )
210
+
211
+ switch transportType {
212
+ case types .TransportTypeSSE :
213
+ mcpClient , err := client .NewSSEMCPClient (mcpServerURL )
214
+ if err != nil {
215
+ return nil , fmt .Errorf ("failed to create SSE MCP client: %w" , err )
216
+ }
217
+ return mcpClient , nil
218
+ case types .TransportTypeStreamableHTTP :
219
+ mcpClient , err := client .NewStreamableHttpClient (mcpServerURL )
220
+ if err != nil {
221
+ return nil , fmt .Errorf ("failed to create Streamable HTTP MCP client: %w" , err )
222
+ }
223
+ return mcpClient , nil
224
+ case types .TransportTypeStdio :
225
+ return nil , fmt .Errorf ("stdio transport is not supported for MCP client connections" )
226
+ case types .TransportTypeInspector :
227
+ return nil , fmt .Errorf ("inspector transport is not supported for MCP client connections" )
228
+ default :
229
+ return nil , fmt .Errorf ("unsupported transport type: %s" , transportType )
230
+ }
231
+ }
232
+
233
+ // determineTransportType determines the transport type based on URL path and user preference
234
+ func determineTransportType (serverURL , transportFlag string ) types.TransportType {
235
+ // If user explicitly specified a transport type, use it (unless it's "auto")
236
+ if transportFlag != "auto" {
237
+ switch transportFlag {
238
+ case string (types .TransportTypeSSE ):
239
+ return types .TransportTypeSSE
240
+ case string (types .TransportTypeStreamableHTTP ):
241
+ return types .TransportTypeStreamableHTTP
242
+ }
243
+ }
244
+
245
+ // Auto-detect based on URL path
246
+ parsedURL , err := url .Parse (serverURL )
205
247
if err != nil {
206
- return nil , fmt .Errorf ("failed to create MCP client: %w" , err )
248
+ // If we can't parse the URL, default to SSE for backward compatibility
249
+ logger .Warnf ("Failed to parse server URL %s, defaulting to SSE transport: %v" , serverURL , err )
250
+ return types .TransportTypeSSE
207
251
}
208
- return mcpClient , nil
252
+
253
+ path := parsedURL .Path
254
+
255
+ // Check for streamable HTTP endpoint (/mcp)
256
+ if strings .HasSuffix (path , "/" + streamable .HTTPStreamableHTTPEndpoint ) ||
257
+ strings .HasSuffix (path , streamable .HTTPStreamableHTTPEndpoint ) {
258
+ return types .TransportTypeStreamableHTTP
259
+ }
260
+
261
+ // Check for SSE endpoint (/sse)
262
+ if strings .HasSuffix (path , ssecommon .HTTPSSEEndpoint ) {
263
+ return types .TransportTypeSSE
264
+ }
265
+
266
+ // Default to SSE for backward compatibility
267
+ return types .TransportTypeSSE
209
268
}
210
269
211
270
// initializeMCPClient initializes the MCP client connection
0 commit comments