diff --git a/sast-engine/analytics/usage.go b/sast-engine/analytics/usage.go index 7d8c6955..46605761 100644 --- a/sast-engine/analytics/usage.go +++ b/sast-engine/analytics/usage.go @@ -15,6 +15,15 @@ const ( QueryCommandJSON = "executed_query_command_json_mode" ErrorProcessingQuery = "error_processing_query" QueryCommandStdin = "executed_query_command_stdin_mode" + + // MCP Server events. + MCPServerStarted = "mcp_server_started" + MCPServerStopped = "mcp_server_stopped" + MCPToolCall = "mcp_tool_call" + MCPIndexingStarted = "mcp_indexing_started" + MCPIndexingComplete = "mcp_indexing_complete" + MCPIndexingFailed = "mcp_indexing_failed" + MCPClientConnected = "mcp_client_connected" ) var ( @@ -60,6 +69,12 @@ func LoadEnvFile() { } func ReportEvent(event string) { + ReportEventWithProperties(event, nil) +} + +// ReportEventWithProperties sends an event with additional properties. +// Properties should not contain any PII (no file paths, code, user info). +func ReportEventWithProperties(event string, properties map[string]interface{}) { if enableMetrics && PublicKey != "" { client, err := posthog.NewWithConfig( PublicKey, @@ -71,11 +86,21 @@ func ReportEvent(event string) { fmt.Println(err) return } - err = client.Enqueue(posthog.Capture{ + defer client.Close() + + capture := posthog.Capture{ DistinctId: os.Getenv("uuid"), Event: event, - }) - defer client.Close() + } + + if properties != nil { + capture.Properties = posthog.NewProperties() + for k, v := range properties { + capture.Properties.Set(k, v) + } + } + + err = client.Enqueue(capture) if err != nil { fmt.Println(err) return diff --git a/sast-engine/cmd/serve.go b/sast-engine/cmd/serve.go index 7507efd5..860f8d48 100644 --- a/sast-engine/cmd/serve.go +++ b/sast-engine/cmd/serve.go @@ -36,20 +36,29 @@ Transport modes: func init() { rootCmd.AddCommand(serveCmd) serveCmd.Flags().StringP("project", "p", ".", "Project path to index") - serveCmd.Flags().String("python-version", "3.11", "Python version for stdlib resolution") + serveCmd.Flags().String("python-version", "", "Python version override (auto-detected from .python-version or pyproject.toml)") serveCmd.Flags().Bool("http", false, "Use HTTP transport instead of stdio") serveCmd.Flags().String("address", ":8080", "HTTP server address (only with --http)") } func runServe(cmd *cobra.Command, _ []string) error { projectPath, _ := cmd.Flags().GetString("project") - pythonVersion, _ := cmd.Flags().GetString("python-version") + pythonVersionOverride, _ := cmd.Flags().GetString("python-version") useHTTP, _ := cmd.Flags().GetBool("http") address, _ := cmd.Flags().GetString("address") fmt.Fprintln(os.Stderr, "Building index...") start := time.Now() + // Auto-detect Python version (same logic as BuildCallGraph). + pythonVersion := builder.DetectPythonVersion(projectPath) + if pythonVersionOverride != "" { + pythonVersion = pythonVersionOverride + fmt.Fprintf(os.Stderr, "Using Python version override: %s\n", pythonVersion) + } else { + fmt.Fprintf(os.Stderr, "Detected Python version: %s\n", pythonVersion) + } + // Create logger for build process (verbose to stderr) logger := output.NewLogger(output.VerbosityVerbose) @@ -90,6 +99,9 @@ func runServe(cmd *cobra.Command, _ []string) error { } func runHTTPServer(mcpServer *mcp.Server, address string) error { + // Set transport type for analytics. + mcpServer.SetTransport("http") + config := &mcp.HTTPConfig{ Address: address, ReadTimeout: 30 * time.Second, diff --git a/sast-engine/mcp/analytics.go b/sast-engine/mcp/analytics.go new file mode 100644 index 00000000..7a83e2ea --- /dev/null +++ b/sast-engine/mcp/analytics.go @@ -0,0 +1,114 @@ +package mcp + +import ( + "runtime" + "time" + + "github.com/shivasurya/code-pathfinder/sast-engine/analytics" +) + +// Analytics provides MCP-specific telemetry helpers. +// All events are anonymous and contain no PII. +type Analytics struct { + transport string // "stdio" or "http" + startTime time.Time +} + +// NewAnalytics creates a new analytics instance. +func NewAnalytics(transport string) *Analytics { + return &Analytics{ + transport: transport, + startTime: time.Now(), + } +} + +// ReportServerStarted reports that the MCP server has started. +func (a *Analytics) ReportServerStarted() { + analytics.ReportEventWithProperties(analytics.MCPServerStarted, map[string]interface{}{ + "transport": a.transport, + "os": runtime.GOOS, + "arch": runtime.GOARCH, + }) +} + +// ReportServerStopped reports that the MCP server has stopped. +func (a *Analytics) ReportServerStopped() { + analytics.ReportEventWithProperties(analytics.MCPServerStopped, map[string]interface{}{ + "transport": a.transport, + "uptime_seconds": time.Since(a.startTime).Seconds(), + }) +} + +// ReportToolCall reports a tool invocation with timing and success info. +// No file paths or code content is included. +func (a *Analytics) ReportToolCall(toolName string, durationMs int64, success bool) { + analytics.ReportEventWithProperties(analytics.MCPToolCall, map[string]interface{}{ + "tool": toolName, + "duration_ms": durationMs, + "success": success, + "transport": a.transport, + }) +} + +// ReportIndexingStarted reports that indexing has begun. +func (a *Analytics) ReportIndexingStarted() { + analytics.ReportEventWithProperties(analytics.MCPIndexingStarted, map[string]interface{}{ + "transport": a.transport, + }) +} + +// ReportIndexingComplete reports successful indexing completion. +// Only aggregate counts are reported, no file paths. +func (a *Analytics) ReportIndexingComplete(stats *IndexingStats) { + props := map[string]interface{}{ + "transport": a.transport, + "duration_seconds": stats.BuildDuration.Seconds(), + "function_count": stats.Functions, + "call_edge_count": stats.CallEdges, + "module_count": stats.Modules, + "file_count": stats.Files, + } + analytics.ReportEventWithProperties(analytics.MCPIndexingComplete, props) +} + +// ReportIndexingFailed reports indexing failure. +// Error messages are not included to avoid potential PII. +func (a *Analytics) ReportIndexingFailed(phase string) { + analytics.ReportEventWithProperties(analytics.MCPIndexingFailed, map[string]interface{}{ + "transport": a.transport, + "phase": phase, + }) +} + +// ReportClientConnected reports a client connection with client info. +// Only client name/version (from MCP protocol) is reported. +func (a *Analytics) ReportClientConnected(clientName, clientVersion string) { + analytics.ReportEventWithProperties(analytics.MCPClientConnected, map[string]interface{}{ + "transport": a.transport, + "client_name": clientName, + "client_version": clientVersion, + }) +} + +// ToolCallMetrics holds metrics for a tool call. +type ToolCallMetrics struct { + StartTime time.Time + ToolName string +} + +// StartToolCall begins tracking a tool call. +func (a *Analytics) StartToolCall(toolName string) *ToolCallMetrics { + return &ToolCallMetrics{ + StartTime: time.Now(), + ToolName: toolName, + } +} + +// EndToolCall completes tracking and reports the metric. +func (a *Analytics) EndToolCall(m *ToolCallMetrics, success bool) { + if m == nil { + return + } + durationMs := time.Since(m.StartTime).Milliseconds() + a.ReportToolCall(m.ToolName, durationMs, success) +} diff --git a/sast-engine/mcp/analytics_test.go b/sast-engine/mcp/analytics_test.go new file mode 100644 index 00000000..963f85c1 --- /dev/null +++ b/sast-engine/mcp/analytics_test.go @@ -0,0 +1,128 @@ +package mcp + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewAnalytics(t *testing.T) { + a := NewAnalytics("stdio") + + assert.NotNil(t, a) + assert.Equal(t, "stdio", a.transport) + assert.False(t, a.startTime.IsZero()) +} + +func TestNewAnalytics_HTTP(t *testing.T) { + a := NewAnalytics("http") + + assert.Equal(t, "http", a.transport) +} + +func TestAnalytics_ReportServerStarted(t *testing.T) { + a := NewAnalytics("stdio") + + // Should not panic even with analytics disabled. + a.ReportServerStarted() +} + +func TestAnalytics_ReportServerStopped(t *testing.T) { + a := NewAnalytics("http") + + // Should not panic. + a.ReportServerStopped() +} + +func TestAnalytics_ReportToolCall(t *testing.T) { + a := NewAnalytics("stdio") + + // Should not panic. + a.ReportToolCall("find_symbol", 150, true) + a.ReportToolCall("get_callers", 50, false) +} + +func TestAnalytics_ReportIndexingStarted(t *testing.T) { + a := NewAnalytics("stdio") + + // Should not panic. + a.ReportIndexingStarted() +} + +func TestAnalytics_ReportIndexingComplete(t *testing.T) { + a := NewAnalytics("http") + + stats := &IndexingStats{ + Functions: 100, + CallEdges: 500, + Modules: 20, + Files: 50, + BuildDuration: 5 * time.Second, + } + + // Should not panic. + a.ReportIndexingComplete(stats) +} + +func TestAnalytics_ReportIndexingFailed(t *testing.T) { + a := NewAnalytics("stdio") + + // Should not panic. + a.ReportIndexingFailed("parsing") + a.ReportIndexingFailed("call_graph") +} + +func TestAnalytics_ReportClientConnected(t *testing.T) { + a := NewAnalytics("http") + + // Should not panic. + a.ReportClientConnected("claude-code", "1.0.0") + a.ReportClientConnected("", "") +} + +func TestAnalytics_StartToolCall(t *testing.T) { + a := NewAnalytics("stdio") + + metrics := a.StartToolCall("find_symbol") + + assert.NotNil(t, metrics) + assert.Equal(t, "find_symbol", metrics.ToolName) + assert.False(t, metrics.StartTime.IsZero()) +} + +func TestAnalytics_EndToolCall(t *testing.T) { + a := NewAnalytics("stdio") + + metrics := a.StartToolCall("get_callers") + time.Sleep(1 * time.Millisecond) // Small delay to ensure duration > 0 + + // Should not panic. + a.EndToolCall(metrics, true) +} + +func TestAnalytics_EndToolCall_Nil(t *testing.T) { + a := NewAnalytics("stdio") + + // Should not panic with nil metrics. + a.EndToolCall(nil, true) +} + +func TestAnalytics_EndToolCall_Failed(t *testing.T) { + a := NewAnalytics("http") + + metrics := a.StartToolCall("resolve_import") + + // Should not panic. + a.EndToolCall(metrics, false) +} + +func TestToolCallMetrics(t *testing.T) { + metrics := &ToolCallMetrics{ + StartTime: time.Now(), + ToolName: "test_tool", + } + + assert.Equal(t, "test_tool", metrics.ToolName) + assert.False(t, metrics.StartTime.IsZero()) +} diff --git a/sast-engine/mcp/errors.go b/sast-engine/mcp/errors.go index d319333c..21d0a113 100644 --- a/sast-engine/mcp/errors.go +++ b/sast-engine/mcp/errors.go @@ -99,9 +99,15 @@ func SymbolNotFoundError(symbol string, suggestions []string) *RPCError { fmt.Sprintf("Symbol not found: %s", symbol), data) } -// IndexNotReadyError creates an index not ready error. -func IndexNotReadyError() *RPCError { - return NewRPCError(ErrCodeIndexNotReady, nil) +// IndexNotReadyError creates an index not ready error with optional progress info. +func IndexNotReadyError(phase string, progress float64) *RPCError { + data := map[string]interface{}{ + "phase": phase, + "progress": progress, + } + return NewRPCErrorWithMessage(ErrCodeIndexNotReady, + fmt.Sprintf("Index not ready: %s (%.0f%% complete)", phase, progress*100), + data) } // QueryTimeoutError creates a query timeout error. diff --git a/sast-engine/mcp/errors_test.go b/sast-engine/mcp/errors_test.go index 18953ba9..d68fcc91 100644 --- a/sast-engine/mcp/errors_test.go +++ b/sast-engine/mcp/errors_test.go @@ -98,10 +98,15 @@ func TestSymbolNotFoundError_NoSuggestions(t *testing.T) { } func TestIndexNotReadyError(t *testing.T) { - err := IndexNotReadyError() + err := IndexNotReadyError("parsing", 0.5) assert.Equal(t, ErrCodeIndexNotReady, err.Code) - assert.Equal(t, "Index not ready", err.Message) + assert.Contains(t, err.Message, "parsing") + assert.Contains(t, err.Message, "50%") + + data := err.Data.(map[string]interface{}) + assert.Equal(t, "parsing", data["phase"]) + assert.Equal(t, 0.5, data["progress"]) } func TestQueryTimeoutError(t *testing.T) { diff --git a/sast-engine/mcp/http_test.go b/sast-engine/mcp/http_test.go index d4187b03..6dd31a97 100644 --- a/sast-engine/mcp/http_test.go +++ b/sast-engine/mcp/http_test.go @@ -465,3 +465,192 @@ func TestHTTPServer_NoOriginHeader(t *testing.T) { // Should use "*" as fallback. assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) } + +func TestHTTPServer_Start_AlreadyRunning(t *testing.T) { + mcpServer := createTestServer() + config := &HTTPConfig{Address: "127.0.0.1:0"} + httpServer := NewHTTPServer(mcpServer, config) + + // Start async first. + err := httpServer.StartAsync() + require.NoError(t, err) + defer func() { + ctx := context.Background() + _ = httpServer.Shutdown(ctx) + }() + + // Try to start synchronously - should fail. + err = httpServer.Start() + assert.Error(t, err) + assert.Contains(t, err.Error(), "already running") +} + +func TestHTTPServer_Start_Blocking(t *testing.T) { + mcpServer := createTestServer() + config := &HTTPConfig{Address: "127.0.0.1:0"} + httpServer := NewHTTPServer(mcpServer, config) + + // Start in goroutine since it blocks. + errCh := make(chan error, 1) + go func() { + errCh <- httpServer.Start() + }() + + // Give server time to start. + time.Sleep(50 * time.Millisecond) + + // Verify it's running. + assert.True(t, httpServer.IsRunning()) + + // Shutdown. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := httpServer.Shutdown(ctx) + require.NoError(t, err) + + // Wait for Start() to return. + select { + case err := <-errCh: + // Server closed error is expected. + assert.Contains(t, err.Error(), "Server closed") + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Start to return") + } +} + +func TestSSEServer_ServeSSE(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + sseServer := NewSSEServer(httpServer) + + // Create a request with a cancellable context. + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest(http.MethodGet, "/sse", nil).WithContext(ctx) + req.Header.Set("Origin", "http://example.com") + + rec := httptest.NewRecorder() + + // Run ServeSSE in a goroutine since it blocks. + done := make(chan struct{}) + go func() { + sseServer.ServeSSE(rec, req) + close(done) + }() + + // Give it time to set headers and send connected event. + time.Sleep(50 * time.Millisecond) + + // Cancel context to close connection. + cancel() + + // Wait for ServeSSE to return. + select { + case <-done: + // Good, it returned. + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for ServeSSE to return") + } + + // Verify SSE headers. + assert.Equal(t, "text/event-stream", rec.Header().Get("Content-Type")) + assert.Equal(t, "no-cache", rec.Header().Get("Cache-Control")) + assert.Equal(t, "keep-alive", rec.Header().Get("Connection")) + + // Verify connected event was sent. + assert.Contains(t, rec.Body.String(), "event: connected") + assert.Contains(t, rec.Body.String(), "status") +} + +// mockResponseWriter doesn't implement http.Flusher. +type noFlushResponseWriter struct { + http.ResponseWriter +} + +func TestSSEServer_ServeSSE_NoFlusher(t *testing.T) { + mcpServer := createTestServer() + httpServer := NewHTTPServer(mcpServer, nil) + sseServer := NewSSEServer(httpServer) + + req := httptest.NewRequest(http.MethodGet, "/sse", nil) + + // Use a writer that doesn't implement Flusher. + rec := httptest.NewRecorder() + noFlush := &noFlushResponseWriter{rec} + + sseServer.ServeSSE(noFlush, req) + + // Should return error response. + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "SSE not supported") +} + +func TestHTTPServer_Shutdown_WithContext(t *testing.T) { + mcpServer := createTestServer() + config := &HTTPConfig{Address: "127.0.0.1:0"} + httpServer := NewHTTPServer(mcpServer, config) + + err := httpServer.StartAsync() + require.NoError(t, err) + + // Shutdown with cancelled context should still work. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = httpServer.Shutdown(ctx) + assert.NoError(t, err) + assert.False(t, httpServer.IsRunning()) +} + +func TestHTTPServer_Start_InvalidAddress(t *testing.T) { + mcpServer := createTestServer() + // Use an invalid address that should fail to listen. + config := &HTTPConfig{Address: "invalid:address:format:99999"} + httpServer := NewHTTPServer(mcpServer, config) + + err := httpServer.Start() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to listen") + assert.False(t, httpServer.IsRunning()) +} + +func TestHTTPServer_StartAsync_InvalidAddress(t *testing.T) { + mcpServer := createTestServer() + // Use an invalid address that should fail to listen. + config := &HTTPConfig{Address: "invalid:address:format:99999"} + httpServer := NewHTTPServer(mcpServer, config) + + err := httpServer.StartAsync() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to listen") + assert.False(t, httpServer.IsRunning()) +} + +// errorWriter always returns an error. +type errorWriter struct{} + +func (e *errorWriter) Write(p []byte) (n int, err error) { + return 0, io.ErrClosedPipe +} + +func TestStreamingHTTPHandler_HandleStream_WriteError(t *testing.T) { + mcpServer := createTestServer() + handler := NewStreamingHTTPHandler(mcpServer) + + input := `{"jsonrpc":"2.0","id":1,"method":"ping"}` + "\n" + reader := strings.NewReader(input) + + err := handler.HandleStream(reader, &errorWriter{}) + assert.Error(t, err) +} + +func TestStreamingHTTPHandler_HandleStream_WriteErrorOnParseError(t *testing.T) { + mcpServer := createTestServer() + handler := NewStreamingHTTPHandler(mcpServer) + + // Invalid JSON to trigger parse error, which then tries to write. + input := "not valid json\n" + reader := strings.NewReader(input) + + err := handler.HandleStream(reader, &errorWriter{}) + assert.Error(t, err) +} diff --git a/sast-engine/mcp/server.go b/sast-engine/mcp/server.go index f0316758..03e6dfef 100644 --- a/sast-engine/mcp/server.go +++ b/sast-engine/mcp/server.go @@ -21,6 +21,9 @@ type Server struct { codeGraph *graph.CodeGraph indexedAt time.Time buildTime time.Duration + statusTracker *StatusTracker + degradation *GracefulDegradation + analytics *Analytics } // NewServer creates a new MCP server with the given index data. @@ -32,6 +35,23 @@ func NewServer( codeGraph *graph.CodeGraph, buildTime time.Duration, ) *Server { + tracker := NewStatusTracker() + + // Mark as ready since we're being created with complete data. + tracker.StartIndexing() + stats := &IndexingStats{ + Functions: len(callGraph.Functions), + CallEdges: len(callGraph.Edges), + Modules: len(moduleRegistry.Modules), + Files: len(moduleRegistry.FileToModule), + BuildDuration: buildTime, + } + tracker.CompleteIndexing(stats) + + // Initialize analytics with stdio transport (default). + mcpAnalytics := NewAnalytics("stdio") + mcpAnalytics.ReportIndexingComplete(stats) + return &Server{ projectPath: projectPath, pythonVersion: pythonVersion, @@ -40,13 +60,25 @@ func NewServer( codeGraph: codeGraph, indexedAt: time.Now(), buildTime: buildTime, + statusTracker: tracker, + degradation: NewGracefulDegradation(tracker), + analytics: mcpAnalytics, } } +// SetTransport updates the analytics transport type (e.g., "http"). +func (s *Server) SetTransport(transport string) { + s.analytics = NewAnalytics(transport) +} + // ServeStdio starts the MCP server on stdin/stdout. func (s *Server) ServeStdio() error { reader := bufio.NewReader(os.Stdin) + // Report server started. + s.analytics.ReportServerStarted() + defer s.analytics.ReportServerStopped() + for { // Read line from stdin. line, err := reader.ReadString('\n') @@ -118,6 +150,8 @@ func (s *Server) handleRequest(req *JSONRPCRequest) *JSONRPCResponse { response = s.handleToolsList(req) case "tools/call": response = s.handleToolsCall(req) + case "status": + response = s.handleStatus(req) case "ping": response = SuccessResponse(req.ID, map[string]string{"status": "ok"}) default: @@ -138,6 +172,9 @@ func (s *Server) handleInitialize(req *JSONRPCRequest) *JSONRPCResponse { if req.Params != nil { _ = json.Unmarshal(req.Params, ¶ms) fmt.Fprintf(os.Stderr, "Client: %s %s\n", params.ClientInfo.Name, params.ClientInfo.Version) + + // Report client connection (only name/version, no PII). + s.analytics.ReportClientConnected(params.ClientInfo.Name, params.ClientInfo.Version) } return SuccessResponse(req.ID, InitializeResult{ @@ -175,7 +212,10 @@ func (s *Server) handleToolsCall(req *JSONRPCRequest) *JSONRPCResponse { fmt.Fprintf(os.Stderr, "Tool call: %s\n", params.Name) + // Track tool call metrics. + metrics := s.analytics.StartToolCall(params.Name) result, isError := s.executeTool(params.Name, params.Arguments) + s.analytics.EndToolCall(metrics, !isError) return SuccessResponse(req.ID, ToolResult{ Content: []ContentBlock{ @@ -188,3 +228,18 @@ func (s *Server) handleToolsCall(req *JSONRPCRequest) *JSONRPCResponse { }) } +// handleStatus returns the current indexing status. +func (s *Server) handleStatus(req *JSONRPCRequest) *JSONRPCResponse { + return SuccessResponse(req.ID, s.degradation.GetStatusJSON()) +} + +// GetStatusTracker returns the status tracker for external use. +func (s *Server) GetStatusTracker() *StatusTracker { + return s.statusTracker +} + +// IsReady returns true if the server is ready to handle tool requests. +func (s *Server) IsReady() bool { + return s.statusTracker.IsReady() +} + diff --git a/sast-engine/mcp/server_test.go b/sast-engine/mcp/server_test.go index 1941f4ae..69a04820 100644 --- a/sast-engine/mcp/server_test.go +++ b/sast-engine/mcp/server_test.go @@ -588,3 +588,83 @@ func TestHandleRequest_MethodNotFound_WithErrorData(t *testing.T) { data := resp.Error.Data.(map[string]string) assert.Equal(t, "unknown/method", data["method"]) } + +// ============================================================================ +// Status and Analytics Tests (PR-05) +// ============================================================================ + +func TestSetTransport(t *testing.T) { + server := createTestServer() + + // Default is stdio. + assert.NotNil(t, server.analytics) + + // Set to HTTP transport. + server.SetTransport("http") + + // Analytics instance should be updated. + assert.NotNil(t, server.analytics) +} + +func TestHandleStatus(t *testing.T) { + server := createTestServer() + + req := &JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "status", + } + + resp := server.handleStatus(req) + + assert.NotNil(t, resp) + assert.Nil(t, resp.Error) + assert.NotNil(t, resp.Result) +} + +func TestHandleRequest_Status(t *testing.T) { + server := createTestServer() + + req := &JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "status", + } + + resp := server.handleRequest(req) + + assert.NotNil(t, resp) + assert.Nil(t, resp.Error) +} + +func TestGetStatusTracker(t *testing.T) { + server := createTestServer() + + tracker := server.GetStatusTracker() + + assert.NotNil(t, tracker) + assert.Equal(t, server.statusTracker, tracker) +} + +func TestIsReady(t *testing.T) { + server := createTestServer() + + // Server should be ready after NewServer is called with valid data. + assert.True(t, server.IsReady()) +} + +func TestIsReady_NotReady(t *testing.T) { + callGraph := core.NewCallGraph() + moduleRegistry := &core.ModuleRegistry{ + Modules: map[string]string{}, + FileToModule: map[string]string{}, + ShortNames: map[string][]string{}, + } + + server := NewServer("/test", "3.11", callGraph, moduleRegistry, nil, time.Second) + + // Manually set status tracker to indexing to simulate not-ready state. + server.statusTracker.StartIndexing() + + assert.False(t, server.IsReady()) +} diff --git a/sast-engine/mcp/status.go b/sast-engine/mcp/status.go new file mode 100644 index 00000000..e7a15acf --- /dev/null +++ b/sast-engine/mcp/status.go @@ -0,0 +1,398 @@ +package mcp + +import ( + "sync" + "time" +) + +// IndexingState represents the current state of the indexing process. +type IndexingState int + +const ( + // StateUninitialized means indexing hasn't started yet. + StateUninitialized IndexingState = iota + // StateIndexing means indexing is in progress. + StateIndexing + // StateReady means indexing is complete and server is ready. + StateReady + // StateFailed means indexing failed. + StateFailed +) + +// String returns the string representation of the state. +func (s IndexingState) String() string { + switch s { + case StateUninitialized: + return "uninitialized" + case StateIndexing: + return "indexing" + case StateReady: + return "ready" + case StateFailed: + return "failed" + default: + return "unknown" + } +} + +// IndexingPhase represents the current phase of indexing. +type IndexingPhase int + +const ( + PhaseNone IndexingPhase = iota + PhaseParsing + PhaseModuleRegistry + PhaseCallGraph + PhaseComplete +) + +// String returns the string representation of the phase. +func (p IndexingPhase) String() string { + switch p { + case PhaseNone: + return "none" + case PhaseParsing: + return "parsing" + case PhaseModuleRegistry: + return "module_registry" + case PhaseCallGraph: + return "call_graph" + case PhaseComplete: + return "complete" + default: + return "unknown" + } +} + +// IndexingProgress holds detailed progress information. +type IndexingProgress struct { + Phase IndexingPhase `json:"phase"` + PhaseProgress float64 `json:"phaseProgress"` // 0.0 to 1.0 + OverallProgress float64 `json:"overallProgress"` // 0.0 to 1.0 + FilesProcessed int `json:"filesProcessed"` + TotalFiles int `json:"totalFiles"` + CurrentFile string `json:"currentFile,omitempty"` + Message string `json:"message,omitempty"` +} + +// IndexingStatus holds the complete indexing status. +type IndexingStatus struct { + State IndexingState `json:"state"` + Progress IndexingProgress `json:"progress"` + StartedAt *time.Time `json:"startedAt,omitempty"` + CompletedAt *time.Time `json:"completedAt,omitempty"` + Error string `json:"error,omitempty"` + Stats *IndexingStats `json:"stats,omitempty"` +} + +// IndexingStats holds statistics after indexing completes. +type IndexingStats struct { + Functions int `json:"functions"` + CallEdges int `json:"callEdges"` + Modules int `json:"modules"` + Files int `json:"files"` + BuildDuration time.Duration `json:"buildDuration"` +} + +// StatusTracker tracks and reports indexing status. +type StatusTracker struct { + mu sync.RWMutex + state IndexingState + progress IndexingProgress + startedAt *time.Time + completedAt *time.Time + errorMsg string + stats *IndexingStats + subscribers []chan IndexingStatus +} + +// NewStatusTracker creates a new status tracker. +func NewStatusTracker() *StatusTracker { + return &StatusTracker{ + state: StateUninitialized, + progress: IndexingProgress{ + Phase: PhaseNone, + }, + } +} + +// GetStatus returns the current indexing status. +func (t *StatusTracker) GetStatus() IndexingStatus { + t.mu.RLock() + defer t.mu.RUnlock() + + return IndexingStatus{ + State: t.state, + Progress: t.progress, + StartedAt: t.startedAt, + CompletedAt: t.completedAt, + Error: t.errorMsg, + Stats: t.stats, + } +} + +// GetState returns the current indexing state. +func (t *StatusTracker) GetState() IndexingState { + t.mu.RLock() + defer t.mu.RUnlock() + return t.state +} + +// IsReady returns true if the server is ready to handle requests. +func (t *StatusTracker) IsReady() bool { + t.mu.RLock() + defer t.mu.RUnlock() + return t.state == StateReady +} + +// StartIndexing marks the start of indexing. +func (t *StatusTracker) StartIndexing() { + t.mu.Lock() + defer t.mu.Unlock() + + now := time.Now() + t.state = StateIndexing + t.startedAt = &now + t.completedAt = nil + t.errorMsg = "" + t.progress = IndexingProgress{ + Phase: PhaseParsing, + PhaseProgress: 0, + OverallProgress: 0, + } + + t.notifySubscribers() +} + +// SetPhase updates the current indexing phase. +func (t *StatusTracker) SetPhase(phase IndexingPhase, message string) { + t.mu.Lock() + defer t.mu.Unlock() + + t.progress.Phase = phase + t.progress.PhaseProgress = 0 + t.progress.Message = message + + // Calculate overall progress based on phase. + switch phase { + case PhaseParsing: + t.progress.OverallProgress = 0.0 + case PhaseModuleRegistry: + t.progress.OverallProgress = 0.33 + case PhaseCallGraph: + t.progress.OverallProgress = 0.66 + case PhaseComplete: + t.progress.OverallProgress = 1.0 + } + + t.notifySubscribers() +} + +// UpdateProgress updates progress within the current phase. +func (t *StatusTracker) UpdateProgress(processed, total int, currentFile string) { + t.mu.Lock() + defer t.mu.Unlock() + + t.progress.FilesProcessed = processed + t.progress.TotalFiles = total + t.progress.CurrentFile = currentFile + + if total > 0 { + t.progress.PhaseProgress = float64(processed) / float64(total) + } + + // Update overall progress based on phase + phase progress. + baseProgress := 0.0 + phaseWeight := 0.33 + + switch t.progress.Phase { + case PhaseParsing: + baseProgress = 0.0 + case PhaseModuleRegistry: + baseProgress = 0.33 + case PhaseCallGraph: + baseProgress = 0.66 + } + + t.progress.OverallProgress = baseProgress + (t.progress.PhaseProgress * phaseWeight) + + t.notifySubscribers() +} + +// CompleteIndexing marks indexing as complete. +func (t *StatusTracker) CompleteIndexing(stats *IndexingStats) { + t.mu.Lock() + defer t.mu.Unlock() + + now := time.Now() + t.state = StateReady + t.completedAt = &now + t.stats = stats + t.progress = IndexingProgress{ + Phase: PhaseComplete, + PhaseProgress: 1.0, + OverallProgress: 1.0, + Message: "Indexing complete", + } + + t.notifySubscribers() +} + +// FailIndexing marks indexing as failed. +func (t *StatusTracker) FailIndexing(err error) { + t.mu.Lock() + defer t.mu.Unlock() + + now := time.Now() + t.state = StateFailed + t.completedAt = &now + if err != nil { + t.errorMsg = err.Error() + } + t.progress.Message = "Indexing failed" + + t.notifySubscribers() +} + +// Subscribe returns a channel that receives status updates. +func (t *StatusTracker) Subscribe() chan IndexingStatus { + t.mu.Lock() + defer t.mu.Unlock() + + ch := make(chan IndexingStatus, 10) + t.subscribers = append(t.subscribers, ch) + + // Send current status immediately. + select { + case ch <- t.buildStatus(): + default: + } + + return ch +} + +// Unsubscribe removes a subscription channel. +func (t *StatusTracker) Unsubscribe(ch chan IndexingStatus) { + t.mu.Lock() + defer t.mu.Unlock() + + for i, sub := range t.subscribers { + if sub == ch { + t.subscribers = append(t.subscribers[:i], t.subscribers[i+1:]...) + close(ch) + break + } + } +} + +// notifySubscribers sends status to all subscribers (must be called with lock held). +func (t *StatusTracker) notifySubscribers() { + status := t.buildStatus() + for _, ch := range t.subscribers { + select { + case ch <- status: + default: + // Channel full, skip. + } + } +} + +// buildStatus creates status struct (must be called with lock held). +func (t *StatusTracker) buildStatus() IndexingStatus { + return IndexingStatus{ + State: t.state, + Progress: t.progress, + StartedAt: t.startedAt, + CompletedAt: t.completedAt, + Error: t.errorMsg, + Stats: t.stats, + } +} + +// GracefulDegradation provides methods for handling requests during indexing. +type GracefulDegradation struct { + tracker *StatusTracker +} + +// NewGracefulDegradation creates a new graceful degradation handler. +func NewGracefulDegradation(tracker *StatusTracker) *GracefulDegradation { + return &GracefulDegradation{tracker: tracker} +} + +// CheckReady checks if the server is ready and returns an appropriate error if not. +func (g *GracefulDegradation) CheckReady() *RPCError { + status := g.tracker.GetStatus() + + switch status.State { + case StateReady: + return nil + case StateIndexing: + return IndexNotReadyError( + status.Progress.Phase.String(), + status.Progress.OverallProgress, + ) + case StateFailed: + return InternalError("Indexing failed: " + status.Error) + default: + return IndexNotReadyError("uninitialized", 0) + } +} + +// WrapToolCall wraps a tool call with readiness checking. +func (g *GracefulDegradation) WrapToolCall(toolName string, fn func() (string, bool)) (string, bool) { + if err := g.CheckReady(); err != nil { + return NewToolError(err.Message, err.Code, map[string]interface{}{ + "tool": toolName, + "status": g.tracker.GetStatus(), + }), true + } + return fn() +} + +// GetStatusJSON returns the current status as a JSON-friendly map. +func (g *GracefulDegradation) GetStatusJSON() map[string]interface{} { + status := g.tracker.GetStatus() + + result := map[string]interface{}{ + "state": status.State.String(), + "progress": map[string]interface{}{ + "phase": status.Progress.Phase.String(), + "phaseProgress": status.Progress.PhaseProgress, + "overallProgress": status.Progress.OverallProgress, + "filesProcessed": status.Progress.FilesProcessed, + "totalFiles": status.Progress.TotalFiles, + }, + } + + if status.Progress.CurrentFile != "" { + result["progress"].(map[string]interface{})["currentFile"] = status.Progress.CurrentFile + } + + if status.Progress.Message != "" { + result["progress"].(map[string]interface{})["message"] = status.Progress.Message + } + + if status.StartedAt != nil { + result["startedAt"] = status.StartedAt.Format(time.RFC3339) + } + + if status.CompletedAt != nil { + result["completedAt"] = status.CompletedAt.Format(time.RFC3339) + } + + if status.Error != "" { + result["error"] = status.Error + } + + if status.Stats != nil { + result["stats"] = map[string]interface{}{ + "functions": status.Stats.Functions, + "callEdges": status.Stats.CallEdges, + "modules": status.Stats.Modules, + "files": status.Stats.Files, + "buildDuration": status.Stats.BuildDuration.String(), + } + } + + return result +} diff --git a/sast-engine/mcp/status_test.go b/sast-engine/mcp/status_test.go new file mode 100644 index 00000000..f335057e --- /dev/null +++ b/sast-engine/mcp/status_test.go @@ -0,0 +1,452 @@ +package mcp + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIndexingState_String(t *testing.T) { + tests := []struct { + state IndexingState + expected string + }{ + {StateUninitialized, "uninitialized"}, + {StateIndexing, "indexing"}, + {StateReady, "ready"}, + {StateFailed, "failed"}, + {IndexingState(99), "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.state.String()) + }) + } +} + +func TestIndexingPhase_String(t *testing.T) { + tests := []struct { + phase IndexingPhase + expected string + }{ + {PhaseNone, "none"}, + {PhaseParsing, "parsing"}, + {PhaseModuleRegistry, "module_registry"}, + {PhaseCallGraph, "call_graph"}, + {PhaseComplete, "complete"}, + {IndexingPhase(99), "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.phase.String()) + }) + } +} + +func TestNewStatusTracker(t *testing.T) { + tracker := NewStatusTracker() + + assert.NotNil(t, tracker) + assert.Equal(t, StateUninitialized, tracker.GetState()) + assert.False(t, tracker.IsReady()) +} + +func TestStatusTracker_StartIndexing(t *testing.T) { + tracker := NewStatusTracker() + + tracker.StartIndexing() + + status := tracker.GetStatus() + assert.Equal(t, StateIndexing, status.State) + assert.NotNil(t, status.StartedAt) + assert.Nil(t, status.CompletedAt) + assert.Equal(t, PhaseParsing, status.Progress.Phase) + assert.Equal(t, 0.0, status.Progress.OverallProgress) +} + +func TestStatusTracker_SetPhase(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + + tests := []struct { + phase IndexingPhase + expectedOverall float64 + }{ + {PhaseParsing, 0.0}, + {PhaseModuleRegistry, 0.33}, + {PhaseCallGraph, 0.66}, + {PhaseComplete, 1.0}, + } + + for _, tt := range tests { + t.Run(tt.phase.String(), func(t *testing.T) { + tracker.SetPhase(tt.phase, "test message") + + status := tracker.GetStatus() + assert.Equal(t, tt.phase, status.Progress.Phase) + assert.Equal(t, tt.expectedOverall, status.Progress.OverallProgress) + assert.Equal(t, "test message", status.Progress.Message) + assert.Equal(t, 0.0, status.Progress.PhaseProgress) + }) + } +} + +func TestStatusTracker_UpdateProgress(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + tracker.SetPhase(PhaseParsing, "Parsing files") + + tracker.UpdateProgress(50, 100, "file.py") + + status := tracker.GetStatus() + assert.Equal(t, 50, status.Progress.FilesProcessed) + assert.Equal(t, 100, status.Progress.TotalFiles) + assert.Equal(t, "file.py", status.Progress.CurrentFile) + assert.Equal(t, 0.5, status.Progress.PhaseProgress) + // Overall progress: 0.0 (base) + 0.5 * 0.33 (phase weight) = 0.165 + assert.InDelta(t, 0.165, status.Progress.OverallProgress, 0.01) +} + +func TestStatusTracker_UpdateProgress_ZeroTotal(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + + tracker.UpdateProgress(0, 0, "") + + status := tracker.GetStatus() + assert.Equal(t, 0.0, status.Progress.PhaseProgress) +} + +func TestStatusTracker_CompleteIndexing(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + + stats := &IndexingStats{ + Functions: 100, + CallEdges: 500, + Modules: 20, + Files: 50, + BuildDuration: 5 * time.Second, + } + + tracker.CompleteIndexing(stats) + + status := tracker.GetStatus() + assert.Equal(t, StateReady, status.State) + assert.True(t, tracker.IsReady()) + assert.NotNil(t, status.CompletedAt) + assert.Equal(t, PhaseComplete, status.Progress.Phase) + assert.Equal(t, 1.0, status.Progress.OverallProgress) + assert.Equal(t, stats, status.Stats) +} + +func TestStatusTracker_FailIndexing(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + + err := errors.New("something went wrong") + tracker.FailIndexing(err) + + status := tracker.GetStatus() + assert.Equal(t, StateFailed, status.State) + assert.False(t, tracker.IsReady()) + assert.NotNil(t, status.CompletedAt) + assert.Equal(t, "something went wrong", status.Error) + assert.Equal(t, "Indexing failed", status.Progress.Message) +} + +func TestStatusTracker_FailIndexing_NilError(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + + tracker.FailIndexing(nil) + + status := tracker.GetStatus() + assert.Equal(t, StateFailed, status.State) + assert.Empty(t, status.Error) +} + +func TestStatusTracker_Subscribe(t *testing.T) { + tracker := NewStatusTracker() + + ch := tracker.Subscribe() + require.NotNil(t, ch) + + // Should receive initial status. + select { + case status := <-ch: + assert.Equal(t, StateUninitialized, status.State) + case <-time.After(100 * time.Millisecond): + t.Fatal("expected initial status") + } + + // Start indexing and receive update. + tracker.StartIndexing() + + select { + case status := <-ch: + assert.Equal(t, StateIndexing, status.State) + case <-time.After(100 * time.Millisecond): + t.Fatal("expected status update") + } + + tracker.Unsubscribe(ch) +} + +func TestStatusTracker_Unsubscribe(t *testing.T) { + tracker := NewStatusTracker() + + ch := tracker.Subscribe() + // Drain initial status. + <-ch + + tracker.Unsubscribe(ch) + + // Channel should be closed. + _, ok := <-ch + assert.False(t, ok) +} + +func TestStatusTracker_MultipleSubscribers(t *testing.T) { + tracker := NewStatusTracker() + + ch1 := tracker.Subscribe() + ch2 := tracker.Subscribe() + + // Drain initial statuses. + <-ch1 + <-ch2 + + tracker.StartIndexing() + + // Both should receive update. + select { + case s := <-ch1: + assert.Equal(t, StateIndexing, s.State) + case <-time.After(100 * time.Millisecond): + t.Fatal("ch1 should receive update") + } + + select { + case s := <-ch2: + assert.Equal(t, StateIndexing, s.State) + case <-time.After(100 * time.Millisecond): + t.Fatal("ch2 should receive update") + } + + tracker.Unsubscribe(ch1) + tracker.Unsubscribe(ch2) +} + +func TestNewGracefulDegradation(t *testing.T) { + tracker := NewStatusTracker() + gd := NewGracefulDegradation(tracker) + + assert.NotNil(t, gd) +} + +func TestGracefulDegradation_CheckReady_Ready(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + tracker.CompleteIndexing(&IndexingStats{}) + gd := NewGracefulDegradation(tracker) + + err := gd.CheckReady() + assert.Nil(t, err) +} + +func TestGracefulDegradation_CheckReady_Indexing(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + tracker.SetPhase(PhaseCallGraph, "Building call graph") + tracker.UpdateProgress(50, 100, "") + gd := NewGracefulDegradation(tracker) + + err := gd.CheckReady() + require.NotNil(t, err) + assert.Equal(t, ErrCodeIndexNotReady, err.Code) + assert.Contains(t, err.Message, "call_graph") +} + +func TestGracefulDegradation_CheckReady_Failed(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + tracker.FailIndexing(errors.New("oops")) + gd := NewGracefulDegradation(tracker) + + err := gd.CheckReady() + require.NotNil(t, err) + assert.Equal(t, ErrCodeInternalError, err.Code) + assert.Contains(t, err.Message, "oops") +} + +func TestGracefulDegradation_CheckReady_Uninitialized(t *testing.T) { + tracker := NewStatusTracker() + gd := NewGracefulDegradation(tracker) + + err := gd.CheckReady() + require.NotNil(t, err) + assert.Equal(t, ErrCodeIndexNotReady, err.Code) +} + +func TestGracefulDegradation_WrapToolCall_Ready(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + tracker.CompleteIndexing(&IndexingStats{}) + gd := NewGracefulDegradation(tracker) + + called := false + result, isError := gd.WrapToolCall("test_tool", func() (string, bool) { + called = true + return `{"result": "ok"}`, false + }) + + assert.True(t, called) + assert.False(t, isError) + assert.Contains(t, result, "ok") +} + +func TestGracefulDegradation_WrapToolCall_NotReady(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + gd := NewGracefulDegradation(tracker) + + called := false + result, isError := gd.WrapToolCall("test_tool", func() (string, bool) { + called = true + return `{"result": "ok"}`, false + }) + + assert.False(t, called) + assert.True(t, isError) + assert.Contains(t, result, "test_tool") + assert.Contains(t, result, "status") +} + +func TestGracefulDegradation_GetStatusJSON(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + tracker.SetPhase(PhaseParsing, "Parsing source files") + tracker.UpdateProgress(10, 50, "main.py") + gd := NewGracefulDegradation(tracker) + + status := gd.GetStatusJSON() + + assert.Equal(t, "indexing", status["state"]) + assert.NotNil(t, status["progress"]) + assert.NotNil(t, status["startedAt"]) + + progress := status["progress"].(map[string]interface{}) + assert.Equal(t, "parsing", progress["phase"]) + assert.Equal(t, 10, progress["filesProcessed"]) + assert.Equal(t, 50, progress["totalFiles"]) + assert.Equal(t, "main.py", progress["currentFile"]) + assert.Equal(t, "Parsing source files", progress["message"]) +} + +func TestGracefulDegradation_GetStatusJSON_Complete(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + tracker.CompleteIndexing(&IndexingStats{ + Functions: 100, + CallEdges: 500, + Modules: 20, + Files: 50, + BuildDuration: 5 * time.Second, + }) + gd := NewGracefulDegradation(tracker) + + status := gd.GetStatusJSON() + + assert.Equal(t, "ready", status["state"]) + assert.NotNil(t, status["completedAt"]) + assert.NotNil(t, status["stats"]) + + stats := status["stats"].(map[string]interface{}) + assert.Equal(t, 100, stats["functions"]) + assert.Equal(t, 500, stats["callEdges"]) +} + +func TestGracefulDegradation_GetStatusJSON_Failed(t *testing.T) { + tracker := NewStatusTracker() + tracker.StartIndexing() + tracker.FailIndexing(errors.New("parse error")) + gd := NewGracefulDegradation(tracker) + + status := gd.GetStatusJSON() + + assert.Equal(t, "failed", status["state"]) + assert.Equal(t, "parse error", status["error"]) +} + +func TestStatusTracker_FullWorkflow(t *testing.T) { + tracker := NewStatusTracker() + + // Initial state. + assert.Equal(t, StateUninitialized, tracker.GetState()) + + // Start indexing. + tracker.StartIndexing() + assert.Equal(t, StateIndexing, tracker.GetState()) + + // Phase 1: Parsing. + tracker.SetPhase(PhaseParsing, "Parsing Python files") + for i := 0; i <= 100; i += 10 { + tracker.UpdateProgress(i, 100, "file.py") + } + + // Phase 2: Module registry. + tracker.SetPhase(PhaseModuleRegistry, "Building module registry") + for i := 0; i <= 100; i += 20 { + tracker.UpdateProgress(i, 100, "") + } + + // Phase 3: Call graph. + tracker.SetPhase(PhaseCallGraph, "Building call graph") + for i := 0; i <= 100; i += 25 { + tracker.UpdateProgress(i, 100, "") + } + + // Complete. + tracker.CompleteIndexing(&IndexingStats{ + Functions: 50, + CallEdges: 100, + }) + + assert.Equal(t, StateReady, tracker.GetState()) + assert.True(t, tracker.IsReady()) +} + +func TestStatusTracker_ConcurrentAccess(t *testing.T) { + tracker := NewStatusTracker() + + done := make(chan bool) + + // Reader goroutine. + go func() { + for i := 0; i < 100; i++ { + _ = tracker.GetStatus() + _ = tracker.GetState() + _ = tracker.IsReady() + } + done <- true + }() + + // Writer goroutine. + go func() { + tracker.StartIndexing() + for i := 0; i < 50; i++ { + tracker.UpdateProgress(i, 50, "file.py") + } + tracker.CompleteIndexing(&IndexingStats{}) + done <- true + }() + + <-done + <-done +}