diff --git a/sast-engine/cmd/serve.go b/sast-engine/cmd/serve.go new file mode 100644 index 00000000..146749c9 --- /dev/null +++ b/sast-engine/cmd/serve.go @@ -0,0 +1,74 @@ +package cmd + +import ( + "fmt" + "os" + "time" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/builder" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/registry" + "github.com/shivasurya/code-pathfinder/sast-engine/mcp" + "github.com/shivasurya/code-pathfinder/sast-engine/output" + "github.com/spf13/cobra" +) + +var serveCmd = &cobra.Command{ + Use: "serve", + Short: "Start MCP server for AI coding assistants", + Long: `Builds code index and starts MCP server on stdio. + +Designed for integration with Claude Code, Codex CLI, and other AI assistants +that support the Model Context Protocol (MCP). + +The server indexes the codebase once at startup, then responds to queries +about symbols, call graphs, and code relationships.`, + RunE: runServe, +} + +func init() { + rootCmd.AddCommand(serveCmd) + serveCmd.Flags().StringP("project", "p", ".", "Project path to index") + serveCmd.Flags().String("python-version", "3.11", "Python version for stdlib resolution") +} + +func runServe(cmd *cobra.Command, _ []string) error { + projectPath, _ := cmd.Flags().GetString("project") + pythonVersion, _ := cmd.Flags().GetString("python-version") + + fmt.Fprintln(os.Stderr, "Building index...") + start := time.Now() + + // Create logger for build process (verbose to stderr) + logger := output.NewLogger(output.VerbosityVerbose) + + // 1. Initialize code graph (AST parsing) + codeGraph := graph.Initialize(projectPath) + if codeGraph == nil { + return fmt.Errorf("failed to initialize code graph") + } + + // 2. Build module registry + moduleRegistry, err := registry.BuildModuleRegistry(projectPath, true) // skip tests + if err != nil { + return fmt.Errorf("failed to build module registry: %w", err) + } + + // 3. Build call graph (5-pass algorithm) + callGraph, err := builder.BuildCallGraph(codeGraph, moduleRegistry, projectPath, logger) + if err != nil { + return fmt.Errorf("failed to build call graph: %w", err) + } + + buildTime := time.Since(start) + fmt.Fprintf(os.Stderr, "Index built in %v\n", buildTime) + fmt.Fprintf(os.Stderr, " Functions: %d\n", len(callGraph.Functions)) + fmt.Fprintf(os.Stderr, " Call edges: %d\n", len(callGraph.Edges)) + fmt.Fprintf(os.Stderr, " Modules: %d\n", len(moduleRegistry.Modules)) + + // 4. Create and run MCP server + server := mcp.NewServer(projectPath, pythonVersion, callGraph, moduleRegistry, codeGraph, buildTime) + + fmt.Fprintln(os.Stderr, "Starting MCP server on stdio...") + return server.ServeStdio() +} diff --git a/sast-engine/main_test.go b/sast-engine/main_test.go index d716ff0f..84c64b2e 100644 --- a/sast-engine/main_test.go +++ b/sast-engine/main_test.go @@ -23,7 +23,7 @@ func TestExecute(t *testing.T) { { name: "Successful execution", mockExecuteErr: nil, - expectedOutput: "Code Pathfinder is designed for identifying vulnerabilities in source code.\n\nUsage:\n pathfinder [command]\n\nAvailable Commands:\n ci CI mode with SARIF, JSON, or CSV output for CI/CD integration\n completion Generate the autocompletion script for the specified shell\n diagnose Validate intra-procedural taint analysis against LLM ground truth\n help Help about any command\n resolution-report Generate a diagnostic report on call resolution statistics\n scan Scan code for security vulnerabilities using Python DSL rules\n version Print the version and commit information\n\nFlags:\n --disable-metrics Disable metrics collection\n -h, --help help for pathfinder\n --verbose Verbose output\n\nUse \"pathfinder [command] --help\" for more information about a command.\n", + expectedOutput: "Code Pathfinder is designed for identifying vulnerabilities in source code.\n\nUsage:\n pathfinder [command]\n\nAvailable Commands:\n ci CI mode with SARIF, JSON, or CSV output for CI/CD integration\n completion Generate the autocompletion script for the specified shell\n diagnose Validate intra-procedural taint analysis against LLM ground truth\n help Help about any command\n resolution-report Generate a diagnostic report on call resolution statistics\n scan Scan code for security vulnerabilities using Python DSL rules\n serve Start MCP server for AI coding assistants\n version Print the version and commit information\n\nFlags:\n --disable-metrics Disable metrics collection\n -h, --help help for pathfinder\n --verbose Verbose output\n\nUse \"pathfinder [command] --help\" for more information about a command.\n", expectedExit: 0, }, } diff --git a/sast-engine/mcp/server.go b/sast-engine/mcp/server.go new file mode 100644 index 00000000..489079b0 --- /dev/null +++ b/sast-engine/mcp/server.go @@ -0,0 +1,176 @@ +package mcp + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "os" + "time" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" +) + +// Server handles MCP protocol communication. +type Server struct { + projectPath string + pythonVersion string + callGraph *core.CallGraph + moduleRegistry *core.ModuleRegistry + codeGraph *graph.CodeGraph + indexedAt time.Time + buildTime time.Duration +} + +// NewServer creates a new MCP server with the given index data. +func NewServer( + projectPath string, + pythonVersion string, + callGraph *core.CallGraph, + moduleRegistry *core.ModuleRegistry, + codeGraph *graph.CodeGraph, + buildTime time.Duration, +) *Server { + return &Server{ + projectPath: projectPath, + pythonVersion: pythonVersion, + callGraph: callGraph, + moduleRegistry: moduleRegistry, + codeGraph: codeGraph, + indexedAt: time.Now(), + buildTime: buildTime, + } +} + +// ServeStdio starts the MCP server on stdin/stdout. +func (s *Server) ServeStdio() error { + reader := bufio.NewReader(os.Stdin) + + for { + // Read line from stdin. + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + fmt.Fprintln(os.Stderr, "Client disconnected") + return nil // Clean shutdown + } + return fmt.Errorf("read error: %w", err) + } + + // Skip empty lines. + if len(line) <= 1 { + continue + } + + // Parse JSON-RPC request. + var request JSONRPCRequest + if err := json.Unmarshal([]byte(line), &request); err != nil { + s.sendResponse(ErrorResponse(nil, -32700, "Parse error: "+err.Error())) + continue + } + + // Handle request and send response. + response := s.handleRequest(&request) + if response != nil { + s.sendResponse(response) + } + } +} + +// sendResponse writes a JSON-RPC response to stdout. +func (s *Server) sendResponse(resp *JSONRPCResponse) { + bytes, err := json.Marshal(resp) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to marshal response: %v\n", err) + return + } + fmt.Println(string(bytes)) +} + +// handleRequest dispatches to the appropriate handler. +func (s *Server) handleRequest(req *JSONRPCRequest) *JSONRPCResponse { + startTime := time.Now() + + var response *JSONRPCResponse + + switch req.Method { + case "initialize": + response = s.handleInitialize(req) + case "initialized": + // Acknowledgment notification - no response needed. + fmt.Fprintln(os.Stderr, "Client initialized") + return nil + case "notifications/initialized": + // Alternative notification format. + return nil + case "tools/list": + response = s.handleToolsList(req) + case "tools/call": + response = s.handleToolsCall(req) + case "ping": + response = SuccessResponse(req.ID, map[string]string{"status": "ok"}) + default: + response = ErrorResponse(req.ID, -32601, fmt.Sprintf("Method not found: %s", req.Method)) + } + + // Log request timing. + elapsed := time.Since(startTime) + fmt.Fprintf(os.Stderr, "[%s] %s (%v)\n", req.Method, "completed", elapsed) + + return response +} + +// handleInitialize responds to the initialize request. +func (s *Server) handleInitialize(req *JSONRPCRequest) *JSONRPCResponse { + // Parse client info if needed. + var params InitializeParams + if req.Params != nil { + _ = json.Unmarshal(req.Params, ¶ms) + fmt.Fprintf(os.Stderr, "Client: %s %s\n", params.ClientInfo.Name, params.ClientInfo.Version) + } + + return SuccessResponse(req.ID, InitializeResult{ + ProtocolVersion: "2024-11-05", + ServerInfo: ServerInfo{ + Name: "pathfinder", + Version: "0.1.0-poc", + }, + Capabilities: Capabilities{ + Tools: &ToolsCapability{ + ListChanged: false, + }, + }, + }) +} + +// handleToolsList returns the list of available tools. +func (s *Server) handleToolsList(req *JSONRPCRequest) *JSONRPCResponse { + tools := s.getToolDefinitions() + return SuccessResponse(req.ID, ToolsListResult{ + Tools: tools, + }) +} + +// handleToolsCall executes a tool. +func (s *Server) handleToolsCall(req *JSONRPCRequest) *JSONRPCResponse { + var params ToolCallParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return ErrorResponse(req.ID, -32602, "Invalid params: "+err.Error()) + } + + fmt.Fprintf(os.Stderr, "Tool call: %s\n", params.Name) + + result, isError := s.executeTool(params.Name, params.Arguments) + + return SuccessResponse(req.ID, ToolResult{ + Content: []ContentBlock{ + { + Type: "text", + Text: result, + }, + }, + IsError: isError, + }) +} + diff --git a/sast-engine/mcp/tools.go b/sast-engine/mcp/tools.go new file mode 100644 index 00000000..806b10d3 --- /dev/null +++ b/sast-engine/mcp/tools.go @@ -0,0 +1,535 @@ +package mcp + +import ( + "encoding/json" + "fmt" + "strings" +) + +// getToolDefinitions returns the complete tool schemas. +func (s *Server) getToolDefinitions() []Tool { + return []Tool{ + { + Name: "get_index_info", + Description: `Get statistics about the indexed Python codebase. Use this FIRST to understand the project scope before making other queries. + +Returns: project_path, python_version, indexed_at timestamp, build_time, and stats (functions count, call_edges count, modules count, files count, taint_summaries count). + +Use when: Starting analysis, understanding project size, or verifying the index is built correctly.`, + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]Property{}, + }, + }, + { + Name: "find_symbol", + Description: `Search for functions, classes, or methods by name. Supports partial matching. + +Returns: List of matches with FQN (fully qualified name like 'myapp.auth.login'), file path, line number, type, and metadata (return_type, parameters, decorators, superclass if available). + +Use when: Looking for a specific function, exploring what functions exist, or finding where something is defined. + +Examples: +- find_symbol("login") - finds all functions containing 'login' +- find_symbol("authenticate_user") - finds exact function +- find_symbol("myapp.auth") - finds all symbols in auth module`, + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]Property{ + "name": {Type: "string", Description: "Symbol name to find. Can be: short name ('login'), partial name ('auth'), or FQN ('myapp.auth.login')"}, + }, + Required: []string{"name"}, + }, + }, + { + Name: "get_callers", + Description: `Find all functions that CALL a given function (reverse call graph / incoming edges). Answer: "Who uses this function?" + +Returns: Target function info and list of callers with their FQN, file, line number, and the specific call site location. + +Use when: Understanding function usage, impact analysis before refactoring, finding entry points, or tracing how data flows INTO a function. + +Examples: +- get_callers("sanitize_input") - who calls the sanitize function? +- get_callers("database.execute") - what code runs database queries?`, + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]Property{ + "function": {Type: "string", Description: "Function to find callers for. Use short name ('login') or FQN ('myapp.auth.login')"}, + }, + Required: []string{"function"}, + }, + }, + { + Name: "get_callees", + Description: `Find all functions CALLED BY a given function (forward call graph / outgoing edges). Answer: "What does this function depend on?" + +Returns: Source function info, list of callees with target name, call line, resolution status (resolved/unresolved), and type inference info if available. + +Use when: Understanding function dependencies, analyzing what a function does, tracing data flow FROM a function, or finding unresolved external calls. + +Examples: +- get_callees("process_payment") - what functions does payment processing call? +- get_callees("handle_request") - what are the dependencies of the request handler?`, + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]Property{ + "function": {Type: "string", Description: "Function to find callees for. Use short name ('process') or FQN ('myapp.payment.process')"}, + }, + Required: []string{"function"}, + }, + }, + { + Name: "get_call_details", + Description: `Get detailed information about a SPECIFIC call from one function to another. Most detailed view of a single call site. + +Returns: Full call site info including caller FQN, target, exact location (file, line, column), arguments passed, and resolution details (resolved status, failure reason if unresolved, type inference info). + +Use when: Investigating a specific function call, understanding how arguments are passed, debugging why a call wasn't resolved, or analyzing type inference. + +Examples: +- get_call_details("handle_request", "authenticate") - how does handle_request call authenticate? +- get_call_details("save_user", "execute") - examine the database call in save_user`, + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]Property{ + "caller": {Type: "string", Description: "The function making the call (short name or FQN)"}, + "callee": {Type: "string", Description: "The function being called (short name, will match partially)"}, + }, + Required: []string{"caller", "callee"}, + }, + }, + { + Name: "resolve_import", + Description: `Resolve a Python import path to its actual file location in the project. + +Returns: Import resolution with file_path, module_fqn, match_type (exact/short_name/partial/ambiguous), and alternatives if multiple matches exist. + +Use when: Finding where a module is defined, understanding import structure, or locating source files for external references. + +Examples: +- resolve_import("myapp.auth.users") - find the users module +- resolve_import("utils") - find modules named utils (may return multiple) +- resolve_import("database") - locate database module`, + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]Property{ + "import": {Type: "string", Description: "Import path to resolve. Can be FQN ('myapp.auth.users') or short name ('users')"}, + }, + Required: []string{"import"}, + }, + }, + } +} + +// executeTool runs a tool and returns the result. +func (s *Server) executeTool(name string, args map[string]interface{}) (string, bool) { + switch name { + case "get_index_info": + return s.toolGetIndexInfo() + case "find_symbol": + symbolName, _ := args["name"].(string) + return s.toolFindSymbol(symbolName) + case "get_callers": + function, _ := args["function"].(string) + return s.toolGetCallers(function) + case "get_callees": + function, _ := args["function"].(string) + return s.toolGetCallees(function) + case "get_call_details": + caller, _ := args["caller"].(string) + callee, _ := args["callee"].(string) + return s.toolGetCallDetails(caller, callee) + case "resolve_import": + importPath, _ := args["import"].(string) + return s.toolResolveImport(importPath) + default: + return fmt.Sprintf(`{"error": "Unknown tool: %s"}`, name), true + } +} + +// ============================================================================ +// Tool Implementations +// ============================================================================ + +// toolGetIndexInfo returns index statistics. +func (s *Server) toolGetIndexInfo() (string, bool) { + result := map[string]interface{}{ + "project_path": s.projectPath, + "python_version": s.pythonVersion, + "indexed_at": s.indexedAt.Format("2006-01-02T15:04:05Z07:00"), + "build_time_seconds": s.buildTime.Seconds(), + "stats": map[string]int{ + "functions": len(s.callGraph.Functions), + "call_edges": len(s.callGraph.Edges), + "modules": len(s.moduleRegistry.Modules), + "files": len(s.moduleRegistry.FileToModule), + "taint_summaries": len(s.callGraph.Summaries), + }, + } + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false +} + +// toolFindSymbol finds symbols by name. +func (s *Server) toolFindSymbol(name string) (string, bool) { + if name == "" { + return `{"error": "name parameter is required"}`, true + } + + var matches []map[string]interface{} + + for fqn, node := range s.callGraph.Functions { + shortName := getShortName(fqn) + if shortName == name || strings.HasSuffix(fqn, "."+name) || fqn == name || strings.Contains(fqn, name) { + match := map[string]interface{}{ + "fqn": fqn, + "file": node.File, + "line": node.LineNumber, + "type": node.Type, + } + + // Add optional fields if available. + if node.ReturnType != "" { + match["return_type"] = node.ReturnType + } + if len(node.MethodArgumentsType) > 0 { + match["parameters"] = node.MethodArgumentsType + } + if node.Modifier != "" { + match["modifier"] = node.Modifier + } + if len(node.Annotation) > 0 { + match["decorators"] = node.Annotation + } + if node.SuperClass != "" { + match["superclass"] = node.SuperClass + } + + matches = append(matches, match) + } + } + + if len(matches) == 0 { + return fmt.Sprintf(`{"error": "Symbol not found: %s", "suggestion": "Try a partial name or check spelling"}`, name), true + } + + result := map[string]interface{}{ + "query": name, + "matches": matches, + "total": len(matches), + } + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false +} + +// toolGetCallers finds all callers of a function. +func (s *Server) toolGetCallers(function string) (string, bool) { + if function == "" { + return `{"error": "function parameter is required"}`, true + } + + fqns := s.findMatchingFQNs(function) + if len(fqns) == 0 { + return fmt.Sprintf(`{"error": "Function not found: %s"}`, function), true + } + + // Use first match. + targetFQN := fqns[0] + targetNode := s.callGraph.Functions[targetFQN] + + // Get callers from reverse edges. + callerFQNs := s.callGraph.ReverseEdges[targetFQN] + + callers := make([]map[string]interface{}, 0, len(callerFQNs)) + for _, callerFQN := range callerFQNs { + callerNode := s.callGraph.Functions[callerFQN] + if callerNode == nil { + continue + } + + caller := map[string]interface{}{ + "fqn": callerFQN, + "name": getShortName(callerFQN), + "file": callerNode.File, + "line": callerNode.LineNumber, + } + + // Find the specific call site location. + for _, cs := range s.callGraph.CallSites[callerFQN] { + if cs.TargetFQN == targetFQN || cs.Target == getShortName(targetFQN) { + caller["call_line"] = cs.Location.Line + caller["call_column"] = cs.Location.Column + break + } + } + + callers = append(callers, caller) + } + + result := map[string]interface{}{ + "target": map[string]interface{}{ + "fqn": targetFQN, + "name": getShortName(targetFQN), + "file": targetNode.File, + "line": targetNode.LineNumber, + }, + "callers": callers, + "total_callers": len(callers), + } + + if len(fqns) > 1 { + result["note"] = fmt.Sprintf("Multiple matches found (%d). Showing callers for first match. Other matches: %v", len(fqns), fqns[1:]) + } + + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false +} + +// toolGetCallees finds all functions called by a function. +func (s *Server) toolGetCallees(function string) (string, bool) { + if function == "" { + return `{"error": "function parameter is required"}`, true + } + + fqns := s.findMatchingFQNs(function) + if len(fqns) == 0 { + return fmt.Sprintf(`{"error": "Function not found: %s"}`, function), true + } + + sourceFQN := fqns[0] + sourceNode := s.callGraph.Functions[sourceFQN] + + // Get call sites for this function. + callSites := s.callGraph.CallSites[sourceFQN] + + callees := make([]map[string]interface{}, 0, len(callSites)) + resolvedCount := 0 + unresolvedCount := 0 + + for _, cs := range callSites { + callee := map[string]interface{}{ + "target": cs.Target, + "call_line": cs.Location.Line, + "resolved": cs.Resolved, + } + + if cs.Resolved { + resolvedCount++ + callee["target_fqn"] = cs.TargetFQN + + // Try to get file info for resolved target. + if targetNode := s.callGraph.Functions[cs.TargetFQN]; targetNode != nil { + callee["target_file"] = targetNode.File + callee["target_line"] = targetNode.LineNumber + } + } else { + unresolvedCount++ + if cs.FailureReason != "" { + callee["failure_reason"] = cs.FailureReason + } + } + + // Include type inference info if used. + if cs.ResolvedViaTypeInference { + callee["type_inference"] = map[string]interface{}{ + "inferred_type": cs.InferredType, + "type_confidence": cs.TypeConfidence, + } + } + + callees = append(callees, callee) + } + + result := map[string]interface{}{ + "source": map[string]interface{}{ + "fqn": sourceFQN, + "name": getShortName(sourceFQN), + "file": sourceNode.File, + "line": sourceNode.LineNumber, + }, + "callees": callees, + "total_callees": len(callees), + "resolved_count": resolvedCount, + "unresolved_count": unresolvedCount, + } + + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false +} + +// toolGetCallDetails gets detailed info about a specific call site. +func (s *Server) toolGetCallDetails(callerName, calleeName string) (string, bool) { + if callerName == "" || calleeName == "" { + return `{"error": "caller and callee parameters are required"}`, true + } + + callerFQNs := s.findMatchingFQNs(callerName) + if len(callerFQNs) == 0 { + return fmt.Sprintf(`{"error": "Caller function not found: %s"}`, callerName), true + } + + callerFQN := callerFQNs[0] + callSites := s.callGraph.CallSites[callerFQN] + + // Find matching call site. + for _, cs := range callSites { + if strings.Contains(cs.Target, calleeName) || strings.Contains(cs.TargetFQN, calleeName) { + callSite := map[string]interface{}{ + "caller_fqn": callerFQN, + "target": cs.Target, + "target_fqn": cs.TargetFQN, + "location": map[string]interface{}{ + "file": cs.Location.File, + "line": cs.Location.Line, + "column": cs.Location.Column, + }, + "resolved": cs.Resolved, + } + + // Add arguments if available. + if len(cs.Arguments) > 0 { + args := make([]map[string]interface{}, len(cs.Arguments)) + for i, arg := range cs.Arguments { + args[i] = map[string]interface{}{ + "position": arg.Position, + "value": arg.Value, + } + } + callSite["arguments"] = args + } + + // Add resolution info. + resolution := map[string]interface{}{ + "resolved": cs.Resolved, + } + if !cs.Resolved && cs.FailureReason != "" { + resolution["failure_reason"] = cs.FailureReason + } + if cs.ResolvedViaTypeInference { + resolution["via_type_inference"] = true + resolution["inferred_type"] = cs.InferredType + resolution["type_confidence"] = cs.TypeConfidence + resolution["type_source"] = cs.TypeSource + } + callSite["resolution"] = resolution + + result := map[string]interface{}{ + "call_site": callSite, + } + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false + } + } + + return fmt.Sprintf(`{"error": "Call site not found: %s -> %s", "suggestion": "Check that the caller actually calls the callee"}`, callerName, calleeName), true +} + +// toolResolveImport resolves an import path to file location. +func (s *Server) toolResolveImport(importPath string) (string, bool) { + if importPath == "" { + return `{"error": "import parameter is required"}`, true + } + + // Try exact match first. + if filePath, ok := s.moduleRegistry.Modules[importPath]; ok { + result := map[string]interface{}{ + "import": importPath, + "resolved": true, + "file_path": filePath, + "module_fqn": importPath, + "match_type": "exact", + "alternatives": []interface{}{}, + } + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false + } + + // Try short name lookup. + shortName := getShortName(importPath) + if files, ok := s.moduleRegistry.ShortNames[shortName]; ok && len(files) > 0 { + if len(files) == 1 { + // Unique match. + filePath := files[0] + moduleFQN := s.moduleRegistry.FileToModule[filePath] + result := map[string]interface{}{ + "import": importPath, + "resolved": true, + "file_path": filePath, + "module_fqn": moduleFQN, + "match_type": "short_name", + "alternatives": []interface{}{}, + } + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false + } + + // Multiple matches - return alternatives. + alternatives := make([]map[string]string, len(files)) + for i, f := range files { + alternatives[i] = map[string]string{ + "fqn": s.moduleRegistry.FileToModule[f], + "file": f, + } + } + result := map[string]interface{}{ + "import": importPath, + "resolved": false, + "match_type": "ambiguous", + "alternatives": alternatives, + "suggestion": "Multiple modules match. Use fully qualified import path.", + } + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false + } + + // Try partial match. + var partialMatches []map[string]string + for moduleFQN, filePath := range s.moduleRegistry.Modules { + if strings.Contains(moduleFQN, importPath) { + partialMatches = append(partialMatches, map[string]string{ + "fqn": moduleFQN, + "file": filePath, + }) + } + } + + if len(partialMatches) > 0 { + result := map[string]interface{}{ + "import": importPath, + "resolved": false, + "match_type": "partial", + "alternatives": partialMatches, + "suggestion": "No exact match. Did you mean one of these?", + } + bytes, _ := json.MarshalIndent(result, "", " ") + return string(bytes), false + } + + return fmt.Sprintf(`{"error": "Import not found: %s", "suggestion": "Check if the module is in the indexed project path"}`, importPath), true +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +// findMatchingFQNs finds all FQNs matching a name. +func (s *Server) findMatchingFQNs(name string) []string { + var matches []string + for fqn := range s.callGraph.Functions { + shortName := getShortName(fqn) + if shortName == name || strings.HasSuffix(fqn, "."+name) || fqn == name { + matches = append(matches, fqn) + } + } + return matches +} + +// getShortName extracts the last part of a FQN. +func getShortName(fqn string) string { + parts := strings.Split(fqn, ".") + if len(parts) == 0 { + return fqn + } + return parts[len(parts)-1] +} diff --git a/sast-engine/mcp/types.go b/sast-engine/mcp/types.go new file mode 100644 index 00000000..f255a88a --- /dev/null +++ b/sast-engine/mcp/types.go @@ -0,0 +1,138 @@ +package mcp + +import "encoding/json" + +// ============================================================================ +// JSON-RPC 2.0 Types +// ============================================================================ + +// JSONRPCRequest represents a JSON-RPC 2.0 request. +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +// JSONRPCResponse represents a JSON-RPC 2.0 response. +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Result interface{} `json:"result,omitempty"` + Error *RPCError `json:"error,omitempty"` +} + +// RPCError represents a JSON-RPC 2.0 error. +type RPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// ============================================================================ +// MCP Protocol Types +// ============================================================================ + +// InitializeParams contains initialization parameters from the client. +type InitializeParams struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo ClientInfo `json:"clientInfo"` +} + +// ClientInfo identifies the MCP client. +type ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// InitializeResult is returned to the client after initialization. +type InitializeResult struct { + ProtocolVersion string `json:"protocolVersion"` + ServerInfo ServerInfo `json:"serverInfo"` + Capabilities Capabilities `json:"capabilities"` +} + +// ServerInfo identifies this MCP server. +type ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// Capabilities advertises server features. +type Capabilities struct { + Tools *ToolsCapability `json:"tools,omitempty"` +} + +// ToolsCapability describes tool support capabilities. +type ToolsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` +} + +// ============================================================================ +// Tool Types +// ============================================================================ + +// Tool defines a tool for tools/list response. +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema InputSchema `json:"inputSchema"` +} + +// InputSchema describes tool parameters. +type InputSchema struct { + Type string `json:"type"` + Properties map[string]Property `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +// Property describes a single parameter. +type Property struct { + Type string `json:"type"` + Description string `json:"description"` +} + +// ToolsListResult is returned for tools/list requests. +type ToolsListResult struct { + Tools []Tool `json:"tools"` +} + +// ToolCallParams contains parameters for tools/call requests. +type ToolCallParams struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +// ToolResult is returned for tools/call responses. +type ToolResult struct { + Content []ContentBlock `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// ContentBlock represents a content block for tool output. +type ContentBlock struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +// SuccessResponse creates a successful JSON-RPC response. +func SuccessResponse(id interface{}, result interface{}) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: id, + Result: result, + } +} + +// ErrorResponse creates an error JSON-RPC response. +func ErrorResponse(id interface{}, code int, message string) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: id, + Error: &RPCError{Code: code, Message: message}, + } +}