diff --git a/sast-engine/mcp/errors.go b/sast-engine/mcp/errors.go new file mode 100644 index 00000000..d319333c --- /dev/null +++ b/sast-engine/mcp/errors.go @@ -0,0 +1,187 @@ +package mcp + +import ( + "encoding/json" + "fmt" +) + +// Standard JSON-RPC 2.0 error codes. +const ( + ErrCodeParseError = -32700 + ErrCodeInvalidRequest = -32600 + ErrCodeMethodNotFound = -32601 + ErrCodeInvalidParams = -32602 + ErrCodeInternalError = -32603 + + // Custom server error codes (-32000 to -32099). + ErrCodeSymbolNotFound = -32001 + ErrCodeIndexNotReady = -32002 + ErrCodeQueryTimeout = -32003 + ErrCodeResultsTruncated = -32004 +) + +// errorMessages maps error codes to default messages. +var errorMessages = map[int]string{ + ErrCodeParseError: "Parse error", + ErrCodeInvalidRequest: "Invalid Request", + ErrCodeMethodNotFound: "Method not found", + ErrCodeInvalidParams: "Invalid params", + ErrCodeInternalError: "Internal error", + ErrCodeSymbolNotFound: "Symbol not found", + ErrCodeIndexNotReady: "Index not ready", + ErrCodeQueryTimeout: "Query timeout", + ErrCodeResultsTruncated: "Results truncated", +} + +// Error implements the error interface for RPCError. +func (e *RPCError) Error() string { + return fmt.Sprintf("[%d] %s", e.Code, e.Message) +} + +// NewRPCError creates a new RPC error with optional data. +func NewRPCError(code int, data interface{}) *RPCError { + message := errorMessages[code] + if message == "" { + message = "Unknown error" + } + return &RPCError{ + Code: code, + Message: message, + Data: data, + } +} + +// NewRPCErrorWithMessage creates an RPC error with custom message. +func NewRPCErrorWithMessage(code int, message string, data interface{}) *RPCError { + return &RPCError{ + Code: code, + Message: message, + Data: data, + } +} + +// ParseError creates a parse error response. +func ParseError(detail string) *RPCError { + return NewRPCErrorWithMessage(ErrCodeParseError, "Parse error: "+detail, nil) +} + +// InvalidRequestError creates an invalid request error. +func InvalidRequestError(detail string) *RPCError { + return NewRPCErrorWithMessage(ErrCodeInvalidRequest, "Invalid request: "+detail, nil) +} + +// MethodNotFoundError creates a method not found error. +func MethodNotFoundError(method string) *RPCError { + return NewRPCErrorWithMessage(ErrCodeMethodNotFound, + fmt.Sprintf("Method not found: %s", method), + map[string]string{"method": method}) +} + +// InvalidParamsError creates an invalid params error. +func InvalidParamsError(detail string) *RPCError { + return NewRPCErrorWithMessage(ErrCodeInvalidParams, "Invalid params: "+detail, nil) +} + +// InternalError creates an internal error. +func InternalError(detail string) *RPCError { + return NewRPCErrorWithMessage(ErrCodeInternalError, "Internal error: "+detail, nil) +} + +// SymbolNotFoundError creates a symbol not found error with suggestions. +func SymbolNotFoundError(symbol string, suggestions []string) *RPCError { + data := map[string]interface{}{ + "symbol": symbol, + } + if len(suggestions) > 0 { + data["suggestions"] = suggestions + } + return NewRPCErrorWithMessage(ErrCodeSymbolNotFound, + fmt.Sprintf("Symbol not found: %s", symbol), data) +} + +// IndexNotReadyError creates an index not ready error. +func IndexNotReadyError() *RPCError { + return NewRPCError(ErrCodeIndexNotReady, nil) +} + +// QueryTimeoutError creates a query timeout error. +func QueryTimeoutError(timeout string) *RPCError { + return NewRPCErrorWithMessage(ErrCodeQueryTimeout, + fmt.Sprintf("Query timed out after %s", timeout), + map[string]string{"timeout": timeout}) +} + +// MakeErrorResponse creates a JSON-RPC error response from an RPCError. +func MakeErrorResponse(id interface{}, err *RPCError) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: id, + Error: err, + } +} + +// ToolError represents a structured tool error response. +type ToolError struct { + Error string `json:"error"` + Code int `json:"code,omitempty"` + Details interface{} `json:"details,omitempty"` +} + +// NewToolError creates a JSON-formatted tool error response. +func NewToolError(message string, code int, details interface{}) string { + te := ToolError{ + Error: message, + Code: code, + Details: details, + } + bytes, _ := json.Marshal(te) + return string(bytes) +} + +// ValidateRequiredParams checks for required parameters. +func ValidateRequiredParams(args map[string]interface{}, required []string) *RPCError { + missing := []string{} + for _, param := range required { + if _, ok := args[param]; !ok { + missing = append(missing, param) + } + } + if len(missing) > 0 { + return InvalidParamsError(fmt.Sprintf("missing required parameters: %v", missing)) + } + return nil +} + +// ValidateStringParam validates a string parameter. +func ValidateStringParam(args map[string]interface{}, name string) (string, *RPCError) { + val, ok := args[name] + if !ok { + return "", InvalidParamsError(fmt.Sprintf("missing required parameter: %s", name)) + } + str, ok := val.(string) + if !ok { + return "", InvalidParamsError(fmt.Sprintf("parameter %s must be a string", name)) + } + if str == "" { + return "", InvalidParamsError(fmt.Sprintf("parameter %s cannot be empty", name)) + } + return str, nil +} + +// ValidateIntParam validates an integer parameter with optional default. +func ValidateIntParam(args map[string]interface{}, name string, defaultVal int) (int, *RPCError) { + val, ok := args[name] + if !ok { + return defaultVal, nil + } + + // JSON numbers come as float64. + switch v := val.(type) { + case float64: + return int(v), nil + case int: + return v, nil + default: + return 0, InvalidParamsError(fmt.Sprintf("parameter %s must be a number", name)) + } +} diff --git a/sast-engine/mcp/errors_test.go b/sast-engine/mcp/errors_test.go new file mode 100644 index 00000000..18953ba9 --- /dev/null +++ b/sast-engine/mcp/errors_test.go @@ -0,0 +1,292 @@ +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRPCError_Error(t *testing.T) { + err := &RPCError{Code: -32600, Message: "Invalid Request"} + assert.Equal(t, "[-32600] Invalid Request", err.Error()) +} + +func TestNewRPCError(t *testing.T) { + err := NewRPCError(ErrCodeParseError, nil) + + assert.Equal(t, ErrCodeParseError, err.Code) + assert.Equal(t, "Parse error", err.Message) + assert.Nil(t, err.Data) +} + +func TestNewRPCError_WithData(t *testing.T) { + data := map[string]string{"detail": "unexpected EOF"} + err := NewRPCError(ErrCodeParseError, data) + + assert.Equal(t, data, err.Data) +} + +func TestNewRPCError_UnknownCode(t *testing.T) { + err := NewRPCError(-99999, nil) + assert.Equal(t, "Unknown error", err.Message) +} + +func TestNewRPCErrorWithMessage(t *testing.T) { + err := NewRPCErrorWithMessage(-32600, "Custom message", nil) + + assert.Equal(t, -32600, err.Code) + assert.Equal(t, "Custom message", err.Message) +} + +func TestParseError(t *testing.T) { + err := ParseError("unexpected token") + + assert.Equal(t, ErrCodeParseError, err.Code) + assert.Contains(t, err.Message, "unexpected token") +} + +func TestInvalidRequestError(t *testing.T) { + err := InvalidRequestError("missing jsonrpc field") + + assert.Equal(t, ErrCodeInvalidRequest, err.Code) + assert.Contains(t, err.Message, "missing jsonrpc field") +} + +func TestMethodNotFoundError(t *testing.T) { + err := MethodNotFoundError("unknown/method") + + assert.Equal(t, ErrCodeMethodNotFound, err.Code) + assert.Contains(t, err.Message, "unknown/method") + + data := err.Data.(map[string]string) + assert.Equal(t, "unknown/method", data["method"]) +} + +func TestInvalidParamsError(t *testing.T) { + err := InvalidParamsError("symbol is required") + + assert.Equal(t, ErrCodeInvalidParams, err.Code) + assert.Contains(t, err.Message, "symbol is required") +} + +func TestInternalError(t *testing.T) { + err := InternalError("database connection failed") + + assert.Equal(t, ErrCodeInternalError, err.Code) + assert.Contains(t, err.Message, "database connection failed") +} + +func TestSymbolNotFoundError(t *testing.T) { + err := SymbolNotFoundError("MyClass", []string{"MyClass2", "MyClassHelper"}) + + assert.Equal(t, ErrCodeSymbolNotFound, err.Code) + assert.Contains(t, err.Message, "MyClass") + + data := err.Data.(map[string]interface{}) + assert.Equal(t, "MyClass", data["symbol"]) + assert.Contains(t, data["suggestions"], "MyClass2") +} + +func TestSymbolNotFoundError_NoSuggestions(t *testing.T) { + err := SymbolNotFoundError("xyz", nil) + + data := err.Data.(map[string]interface{}) + _, hasSuggestions := data["suggestions"] + assert.False(t, hasSuggestions) +} + +func TestIndexNotReadyError(t *testing.T) { + err := IndexNotReadyError() + + assert.Equal(t, ErrCodeIndexNotReady, err.Code) + assert.Equal(t, "Index not ready", err.Message) +} + +func TestQueryTimeoutError(t *testing.T) { + err := QueryTimeoutError("30s") + + assert.Equal(t, ErrCodeQueryTimeout, err.Code) + assert.Contains(t, err.Message, "30s") +} + +func TestMakeErrorResponse(t *testing.T) { + rpcErr := ParseError("bad json") + resp := MakeErrorResponse(1, rpcErr) + + assert.Equal(t, "2.0", resp.JSONRPC) + assert.Equal(t, 1, resp.ID) + assert.Equal(t, rpcErr, resp.Error) + assert.Nil(t, resp.Result) +} + +func TestMakeErrorResponse_NilID(t *testing.T) { + rpcErr := ParseError("bad json") + resp := MakeErrorResponse(nil, rpcErr) + + assert.Equal(t, "2.0", resp.JSONRPC) + assert.Nil(t, resp.ID) + assert.Equal(t, rpcErr, resp.Error) +} + +func TestNewToolError(t *testing.T) { + result := NewToolError("Symbol not found", ErrCodeSymbolNotFound, map[string]string{"symbol": "foo"}) + + var parsed ToolError + err := json.Unmarshal([]byte(result), &parsed) + require.NoError(t, err) + + assert.Equal(t, "Symbol not found", parsed.Error) + assert.Equal(t, ErrCodeSymbolNotFound, parsed.Code) +} + +func TestNewToolError_NoDetails(t *testing.T) { + result := NewToolError("Internal error", ErrCodeInternalError, nil) + + var parsed ToolError + err := json.Unmarshal([]byte(result), &parsed) + require.NoError(t, err) + + assert.Equal(t, "Internal error", parsed.Error) + assert.Nil(t, parsed.Details) +} + +func TestValidateRequiredParams(t *testing.T) { + args := map[string]interface{}{ + "name": "test", + } + + // Has required param. + err := ValidateRequiredParams(args, []string{"name"}) + assert.Nil(t, err) + + // Missing required param. + err = ValidateRequiredParams(args, []string{"name", "value"}) + assert.NotNil(t, err) + assert.Equal(t, ErrCodeInvalidParams, err.Code) + assert.Contains(t, err.Message, "value") +} + +func TestValidateRequiredParams_AllPresent(t *testing.T) { + args := map[string]interface{}{ + "name": "test", + "value": 123, + } + + err := ValidateRequiredParams(args, []string{"name", "value"}) + assert.Nil(t, err) +} + +func TestValidateRequiredParams_Empty(t *testing.T) { + args := map[string]interface{}{} + + err := ValidateRequiredParams(args, []string{"name"}) + assert.NotNil(t, err) + assert.Contains(t, err.Message, "name") +} + +func TestValidateStringParam(t *testing.T) { + args := map[string]interface{}{ + "name": "test", + "number": 123, + "empty": "", + } + + // Valid string. + val, err := ValidateStringParam(args, "name") + assert.Nil(t, err) + assert.Equal(t, "test", val) + + // Missing param. + _, err = ValidateStringParam(args, "missing") + assert.NotNil(t, err) + assert.Contains(t, err.Message, "missing") + + // Wrong type. + _, err = ValidateStringParam(args, "number") + assert.NotNil(t, err) + assert.Contains(t, err.Message, "must be a string") + + // Empty string. + _, err = ValidateStringParam(args, "empty") + assert.NotNil(t, err) + assert.Contains(t, err.Message, "cannot be empty") +} + +func TestValidateIntParam(t *testing.T) { + args := map[string]interface{}{ + "count": float64(10), + "limit": 100, + "name": "test", + } + + // Valid float64 (from JSON). + val, err := ValidateIntParam(args, "count", 5) + assert.Nil(t, err) + assert.Equal(t, 10, val) + + // Valid int. + val, err = ValidateIntParam(args, "limit", 50) + assert.Nil(t, err) + assert.Equal(t, 100, val) + + // Missing - use default. + val, err = ValidateIntParam(args, "missing", 25) + assert.Nil(t, err) + assert.Equal(t, 25, val) + + // Wrong type. + _, err = ValidateIntParam(args, "name", 0) + assert.NotNil(t, err) + assert.Contains(t, err.Message, "must be a number") +} + +func TestErrorCodes_AreCorrect(t *testing.T) { + // Verify standard JSON-RPC 2.0 codes. + assert.Equal(t, -32700, ErrCodeParseError) + assert.Equal(t, -32600, ErrCodeInvalidRequest) + assert.Equal(t, -32601, ErrCodeMethodNotFound) + assert.Equal(t, -32602, ErrCodeInvalidParams) + assert.Equal(t, -32603, ErrCodeInternalError) + + // Verify custom codes are in valid range. + assert.True(t, ErrCodeSymbolNotFound >= -32099 && ErrCodeSymbolNotFound <= -32000) + assert.True(t, ErrCodeIndexNotReady >= -32099 && ErrCodeIndexNotReady <= -32000) + assert.True(t, ErrCodeQueryTimeout >= -32099 && ErrCodeQueryTimeout <= -32000) + assert.True(t, ErrCodeResultsTruncated >= -32099 && ErrCodeResultsTruncated <= -32000) +} + +func TestRPCError_JSONSerialization(t *testing.T) { + err := SymbolNotFoundError("foo", []string{"foobar"}) + + bytes, jsonErr := json.Marshal(err) + require.NoError(t, jsonErr) + + var parsed RPCError + jsonErr = json.Unmarshal(bytes, &parsed) + require.NoError(t, jsonErr) + + assert.Equal(t, err.Code, parsed.Code) + assert.Equal(t, err.Message, parsed.Message) +} + +func TestAllErrorMessages_Defined(t *testing.T) { + // Verify all error codes have messages. + codes := []int{ + ErrCodeParseError, + ErrCodeInvalidRequest, + ErrCodeMethodNotFound, + ErrCodeInvalidParams, + ErrCodeInternalError, + ErrCodeSymbolNotFound, + ErrCodeIndexNotReady, + ErrCodeQueryTimeout, + ErrCodeResultsTruncated, + } + + for _, code := range codes { + msg := errorMessages[code] + assert.NotEmpty(t, msg, "Error code %d should have a message", code) + } +} diff --git a/sast-engine/mcp/server.go b/sast-engine/mcp/server.go index 489079b0..f0316758 100644 --- a/sast-engine/mcp/server.go +++ b/sast-engine/mcp/server.go @@ -66,7 +66,7 @@ func (s *Server) ServeStdio() error { // Parse JSON-RPC request. var request JSONRPCRequest if err := json.Unmarshal([]byte(line), &request); err != nil { - s.sendResponse(ErrorResponse(nil, -32700, "Parse error: "+err.Error())) + s.sendResponse(MakeErrorResponse(nil, ParseError(err.Error()))) continue } @@ -92,6 +92,16 @@ func (s *Server) sendResponse(resp *JSONRPCResponse) { func (s *Server) handleRequest(req *JSONRPCRequest) *JSONRPCResponse { startTime := time.Now() + // Validate JSON-RPC version. + if req.JSONRPC != "2.0" { + return MakeErrorResponse(req.ID, InvalidRequestError("jsonrpc must be '2.0'")) + } + + // Validate method exists. + if req.Method == "" { + return MakeErrorResponse(req.ID, InvalidRequestError("method is required")) + } + var response *JSONRPCResponse switch req.Method { @@ -111,7 +121,7 @@ func (s *Server) handleRequest(req *JSONRPCRequest) *JSONRPCResponse { case "ping": response = SuccessResponse(req.ID, map[string]string{"status": "ok"}) default: - response = ErrorResponse(req.ID, -32601, fmt.Sprintf("Method not found: %s", req.Method)) + response = MakeErrorResponse(req.ID, MethodNotFoundError(req.Method)) } // Log request timing. @@ -156,7 +166,11 @@ func (s *Server) handleToolsList(req *JSONRPCRequest) *JSONRPCResponse { func (s *Server) handleToolsCall(req *JSONRPCRequest) *JSONRPCResponse { var params ToolCallParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { - return ErrorResponse(req.ID, -32602, "Invalid params: "+err.Error()) + return MakeErrorResponse(req.ID, InvalidParamsError(err.Error())) + } + + if params.Name == "" { + return MakeErrorResponse(req.ID, InvalidParamsError("tool name is required")) } fmt.Fprintf(os.Stderr, "Tool call: %s\n", params.Name) diff --git a/sast-engine/mcp/server_test.go b/sast-engine/mcp/server_test.go index 9aa51652..1941f4ae 100644 --- a/sast-engine/mcp/server_test.go +++ b/sast-engine/mcp/server_test.go @@ -511,3 +511,80 @@ func TestSendResponse(t *testing.T) { server.sendResponse(resp) }) } + +// ============================================================================ +// Error Handling Tests (PR-02) +// ============================================================================ + +func TestHandleRequest_InvalidJSONRPCVersion(t *testing.T) { + server := createTestServer() + + req := &JSONRPCRequest{ + JSONRPC: "1.0", // Invalid version + ID: 1, + Method: "ping", + } + + resp := server.handleRequest(req) + + assert.NotNil(t, resp) + assert.NotNil(t, resp.Error) + assert.Equal(t, ErrCodeInvalidRequest, resp.Error.Code) + assert.Contains(t, resp.Error.Message, "jsonrpc must be '2.0'") +} + +func TestHandleRequest_MissingMethod(t *testing.T) { + server := createTestServer() + + req := &JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "", // Missing method + } + + resp := server.handleRequest(req) + + assert.NotNil(t, resp) + assert.NotNil(t, resp.Error) + assert.Equal(t, ErrCodeInvalidRequest, resp.Error.Code) + assert.Contains(t, resp.Error.Message, "method is required") +} + +func TestHandleToolsCall_MissingToolName(t *testing.T) { + server := createTestServer() + + req := &JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/call", + Params: json.RawMessage(`{"name":"","arguments":{}}`), + } + + resp := server.handleToolsCall(req) + + assert.NotNil(t, resp) + assert.NotNil(t, resp.Error) + assert.Equal(t, ErrCodeInvalidParams, resp.Error.Code) + assert.Contains(t, resp.Error.Message, "tool name is required") +} + +func TestHandleRequest_MethodNotFound_WithErrorData(t *testing.T) { + server := createTestServer() + + req := &JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "unknown/method", + } + + resp := server.handleRequest(req) + + assert.NotNil(t, resp) + assert.NotNil(t, resp.Error) + assert.Equal(t, ErrCodeMethodNotFound, resp.Error.Code) + assert.Contains(t, resp.Error.Message, "unknown/method") + + // Verify error has method in data. + data := resp.Error.Data.(map[string]string) + assert.Equal(t, "unknown/method", data["method"]) +}