diff --git a/sast-engine/mcp/pagination.go b/sast-engine/mcp/pagination.go new file mode 100644 index 00000000..63f97db6 --- /dev/null +++ b/sast-engine/mcp/pagination.go @@ -0,0 +1,150 @@ +package mcp + +import ( + "encoding/base64" + "encoding/json" + "fmt" +) + +// Default and max limits. +const ( + DefaultLimit = 50 + MaxLimit = 500 +) + +// PaginationParams holds pagination parameters from request. +type PaginationParams struct { + Limit int `json:"limit"` + Cursor string `json:"cursor"` +} + +// PaginationInfo holds pagination metadata for response. +type PaginationInfo struct { + Total int `json:"total"` + Returned int `json:"returned"` + HasMore bool `json:"hasMore"` + NextCursor string `json:"nextCursor,omitempty"` +} + +// Cursor represents an opaque pagination cursor. +type Cursor struct { + Offset int `json:"o"` + Query string `json:"q,omitempty"` +} + +// EncodeCursor creates an opaque cursor string. +func EncodeCursor(offset int, query string) string { + c := Cursor{Offset: offset, Query: query} + bytes, _ := json.Marshal(c) + return base64.URLEncoding.EncodeToString(bytes) +} + +// DecodeCursor parses a cursor string. +func DecodeCursor(cursor string) (*Cursor, error) { + if cursor == "" { + return &Cursor{Offset: 0}, nil + } + + bytes, err := base64.URLEncoding.DecodeString(cursor) + if err != nil { + return nil, fmt.Errorf("invalid cursor: %w", err) + } + + var c Cursor + if err := json.Unmarshal(bytes, &c); err != nil { + return nil, fmt.Errorf("invalid cursor format: %w", err) + } + + return &c, nil +} + +// ExtractPaginationParams extracts and validates pagination params. +func ExtractPaginationParams(args map[string]interface{}) (*PaginationParams, *RPCError) { + params := &PaginationParams{ + Limit: DefaultLimit, + } + + // Extract limit. + if limitVal, ok := args["limit"]; ok { + switch v := limitVal.(type) { + case float64: + params.Limit = int(v) + case int: + params.Limit = v + default: + return nil, InvalidParamsError("limit must be a number") + } + } + + // Validate limit. + if params.Limit <= 0 { + params.Limit = DefaultLimit + } + if params.Limit > MaxLimit { + params.Limit = MaxLimit + } + + // Extract cursor. + if cursorVal, ok := args["cursor"].(string); ok { + params.Cursor = cursorVal + } + + return params, nil +} + +// PaginateSlice applies pagination to a slice of any type. +func PaginateSlice[T any](items []T, params *PaginationParams) ([]T, *PaginationInfo) { + total := len(items) + + // Decode cursor to get offset. + cursor, err := DecodeCursor(params.Cursor) + if err != nil { + cursor = &Cursor{Offset: 0} + } + + offset := cursor.Offset + limit := params.Limit + + // Bounds check. + if offset >= total { + return []T{}, &PaginationInfo{ + Total: total, + Returned: 0, + HasMore: false, + } + } + + end := offset + limit + if end > total { + end = total + } + + result := items[offset:end] + hasMore := end < total + + info := &PaginationInfo{ + Total: total, + Returned: len(result), + HasMore: hasMore, + } + + if hasMore { + info.NextCursor = EncodeCursor(end, cursor.Query) + } + + return result, info +} + +// PaginatedResult wraps results with pagination info. +type PaginatedResult struct { + Items interface{} `json:"items"` + Pagination PaginationInfo `json:"pagination"` +} + +// NewPaginatedResult creates a paginated result. +func NewPaginatedResult(items interface{}, info *PaginationInfo) *PaginatedResult { + return &PaginatedResult{ + Items: items, + Pagination: *info, + } +} diff --git a/sast-engine/mcp/pagination_test.go b/sast-engine/mcp/pagination_test.go new file mode 100644 index 00000000..81199e65 --- /dev/null +++ b/sast-engine/mcp/pagination_test.go @@ -0,0 +1,288 @@ +package mcp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeCursor(t *testing.T) { + cursor := EncodeCursor(100, "test") + + assert.NotEmpty(t, cursor) + // Should be URL-safe base64. + assert.NotContains(t, cursor, "+") + assert.NotContains(t, cursor, "/") +} + +func TestDecodeCursor(t *testing.T) { + // Encode then decode. + encoded := EncodeCursor(50, "query") + decoded, err := DecodeCursor(encoded) + + require.NoError(t, err) + assert.Equal(t, 50, decoded.Offset) + assert.Equal(t, "query", decoded.Query) +} + +func TestDecodeCursor_Empty(t *testing.T) { + decoded, err := DecodeCursor("") + + require.NoError(t, err) + assert.Equal(t, 0, decoded.Offset) +} + +func TestDecodeCursor_Invalid(t *testing.T) { + _, err := DecodeCursor("not-valid-base64!!!") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid cursor") +} + +func TestDecodeCursor_InvalidJSON(t *testing.T) { + // Valid base64 but invalid JSON. + _, err := DecodeCursor("bm90anNvbg==") // "notjson" + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid cursor format") +} + +func TestExtractPaginationParams_Defaults(t *testing.T) { + args := map[string]interface{}{} + params, err := ExtractPaginationParams(args) + + assert.Nil(t, err) + assert.Equal(t, DefaultLimit, params.Limit) + assert.Empty(t, params.Cursor) +} + +func TestExtractPaginationParams_WithLimit(t *testing.T) { + args := map[string]interface{}{ + "limit": float64(25), + } + params, err := ExtractPaginationParams(args) + + assert.Nil(t, err) + assert.Equal(t, 25, params.Limit) +} + +func TestExtractPaginationParams_LimitInt(t *testing.T) { + args := map[string]interface{}{ + "limit": 30, + } + params, err := ExtractPaginationParams(args) + + assert.Nil(t, err) + assert.Equal(t, 30, params.Limit) +} + +func TestExtractPaginationParams_LimitCapped(t *testing.T) { + args := map[string]interface{}{ + "limit": float64(10000), + } + params, err := ExtractPaginationParams(args) + + assert.Nil(t, err) + assert.Equal(t, MaxLimit, params.Limit) +} + +func TestExtractPaginationParams_InvalidLimit(t *testing.T) { + args := map[string]interface{}{ + "limit": "not a number", + } + _, err := ExtractPaginationParams(args) + + assert.NotNil(t, err) + assert.Equal(t, ErrCodeInvalidParams, err.Code) +} + +func TestExtractPaginationParams_NegativeLimit(t *testing.T) { + args := map[string]interface{}{ + "limit": float64(-5), + } + params, err := ExtractPaginationParams(args) + + assert.Nil(t, err) + assert.Equal(t, DefaultLimit, params.Limit) // Reset to default. +} + +func TestExtractPaginationParams_ZeroLimit(t *testing.T) { + args := map[string]interface{}{ + "limit": float64(0), + } + params, err := ExtractPaginationParams(args) + + assert.Nil(t, err) + assert.Equal(t, DefaultLimit, params.Limit) // Reset to default. +} + +func TestExtractPaginationParams_WithCursor(t *testing.T) { + cursor := EncodeCursor(100, "") + args := map[string]interface{}{ + "cursor": cursor, + } + params, err := ExtractPaginationParams(args) + + assert.Nil(t, err) + assert.Equal(t, cursor, params.Cursor) +} + +func TestPaginateSlice_FirstPage(t *testing.T) { + items := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + params := &PaginationParams{Limit: 3, Cursor: ""} + + result, info := PaginateSlice(items, params) + + assert.Equal(t, []string{"a", "b", "c"}, result) + assert.Equal(t, 10, info.Total) + assert.Equal(t, 3, info.Returned) + assert.True(t, info.HasMore) + assert.NotEmpty(t, info.NextCursor) +} + +func TestPaginateSlice_MiddlePage(t *testing.T) { + items := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + cursor := EncodeCursor(3, "") + params := &PaginationParams{Limit: 3, Cursor: cursor} + + result, info := PaginateSlice(items, params) + + assert.Equal(t, []string{"d", "e", "f"}, result) + assert.Equal(t, 10, info.Total) + assert.Equal(t, 3, info.Returned) + assert.True(t, info.HasMore) +} + +func TestPaginateSlice_LastPage(t *testing.T) { + items := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + cursor := EncodeCursor(9, "") + params := &PaginationParams{Limit: 3, Cursor: cursor} + + result, info := PaginateSlice(items, params) + + assert.Equal(t, []string{"j"}, result) + assert.Equal(t, 10, info.Total) + assert.Equal(t, 1, info.Returned) + assert.False(t, info.HasMore) + assert.Empty(t, info.NextCursor) +} + +func TestPaginateSlice_PastEnd(t *testing.T) { + items := []string{"a", "b", "c"} + cursor := EncodeCursor(100, "") + params := &PaginationParams{Limit: 10, Cursor: cursor} + + result, info := PaginateSlice(items, params) + + assert.Empty(t, result) + assert.Equal(t, 3, info.Total) + assert.Equal(t, 0, info.Returned) + assert.False(t, info.HasMore) +} + +func TestPaginateSlice_ExactFit(t *testing.T) { + items := []string{"a", "b", "c"} + params := &PaginationParams{Limit: 3, Cursor: ""} + + result, info := PaginateSlice(items, params) + + assert.Equal(t, []string{"a", "b", "c"}, result) + assert.False(t, info.HasMore) +} + +func TestPaginateSlice_EmptySlice(t *testing.T) { + items := []string{} + params := &PaginationParams{Limit: 10, Cursor: ""} + + result, info := PaginateSlice(items, params) + + assert.Empty(t, result) + assert.Equal(t, 0, info.Total) + assert.False(t, info.HasMore) +} + +func TestPaginateSlice_InvalidCursor(t *testing.T) { + items := []string{"a", "b", "c"} + params := &PaginationParams{Limit: 2, Cursor: "invalid!!!"} + + result, info := PaginateSlice(items, params) + + // Should start from beginning on invalid cursor. + assert.Equal(t, []string{"a", "b"}, result) + assert.Equal(t, 3, info.Total) +} + +func TestNewPaginatedResult(t *testing.T) { + items := []string{"a", "b"} + info := &PaginationInfo{Total: 10, Returned: 2, HasMore: true} + + result := NewPaginatedResult(items, info) + + assert.Equal(t, items, result.Items) + assert.Equal(t, *info, result.Pagination) +} + +func TestPaginationConstants(t *testing.T) { + assert.Equal(t, 50, DefaultLimit) + assert.Equal(t, 500, MaxLimit) + assert.True(t, MaxLimit > DefaultLimit) +} + +func TestPaginateSlice_IntSlice(t *testing.T) { + items := []int{1, 2, 3, 4, 5} + params := &PaginationParams{Limit: 2, Cursor: ""} + + result, info := PaginateSlice(items, params) + + assert.Equal(t, []int{1, 2}, result) + assert.Equal(t, 5, info.Total) + assert.True(t, info.HasMore) +} + +func TestPaginateSlice_StructSlice(t *testing.T) { + type Item struct { + Name string + } + items := []Item{{Name: "a"}, {Name: "b"}, {Name: "c"}} + params := &PaginationParams{Limit: 2, Cursor: ""} + + result, info := PaginateSlice(items, params) + + assert.Len(t, result, 2) + assert.Equal(t, "a", result[0].Name) + assert.Equal(t, 3, info.Total) +} + +func TestCursorRoundTrip(t *testing.T) { + // Test that encoding and decoding preserves values. + tests := []struct { + offset int + query string + }{ + {0, ""}, + {100, "test"}, + {999999, "complex query with spaces"}, + {1, ""}, + } + + for _, tt := range tests { + encoded := EncodeCursor(tt.offset, tt.query) + decoded, err := DecodeCursor(encoded) + + require.NoError(t, err) + assert.Equal(t, tt.offset, decoded.Offset) + assert.Equal(t, tt.query, decoded.Query) + } +} + +func TestPaginateSlice_CursorPreservesQuery(t *testing.T) { + items := []string{"a", "b", "c", "d", "e"} + cursor := EncodeCursor(0, "myquery") + params := &PaginationParams{Limit: 2, Cursor: cursor} + + _, info := PaginateSlice(items, params) + + // Decode the next cursor to verify query is preserved. + nextCursor, err := DecodeCursor(info.NextCursor) + require.NoError(t, err) + assert.Equal(t, "myquery", nextCursor.Query) +} diff --git a/sast-engine/mcp/tools.go b/sast-engine/mcp/tools.go index 806b10d3..11dbff9e 100644 --- a/sast-engine/mcp/tools.go +++ b/sast-engine/mcp/tools.go @@ -23,9 +23,9 @@ Use when: Starting analysis, understanding project size, or verifying the index }, { Name: "find_symbol", - Description: `Search for functions, classes, or methods by name. Supports partial matching. + Description: `Search for functions, classes, or methods by name. Supports partial matching. Results are paginated. -Returns: List of matches with FQN (fully qualified name like 'myapp.auth.login'), file path, line number, type, and metadata (return_type, parameters, decorators, superclass if available). +Returns: List of matches with FQN (fully qualified name like 'myapp.auth.login'), file path, line number, type, and metadata (return_type, parameters, decorators, superclass if available). Includes pagination info. Use when: Looking for a specific function, exploring what functions exist, or finding where something is defined. @@ -36,16 +36,18 @@ Examples: InputSchema: InputSchema{ Type: "object", Properties: map[string]Property{ - "name": {Type: "string", Description: "Symbol name to find. Can be: short name ('login'), partial name ('auth'), or FQN ('myapp.auth.login')"}, + "name": {Type: "string", Description: "Symbol name to find. Can be: short name ('login'), partial name ('auth'), or FQN ('myapp.auth.login')"}, + "limit": {Type: "integer", Description: "Max results to return (default: 50, max: 500)"}, + "cursor": {Type: "string", Description: "Pagination cursor from previous response"}, }, Required: []string{"name"}, }, }, { Name: "get_callers", - Description: `Find all functions that CALL a given function (reverse call graph / incoming edges). Answer: "Who uses this function?" + Description: `Find all functions that CALL a given function (reverse call graph / incoming edges). Answer: "Who uses this function?" Results are paginated. -Returns: Target function info and list of callers with their FQN, file, line number, and the specific call site location. +Returns: Target function info and list of callers with their FQN, file, line number, and the specific call site location. Includes pagination info. Use when: Understanding function usage, impact analysis before refactoring, finding entry points, or tracing how data flows INTO a function. @@ -56,15 +58,17 @@ Examples: Type: "object", Properties: map[string]Property{ "function": {Type: "string", Description: "Function to find callers for. Use short name ('login') or FQN ('myapp.auth.login')"}, + "limit": {Type: "integer", Description: "Max results to return (default: 50, max: 500)"}, + "cursor": {Type: "string", Description: "Pagination cursor from previous response"}, }, Required: []string{"function"}, }, }, { Name: "get_callees", - Description: `Find all functions CALLED BY a given function (forward call graph / outgoing edges). Answer: "What does this function depend on?" + Description: `Find all functions CALLED BY a given function (forward call graph / outgoing edges). Answer: "What does this function depend on?" Results are paginated. -Returns: Source function info, list of callees with target name, call line, resolution status (resolved/unresolved), and type inference info if available. +Returns: Source function info, list of callees with target name, call line, resolution status (resolved/unresolved), and type inference info if available. Includes pagination info. Use when: Understanding function dependencies, analyzing what a function does, tracing data flow FROM a function, or finding unresolved external calls. @@ -75,6 +79,8 @@ Examples: Type: "object", Properties: map[string]Property{ "function": {Type: "string", Description: "Function to find callees for. Use short name ('process') or FQN ('myapp.payment.process')"}, + "limit": {Type: "integer", Description: "Max results to return (default: 50, max: 500)"}, + "cursor": {Type: "string", Description: "Pagination cursor from previous response"}, }, Required: []string{"function"}, }, @@ -128,14 +134,11 @@ func (s *Server) executeTool(name string, args map[string]interface{}) (string, case "get_index_info": return s.toolGetIndexInfo() case "find_symbol": - symbolName, _ := args["name"].(string) - return s.toolFindSymbol(symbolName) + return s.toolFindSymbol(args) case "get_callers": - function, _ := args["function"].(string) - return s.toolGetCallers(function) + return s.toolGetCallers(args) case "get_callees": - function, _ := args["function"].(string) - return s.toolGetCallees(function) + return s.toolGetCallees(args) case "get_call_details": caller, _ := args["caller"].(string) callee, _ := args["callee"].(string) @@ -171,13 +174,20 @@ func (s *Server) toolGetIndexInfo() (string, bool) { return string(bytes), false } -// toolFindSymbol finds symbols by name. -func (s *Server) toolFindSymbol(name string) (string, bool) { +// toolFindSymbol finds symbols by name with pagination support. +func (s *Server) toolFindSymbol(args map[string]interface{}) (string, bool) { + name, _ := args["name"].(string) if name == "" { return `{"error": "name parameter is required"}`, true } - var matches []map[string]interface{} + // Extract pagination params. + pageParams, err := ExtractPaginationParams(args) + if err != nil { + return NewToolError(err.Message, err.Code, err.Data), true + } + + var allMatches []map[string]interface{} for fqn, node := range s.callGraph.Functions { shortName := getShortName(fqn) @@ -206,29 +216,39 @@ func (s *Server) toolFindSymbol(name string) (string, bool) { match["superclass"] = node.SuperClass } - matches = append(matches, match) + allMatches = append(allMatches, match) } } - if len(matches) == 0 { + if len(allMatches) == 0 { return fmt.Sprintf(`{"error": "Symbol not found: %s", "suggestion": "Try a partial name or check spelling"}`, name), true } + // Apply pagination. + matches, pageInfo := PaginateSlice(allMatches, pageParams) + result := map[string]interface{}{ - "query": name, - "matches": matches, - "total": len(matches), + "query": name, + "matches": matches, + "pagination": pageInfo, } bytes, _ := json.MarshalIndent(result, "", " ") return string(bytes), false } -// toolGetCallers finds all callers of a function. -func (s *Server) toolGetCallers(function string) (string, bool) { +// toolGetCallers finds all callers of a function with pagination support. +func (s *Server) toolGetCallers(args map[string]interface{}) (string, bool) { + function, _ := args["function"].(string) if function == "" { return `{"error": "function parameter is required"}`, true } + // Extract pagination params. + pageParams, err := ExtractPaginationParams(args) + if err != nil { + return NewToolError(err.Message, err.Code, err.Data), true + } + fqns := s.findMatchingFQNs(function) if len(fqns) == 0 { return fmt.Sprintf(`{"error": "Function not found: %s"}`, function), true @@ -241,7 +261,7 @@ func (s *Server) toolGetCallers(function string) (string, bool) { // Get callers from reverse edges. callerFQNs := s.callGraph.ReverseEdges[targetFQN] - callers := make([]map[string]interface{}, 0, len(callerFQNs)) + allCallers := make([]map[string]interface{}, 0, len(callerFQNs)) for _, callerFQN := range callerFQNs { callerNode := s.callGraph.Functions[callerFQN] if callerNode == nil { @@ -264,9 +284,12 @@ func (s *Server) toolGetCallers(function string) (string, bool) { } } - callers = append(callers, caller) + allCallers = append(allCallers, caller) } + // Apply pagination. + callers, pageInfo := PaginateSlice(allCallers, pageParams) + result := map[string]interface{}{ "target": map[string]interface{}{ "fqn": targetFQN, @@ -274,8 +297,8 @@ func (s *Server) toolGetCallers(function string) (string, bool) { "file": targetNode.File, "line": targetNode.LineNumber, }, - "callers": callers, - "total_callers": len(callers), + "callers": callers, + "pagination": pageInfo, } if len(fqns) > 1 { @@ -287,11 +310,18 @@ func (s *Server) toolGetCallers(function string) (string, bool) { } // toolGetCallees finds all functions called by a function. -func (s *Server) toolGetCallees(function string) (string, bool) { +func (s *Server) toolGetCallees(args map[string]interface{}) (string, bool) { + function, _ := args["function"].(string) if function == "" { return `{"error": "function parameter is required"}`, true } + // Extract pagination params. + pageParams, rpcErr := ExtractPaginationParams(args) + if rpcErr != nil { + return fmt.Sprintf(`{"error": "%s"}`, rpcErr.Message), true + } + fqns := s.findMatchingFQNs(function) if len(fqns) == 0 { return fmt.Sprintf(`{"error": "Function not found: %s"}`, function), true @@ -303,7 +333,7 @@ func (s *Server) toolGetCallees(function string) (string, bool) { // Get call sites for this function. callSites := s.callGraph.CallSites[sourceFQN] - callees := make([]map[string]interface{}, 0, len(callSites)) + allCallees := make([]map[string]interface{}, 0, len(callSites)) resolvedCount := 0 unresolvedCount := 0 @@ -338,9 +368,12 @@ func (s *Server) toolGetCallees(function string) (string, bool) { } } - callees = append(callees, callee) + allCallees = append(allCallees, callee) } + // Apply pagination. + callees, pageInfo := PaginateSlice(allCallees, pageParams) + result := map[string]interface{}{ "source": map[string]interface{}{ "fqn": sourceFQN, @@ -349,7 +382,7 @@ func (s *Server) toolGetCallees(function string) (string, bool) { "line": sourceNode.LineNumber, }, "callees": callees, - "total_callees": len(callees), + "pagination": pageInfo, "resolved_count": resolvedCount, "unresolved_count": unresolvedCount, } diff --git a/sast-engine/mcp/tools_test.go b/sast-engine/mcp/tools_test.go index e961f0c7..57a89372 100644 --- a/sast-engine/mcp/tools_test.go +++ b/sast-engine/mcp/tools_test.go @@ -38,7 +38,7 @@ func TestToolGetIndexInfo(t *testing.T) { func TestToolFindSymbol_Found(t *testing.T) { server := createTestServer() - result, isError := server.toolFindSymbol("validate_user") + result, isError := server.toolFindSymbol(map[string]interface{}{"name": "validate_user"}) assert.False(t, isError) assert.Contains(t, result, "validate_user") @@ -49,7 +49,7 @@ func TestToolFindSymbol_Found(t *testing.T) { func TestToolFindSymbol_PartialMatch(t *testing.T) { server := createTestServer() - result, isError := server.toolFindSymbol("validate") + result, isError := server.toolFindSymbol(map[string]interface{}{"name": "validate"}) // Should find validate_user via partial match. assert.False(t, isError) @@ -59,7 +59,7 @@ func TestToolFindSymbol_PartialMatch(t *testing.T) { func TestToolFindSymbol_MultipleMatches(t *testing.T) { server := createTestServer() - result, isError := server.toolFindSymbol("log") + result, isError := server.toolFindSymbol(map[string]interface{}{"name": "log"}) // Should find both login and logout. assert.False(t, isError) @@ -70,7 +70,7 @@ func TestToolFindSymbol_MultipleMatches(t *testing.T) { func TestToolFindSymbol_NotFound(t *testing.T) { server := createTestServer() - result, isError := server.toolFindSymbol("nonexistent_function_xyz") + result, isError := server.toolFindSymbol(map[string]interface{}{"name": "nonexistent_function_xyz"}) assert.True(t, isError) assert.Contains(t, result, "not found") @@ -80,7 +80,7 @@ func TestToolFindSymbol_NotFound(t *testing.T) { func TestToolFindSymbol_EmptyName(t *testing.T) { server := createTestServer() - result, isError := server.toolFindSymbol("") + result, isError := server.toolFindSymbol(map[string]interface{}{"name": ""}) assert.True(t, isError) assert.Contains(t, result, "required") @@ -89,30 +89,30 @@ func TestToolFindSymbol_EmptyName(t *testing.T) { func TestToolGetCallers_Found(t *testing.T) { server := createTestServer() - result, isError := server.toolGetCallers("validate_user") + result, isError := server.toolGetCallers(map[string]interface{}{"function": "validate_user"}) assert.False(t, isError) assert.Contains(t, result, "callers") assert.Contains(t, result, "login") assert.Contains(t, result, "target") - assert.Contains(t, result, "total_callers") + assert.Contains(t, result, "pagination") } func TestToolGetCallers_NoCallers(t *testing.T) { server := createTestServer() // login has no callers in our test data. - result, isError := server.toolGetCallers("login") + result, isError := server.toolGetCallers(map[string]interface{}{"function": "login"}) assert.False(t, isError) - assert.Contains(t, result, "total_callers") - assert.Contains(t, result, `"total_callers": 0`) + assert.Contains(t, result, "pagination") + assert.Contains(t, result, `"total": 0`) } func TestToolGetCallers_NotFound(t *testing.T) { server := createTestServer() - result, isError := server.toolGetCallers("nonexistent_function") + result, isError := server.toolGetCallers(map[string]interface{}{"function": "nonexistent_function"}) assert.True(t, isError) assert.Contains(t, result, "not found") @@ -121,7 +121,7 @@ func TestToolGetCallers_NotFound(t *testing.T) { func TestToolGetCallers_EmptyName(t *testing.T) { server := createTestServer() - result, isError := server.toolGetCallers("") + result, isError := server.toolGetCallers(map[string]interface{}{"function": ""}) assert.True(t, isError) assert.Contains(t, result, "required") @@ -130,7 +130,7 @@ func TestToolGetCallers_EmptyName(t *testing.T) { func TestToolGetCallees_Found(t *testing.T) { server := createTestServer() - result, isError := server.toolGetCallees("login") + result, isError := server.toolGetCallees(map[string]interface{}{"function": "login"}) assert.False(t, isError) assert.Contains(t, result, "callees") @@ -143,16 +143,16 @@ func TestToolGetCallees_NoCallees(t *testing.T) { server := createTestServer() // validate_user has no callees in our test data. - result, isError := server.toolGetCallees("validate_user") + result, isError := server.toolGetCallees(map[string]interface{}{"function": "validate_user"}) assert.False(t, isError) - assert.Contains(t, result, "total_callees") + assert.Contains(t, result, "pagination") } func TestToolGetCallees_NotFound(t *testing.T) { server := createTestServer() - result, isError := server.toolGetCallees("nonexistent_function") + result, isError := server.toolGetCallees(map[string]interface{}{"function": "nonexistent_function"}) assert.True(t, isError) assert.Contains(t, result, "not found") @@ -161,7 +161,7 @@ func TestToolGetCallees_NotFound(t *testing.T) { func TestToolGetCallees_EmptyName(t *testing.T) { server := createTestServer() - result, isError := server.toolGetCallees("") + result, isError := server.toolGetCallees(map[string]interface{}{"function": ""}) assert.True(t, isError) assert.Contains(t, result, "required") @@ -387,7 +387,7 @@ func TestGetToolDefinitions(t *testing.T) { func TestToolFindSymbol_WithAllFields(t *testing.T) { server := createExtendedTestServer() - result, isError := server.toolFindSymbol("validate_user") + result, isError := server.toolFindSymbol(map[string]interface{}{"name": "validate_user"}) assert.False(t, isError) @@ -447,7 +447,7 @@ func TestToolGetCallDetails_TypeInference(t *testing.T) { func TestToolGetCallees_WithUnresolvedCalls(t *testing.T) { server := createExtendedTestServer() - result, isError := server.toolGetCallees("login") + result, isError := server.toolGetCallees(map[string]interface{}{"function": "login"}) assert.False(t, isError) assert.Contains(t, result, "callees") @@ -516,12 +516,12 @@ func TestToolResolveImport_PartialContainsMatch(t *testing.T) { func TestToolGetCallers_WithMultipleCallSites(t *testing.T) { server := createExtendedTestServer() - result, isError := server.toolGetCallers("validate_user") + result, isError := server.toolGetCallers(map[string]interface{}{"function": "validate_user"}) assert.False(t, isError) assert.Contains(t, result, "callers") assert.Contains(t, result, "login") - assert.Contains(t, result, "total_callers") + assert.Contains(t, result, "pagination") } func TestToolGetCallers_MultipleMatches(t *testing.T) { @@ -547,7 +547,7 @@ func TestToolGetCallers_MultipleMatches(t *testing.T) { Modules: map[string]string{}, FileToModule: map[string]string{}, ShortNames: map[string][]string{}, }, nil, time.Second) - result, isError := server.toolGetCallers("handler") + result, isError := server.toolGetCallers(map[string]interface{}{"function": "handler"}) assert.False(t, isError) // Should have a note about multiple matches. @@ -569,19 +569,19 @@ func TestToolGetCallers_NilCallerNode(t *testing.T) { Modules: map[string]string{}, FileToModule: map[string]string{}, ShortNames: map[string][]string{}, }, nil, time.Second) - result, isError := server.toolGetCallers("func") + result, isError := server.toolGetCallers(map[string]interface{}{"function": "func"}) // Should still succeed but skip the nil caller. assert.False(t, isError) - assert.Contains(t, result, "total_callers") - assert.Contains(t, result, `"total_callers": 0`) + assert.Contains(t, result, "pagination") + assert.Contains(t, result, `"total": 0`) } func TestToolFindSymbol_SubstringMatch(t *testing.T) { server := createExtendedTestServer() // "valid" should match "validate_user" via substring. - result, isError := server.toolFindSymbol("valid") + result, isError := server.toolFindSymbol(map[string]interface{}{"name": "valid"}) assert.False(t, isError) assert.Contains(t, result, "validate_user") @@ -591,7 +591,7 @@ func TestToolFindSymbol_FQNSubstringMatch(t *testing.T) { server := createExtendedTestServer() // "myapp.auth" should match via FQN substring. - result, isError := server.toolFindSymbol("myapp.auth") + result, isError := server.toolFindSymbol(map[string]interface{}{"name": "myapp.auth"}) assert.False(t, isError) assert.Contains(t, result, "validate_user")