diff --git a/sast-engine/cmd/serve.go b/sast-engine/cmd/serve.go index 146749c9..7507efd5 100644 --- a/sast-engine/cmd/serve.go +++ b/sast-engine/cmd/serve.go @@ -1,8 +1,11 @@ package cmd import ( + "context" "fmt" "os" + "os/signal" + "syscall" "time" "github.com/shivasurya/code-pathfinder/sast-engine/graph" @@ -16,13 +19,17 @@ import ( var serveCmd = &cobra.Command{ Use: "serve", Short: "Start MCP server for AI coding assistants", - Long: `Builds code index and starts MCP server on stdio. + Long: `Builds code index and starts MCP server. Designed for integration with Claude Code, Codex CLI, and other AI assistants that support the Model Context Protocol (MCP). The server indexes the codebase once at startup, then responds to queries -about symbols, call graphs, and code relationships.`, +about symbols, call graphs, and code relationships. + +Transport modes: + - stdio (default): Standard input/output for direct integration + - http: HTTP server for network access`, RunE: runServe, } @@ -30,11 +37,15 @@ func init() { rootCmd.AddCommand(serveCmd) serveCmd.Flags().StringP("project", "p", ".", "Project path to index") serveCmd.Flags().String("python-version", "3.11", "Python version for stdlib resolution") + serveCmd.Flags().Bool("http", false, "Use HTTP transport instead of stdio") + serveCmd.Flags().String("address", ":8080", "HTTP server address (only with --http)") } func runServe(cmd *cobra.Command, _ []string) error { projectPath, _ := cmd.Flags().GetString("project") pythonVersion, _ := cmd.Flags().GetString("python-version") + useHTTP, _ := cmd.Flags().GetBool("http") + address, _ := cmd.Flags().GetString("address") fmt.Fprintln(os.Stderr, "Building index...") start := time.Now() @@ -66,9 +77,45 @@ func runServe(cmd *cobra.Command, _ []string) error { fmt.Fprintf(os.Stderr, " Call edges: %d\n", len(callGraph.Edges)) fmt.Fprintf(os.Stderr, " Modules: %d\n", len(moduleRegistry.Modules)) - // 4. Create and run MCP server + // 4. Create MCP server server := mcp.NewServer(projectPath, pythonVersion, callGraph, moduleRegistry, codeGraph, buildTime) + // 5. Start appropriate transport + if useHTTP { + return runHTTPServer(server, address) + } + fmt.Fprintln(os.Stderr, "Starting MCP server on stdio...") return server.ServeStdio() } + +func runHTTPServer(mcpServer *mcp.Server, address string) error { + config := &mcp.HTTPConfig{ + Address: address, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + ShutdownTimeout: 5 * time.Second, + AllowedOrigins: []string{"*"}, + } + + httpServer := mcp.NewHTTPServer(mcpServer, config) + + // Handle graceful shutdown. + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + errChan := make(chan error, 1) + go func() { + errChan <- httpServer.Start() + }() + + select { + case err := <-errChan: + return err + case sig := <-sigChan: + fmt.Fprintf(os.Stderr, "\nReceived %v, shutting down...\n", sig) + ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout) + defer cancel() + return httpServer.Shutdown(ctx) + } +} diff --git a/sast-engine/mcp/http.go b/sast-engine/mcp/http.go new file mode 100644 index 00000000..701c2dea --- /dev/null +++ b/sast-engine/mcp/http.go @@ -0,0 +1,331 @@ +package mcp + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "time" +) + +// HTTPConfig holds configuration for the HTTP server. +type HTTPConfig struct { + Address string + ReadTimeout time.Duration + WriteTimeout time.Duration + ShutdownTimeout time.Duration + AllowedOrigins []string +} + +// DefaultHTTPConfig returns sensible defaults. +func DefaultHTTPConfig() *HTTPConfig { + return &HTTPConfig{ + Address: ":8080", + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + ShutdownTimeout: 5 * time.Second, + AllowedOrigins: []string{"*"}, + } +} + +// HTTPServer wraps the MCP server with HTTP transport. +type HTTPServer struct { + server *Server + httpServer *http.Server + config *HTTPConfig + mu sync.RWMutex + running bool +} + +// NewHTTPServer creates a new HTTP server wrapping the MCP server. +func NewHTTPServer(mcpServer *Server, config *HTTPConfig) *HTTPServer { + if config == nil { + config = DefaultHTTPConfig() + } + + return &HTTPServer{ + server: mcpServer, + config: config, + } +} + +// ServeHTTP implements http.Handler for JSON-RPC requests. +func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Set CORS headers. + h.setCORSHeaders(w, r) + + // Handle preflight requests. + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + // Only accept POST for JSON-RPC. + if r.Method != http.MethodPost { + h.writeError(w, http.StatusMethodNotAllowed, "Only POST method is allowed") + return + } + + // Verify content type. + contentType := r.Header.Get("Content-Type") + if !strings.HasPrefix(contentType, "application/json") { + h.writeError(w, http.StatusUnsupportedMediaType, "Content-Type must be application/json") + return + } + + // Read request body. + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB limit + if err != nil { + h.writeError(w, http.StatusBadRequest, "Failed to read request body") + return + } + defer r.Body.Close() + + // Parse JSON-RPC request. + var request JSONRPCRequest + if err := json.Unmarshal(body, &request); err != nil { + response := MakeErrorResponse(nil, ParseError(err.Error())) + h.writeJSON(w, http.StatusOK, response) + return + } + + // Handle the request. + response := h.server.handleRequest(&request) + + // Write response. + h.writeJSON(w, http.StatusOK, response) +} + +// Start starts the HTTP server. +func (h *HTTPServer) Start() error { + h.mu.Lock() + if h.running { + h.mu.Unlock() + return fmt.Errorf("server already running") + } + + mux := http.NewServeMux() + mux.Handle("/", h) + mux.HandleFunc("/health", h.healthHandler) + + h.httpServer = &http.Server{ + Addr: h.config.Address, + Handler: mux, + ReadTimeout: h.config.ReadTimeout, + WriteTimeout: h.config.WriteTimeout, + } + + h.running = true + h.mu.Unlock() + + // Start listening. + lc := net.ListenConfig{} + listener, err := lc.Listen(context.Background(), "tcp", h.config.Address) + if err != nil { + h.mu.Lock() + h.running = false + h.mu.Unlock() + return fmt.Errorf("failed to listen on %s: %w", h.config.Address, err) + } + + fmt.Printf("MCP HTTP server listening on %s\n", h.config.Address) + return h.httpServer.Serve(listener) +} + +// StartAsync starts the HTTP server in a goroutine and returns immediately. +func (h *HTTPServer) StartAsync() error { + h.mu.Lock() + if h.running { + h.mu.Unlock() + return fmt.Errorf("server already running") + } + + mux := http.NewServeMux() + mux.Handle("/", h) + mux.HandleFunc("/health", h.healthHandler) + + h.httpServer = &http.Server{ + Addr: h.config.Address, + Handler: mux, + ReadTimeout: h.config.ReadTimeout, + WriteTimeout: h.config.WriteTimeout, + } + + h.running = true + h.mu.Unlock() + + // Start listening. + lc := net.ListenConfig{} + listener, err := lc.Listen(context.Background(), "tcp", h.config.Address) + if err != nil { + h.mu.Lock() + h.running = false + h.mu.Unlock() + return fmt.Errorf("failed to listen on %s: %w", h.config.Address, err) + } + + go func() { + _ = h.httpServer.Serve(listener) + }() + + return nil +} + +// Shutdown gracefully shuts down the HTTP server. +func (h *HTTPServer) Shutdown(ctx context.Context) error { + h.mu.Lock() + if !h.running { + h.mu.Unlock() + return nil + } + h.running = false + h.mu.Unlock() + + if h.httpServer == nil { + return nil + } + + return h.httpServer.Shutdown(ctx) +} + +// IsRunning returns whether the server is running. +func (h *HTTPServer) IsRunning() bool { + h.mu.RLock() + defer h.mu.RUnlock() + return h.running +} + +// Address returns the configured address. +func (h *HTTPServer) Address() string { + return h.config.Address +} + +// healthHandler returns server health status. +func (h *HTTPServer) healthHandler(w http.ResponseWriter, r *http.Request) { + h.setCORSHeaders(w, r) + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + status := map[string]interface{}{ + "status": "healthy", + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + + h.writeJSON(w, http.StatusOK, status) +} + +// setCORSHeaders sets CORS headers based on configuration. +func (h *HTTPServer) setCORSHeaders(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + origin = "*" + } + + // Check if origin is allowed. + allowed := false + for _, o := range h.config.AllowedOrigins { + if o == "*" || o == origin { + allowed = true + break + } + } + + if allowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + } + + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Access-Control-Max-Age", "86400") +} + +// writeJSON writes a JSON response. +func (h *HTTPServer) writeJSON(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(data) +} + +// writeError writes an error response. +func (h *HTTPServer) writeError(w http.ResponseWriter, status int, message string) { + h.writeJSON(w, status, map[string]string{"error": message}) +} + +// SSEServer provides Server-Sent Events transport for streaming. +type SSEServer struct { + httpServer *HTTPServer +} + +// NewSSEServer creates a new SSE server. +func NewSSEServer(httpServer *HTTPServer) *SSEServer { + return &SSEServer{httpServer: httpServer} +} + +// ServeSSE handles SSE connections for streaming responses. +func (s *SSEServer) ServeSSE(w http.ResponseWriter, r *http.Request) { + // Set SSE headers. + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + s.httpServer.setCORSHeaders(w, r) + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "SSE not supported", http.StatusInternalServerError) + return + } + + // Send initial connection event. + fmt.Fprintf(w, "event: connected\ndata: {\"status\": \"connected\"}\n\n") + flusher.Flush() + + // Keep connection open until client disconnects. + <-r.Context().Done() +} + +// StreamingHTTPHandler provides a handler for streaming JSON-RPC over HTTP. +type StreamingHTTPHandler struct { + server *Server +} + +// NewStreamingHTTPHandler creates a new streaming handler. +func NewStreamingHTTPHandler(server *Server) *StreamingHTTPHandler { + return &StreamingHTTPHandler{server: server} +} + +// HandleStream processes a stream of JSON-RPC requests. +func (s *StreamingHTTPHandler) HandleStream(r io.Reader, w io.Writer) error { + scanner := bufio.NewScanner(r) + encoder := json.NewEncoder(w) + + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + + var request JSONRPCRequest + if err := json.Unmarshal([]byte(line), &request); err != nil { + response := MakeErrorResponse(nil, ParseError(err.Error())) + if err := encoder.Encode(response); err != nil { + return err + } + continue + } + + response := s.server.handleRequest(&request) + if err := encoder.Encode(response); err != nil { + return err + } + } + + return scanner.Err() +} diff --git a/sast-engine/mcp/http_test.go b/sast-engine/mcp/http_test.go new file mode 100644 index 00000000..d4187b03 --- /dev/null +++ b/sast-engine/mcp/http_test.go @@ -0,0 +1,467 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultHTTPConfig(t *testing.T) { + config := DefaultHTTPConfig() + + assert.Equal(t, ":8080", config.Address) + assert.Equal(t, 30*time.Second, config.ReadTimeout) + assert.Equal(t, 30*time.Second, config.WriteTimeout) + assert.Equal(t, 5*time.Second, config.ShutdownTimeout) + assert.Equal(t, []string{"*"}, config.AllowedOrigins) +} + +func TestNewHTTPServer(t *testing.T) { + mcpServer := createTestServer() + + t.Run("with config", func(t *testing.T) { + config := &HTTPConfig{Address: ":9090"} + httpServer := NewHTTPServer(mcpServer, config) + + assert.NotNil(t, httpServer) + assert.Equal(t, ":9090", httpServer.Address()) + }) + + t.Run("with nil config uses defaults", func(t *testing.T) { + httpServer := NewHTTPServer(mcpServer, nil) + + assert.NotNil(t, httpServer) + assert.Equal(t, ":8080", httpServer.Address()) + }) +} + +func TestHTTPServer_ServeHTTP_Initialize(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + params := map[string]interface{}{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]interface{}{ + "name": "test-client", + "version": "1.0", + }, + } + paramsBytes, _ := json.Marshal(params) + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + Params: paramsBytes, + } + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response JSONRPCResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, "2.0", response.JSONRPC) + assert.NotNil(t, response.Result) + assert.Nil(t, response.Error) +} + +func TestHTTPServer_ServeHTTP_ToolsCall(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + params := map[string]interface{}{ + "name": "get_index_info", + "arguments": map[string]interface{}{}, + } + paramsBytes, _ := json.Marshal(params) + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/call", + Params: paramsBytes, + } + + body, _ := json.Marshal(request) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Body.String(), "project_path") +} + +func TestHTTPServer_ServeHTTP_MethodNotAllowed(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + methods := []string{http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodPatch} + + for _, method := range methods { + t.Run(method, func(t *testing.T) { + req := httptest.NewRequest(method, "/", nil) + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) + assert.Contains(t, rec.Body.String(), "Only POST method is allowed") + }) + } +} + +func TestHTTPServer_ServeHTTP_WrongContentType(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) + req.Header.Set("Content-Type", "text/plain") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnsupportedMediaType, rec.Code) + assert.Contains(t, rec.Body.String(), "Content-Type must be application/json") +} + +func TestHTTPServer_ServeHTTP_InvalidJSON(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("not json")) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response JSONRPCResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + assert.NotNil(t, response.Error) + assert.Equal(t, ErrCodeParseError, response.Error.Code) +} + +func TestHTTPServer_ServeHTTP_CORS(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + t.Run("preflight request", func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + req.Header.Set("Origin", "http://example.com") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, "http://example.com", rec.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "POST, OPTIONS", rec.Header().Get("Access-Control-Allow-Methods")) + }) + + t.Run("regular request has CORS headers", func(t *testing.T) { + request := JSONRPCRequest{JSONRPC: "2.0", ID: 1, Method: "ping"} + body, _ := json.Marshal(request) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", "http://example.com") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, "http://example.com", rec.Header().Get("Access-Control-Allow-Origin")) + }) +} + +func TestHTTPServer_ServeHTTP_CORSAllowedOrigins(t *testing.T) { + mcpServer := createTestServer() + config := &HTTPConfig{ + Address: ":8080", + AllowedOrigins: []string{"http://allowed.com"}, + } + httpServer := NewHTTPServer(mcpServer, config) + + t.Run("allowed origin", func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + req.Header.Set("Origin", "http://allowed.com") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, "http://allowed.com", rec.Header().Get("Access-Control-Allow-Origin")) + }) + + t.Run("disallowed origin", func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + req.Header.Set("Origin", "http://notallowed.com") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin")) + }) +} + +func TestHTTPServer_HealthHandler(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rec := httptest.NewRecorder() + + httpServer.healthHandler(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, "healthy", response["status"]) + assert.Contains(t, response, "timestamp") +} + +func TestHTTPServer_HealthHandler_Preflight(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + req := httptest.NewRequest(http.MethodOptions, "/health", nil) + rec := httptest.NewRecorder() + + httpServer.healthHandler(rec, req) + + assert.Equal(t, http.StatusNoContent, rec.Code) +} + +func TestHTTPServer_IsRunning(t *testing.T) { + mcpServer := createTestServer() + config := &HTTPConfig{Address: "127.0.0.1:0"} // Use port 0 for random available port + httpServer := NewHTTPServer(mcpServer, config) + + assert.False(t, httpServer.IsRunning()) + + err := httpServer.StartAsync() + require.NoError(t, err) + assert.True(t, httpServer.IsRunning()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = httpServer.Shutdown(ctx) + require.NoError(t, err) + assert.False(t, httpServer.IsRunning()) +} + +func TestHTTPServer_StartAsync_AlreadyRunning(t *testing.T) { + mcpServer := createTestServer() + config := &HTTPConfig{Address: "127.0.0.1:0"} + httpServer := NewHTTPServer(mcpServer, config) + + err := httpServer.StartAsync() + require.NoError(t, err) + defer func() { + ctx := context.Background() + _ = httpServer.Shutdown(ctx) + }() + + err = httpServer.StartAsync() + assert.Error(t, err) + assert.Contains(t, err.Error(), "already running") +} + +func TestHTTPServer_Shutdown_NotRunning(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + ctx := context.Background() + err := httpServer.Shutdown(ctx) + assert.NoError(t, err) +} + +func TestNewSSEServer(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + sseServer := NewSSEServer(httpServer) + + assert.NotNil(t, sseServer) +} + +func TestNewStreamingHTTPHandler(t *testing.T) { + mcpServer := createTestServer() + handler := NewStreamingHTTPHandler(mcpServer) + + assert.NotNil(t, handler) +} + +func TestStreamingHTTPHandler_HandleStream(t *testing.T) { + mcpServer := createTestServer() + handler := NewStreamingHTTPHandler(mcpServer) + + t.Run("single request", func(t *testing.T) { + input := `{"jsonrpc":"2.0","id":1,"method":"ping"}` + "\n" + reader := strings.NewReader(input) + var output bytes.Buffer + + err := handler.HandleStream(reader, &output) + require.NoError(t, err) + + var response JSONRPCResponse + err = json.Unmarshal(output.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, "2.0", response.JSONRPC) + }) + + t.Run("multiple requests", func(t *testing.T) { + input := `{"jsonrpc":"2.0","id":1,"method":"ping"} +{"jsonrpc":"2.0","id":2,"method":"tools/list"} +` + reader := strings.NewReader(input) + var output bytes.Buffer + + err := handler.HandleStream(reader, &output) + require.NoError(t, err) + + // Should have two responses. + lines := strings.Split(strings.TrimSpace(output.String()), "\n") + assert.Len(t, lines, 2) + }) + + t.Run("empty lines skipped", func(t *testing.T) { + input := `{"jsonrpc":"2.0","id":1,"method":"ping"} + +{"jsonrpc":"2.0","id":2,"method":"ping"} +` + reader := strings.NewReader(input) + var output bytes.Buffer + + err := handler.HandleStream(reader, &output) + require.NoError(t, err) + + lines := strings.Split(strings.TrimSpace(output.String()), "\n") + assert.Len(t, lines, 2) + }) + + t.Run("invalid JSON returns error response", func(t *testing.T) { + input := "not valid json\n" + reader := strings.NewReader(input) + var output bytes.Buffer + + err := handler.HandleStream(reader, &output) + require.NoError(t, err) + + var response JSONRPCResponse + err = json.Unmarshal(output.Bytes(), &response) + require.NoError(t, err) + assert.NotNil(t, response.Error) + assert.Equal(t, ErrCodeParseError, response.Error.Code) + }) +} + +func TestHTTPServer_Integration(t *testing.T) { + mcpServer := createTestServer() + config := &HTTPConfig{Address: "127.0.0.1:0"} + httpServer := NewHTTPServer(mcpServer, config) + + err := httpServer.StartAsync() + require.NoError(t, err) + defer func() { + ctx := context.Background() + _ = httpServer.Shutdown(ctx) + }() + + // Give server time to start. + time.Sleep(50 * time.Millisecond) + + // The server is running, we can test via httptest. + // Since we don't have the actual address easily, use ServeHTTP directly. + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/list", + } + body, _ := json.Marshal(request) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Body.String(), "tools") +} + +func TestHTTPServer_LargeRequest(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + // Create a request within the 1MB limit. + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "ping", + } + body, _ := json.Marshal(request) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +type errorReader struct{} + +func (e *errorReader) Read(p []byte) (n int, err error) { + return 0, io.ErrUnexpectedEOF +} + +func TestHTTPServer_ReadBodyError(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + req := httptest.NewRequest(http.MethodPost, "/", &errorReader{}) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Contains(t, rec.Body.String(), "Failed to read request body") +} + +func TestHTTPServer_NoOriginHeader(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + + request := JSONRPCRequest{JSONRPC: "2.0", ID: 1, Method: "ping"} + body, _ := json.Marshal(request) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // No Origin header set. + + rec := httptest.NewRecorder() + httpServer.ServeHTTP(rec, req) + + // Should use "*" as fallback. + assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) +}