diff --git a/.gitignore b/.gitignore index 1e48113..8e7c504 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ tmp/ *.log # Go -vendor \ No newline at end of file +vendor +.cursor/ diff --git a/.licenserc.yaml b/.licenserc.yaml new file mode 100644 index 0000000..5190619 --- /dev/null +++ b/.licenserc.yaml @@ -0,0 +1,40 @@ +# Copyright 2024 StreamNative +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +header: + license: + spdx-id: Apache-2.0 + copyright-owner: StreamNative + + paths-ignore: + - 'dist' + - 'licenses' + - '**/*.md' + - 'LICENSE' + - 'NOTICE' + - '.github/**' + - 'PROJECT' + - '**/go.mod' + - '**/go.work' + - '**/go.work.sum' + - '**/go.sum' + - '**/*.json' + - 'sdk/**' + - '**/*.yaml' + - '**/*.yml' + - 'Makefile' + - '.gitignore' + + comment: on-failure diff --git a/Makefile b/Makefile index 05f493f..569cb67 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,11 @@ build: -X ${VERSION_PATH}.date=${BUILD_DATE}" \ -o bin/snmcp cmd/streamnative-mcp-server/main.go -# go install github.com/elastic/go-licenser@latest -.PHONY: fix-license -fix-license: - go-licenser -license ASL2 -exclude sdk +.PHONY: license-check +license-check: + license-eye header check + +# go install github.com/apache/skywalking-eyes/cmd/license-eye@latest +.PHONY: license-fix +license-fix: + license-eye header fix diff --git a/README.md b/README.md index e72a5a7..39277e4 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,7 @@ The StreamNative MCP Server allows you to enable or disable specific groups of f | Feature | Description | Docs | |---------------------|------------------------------------------------------------------|------| | `streamnative-cloud`| Manage StreamNative Cloud context and check resource logs | [streamnative_cloud.md](docs/tools/streamnative_cloud.md) | +| `functions-as-tools` | Dynamically exposes deployed Pulsar Functions as invokable MCP tools, with automatic input/output schema handling. | [functions_as_tools.md](docs/tools/functions_as_tools.md) | You can combine these features as needed using the `--features` flag. For example, to enable only Pulsar client features: ```bash diff --git a/docs/tools/functions_as_tools.md b/docs/tools/functions_as_tools.md new file mode 100644 index 0000000..81a3e91 --- /dev/null +++ b/docs/tools/functions_as_tools.md @@ -0,0 +1,107 @@ +# Functions as Tools + +The "Functions as Tools" feature allows the StreamNative MCP Server to dynamically discover Apache Pulsar Functions deployed in your cluster and expose them as invokable MCP tools for AI agents. This significantly enhances the capabilities of AI agents by allowing them to interact with custom business logic encapsulated in Pulsar Functions without manual tool registration for each function. + +## How it Works + +### 1. Function Discovery +The MCP Server automatically discovers Pulsar Functions available in the connected Pulsar cluster. It periodically polls for functions and identifies those suitable for exposure as tools. + +By default, if no custom name is provided (see Customizing Tool Properties), the MCP tool name might be derived from the Function's Fully Qualified Name (FQN), such as `pulsar_function_$tenant_$namespace_$name`. + +### 2. Schema Conversion +For each discovered function, the MCP Server attempts to extract its input and output schema definitions. Pulsar Functions can be defined with various schema types for their inputs and outputs (e.g., primitive types, AVRO, JSON). + +The server then converts these native Pulsar schemas into a format compatible with MCP tools. This allows the AI agent to understand the expected input parameters and the structure of the output. + +Supported Pulsar schema types for automatic conversion include: +* Primitive types (String, Boolean, Numbers like INT8, INT16, INT32, INT64, FLOAT, DOUBLE) +* AVRO +* JSON + +If a function uses an unsupported schema type for its input or output, or if schemas are not clearly defined, it might not be exposed as an MCP tool. + +## Enabling the Feature +To enable this functionality, you need to specific the default `--pulsar-instance` and `--pulsar-cluster`, and include `functions-as-tools` in the `--features` flag when starting the StreamNative MCP Server. + +Example: +```bash +snmcp sse --organization my-org --key-file /path/to/key-file.json --features pulsar-admin,pulsar-client,functions-as-tools --pulsar-instance instance --pulsar-cluster cluster +``` +If `functions-as-tools` is part of a broader feature set like `all` and `streamnative-cloud`, enabling `all` or `streamnative-cloud` would also activate this feature. + +## Customizing Tool Properties +You can customize how your Pulsar Functions appear as MCP tools (their name and description) by providing specific runtime options when deploying or updating your functions. This is done using the `--custom-runtime-options` flag with `pulsar-admin functions create` or `pulsar-admin functions update`. + +The MCP Server looks for the following environment variables within the custom runtime options: +* `MCP_TOOL_NAME`: Specifies the desired name for the MCP tool. +* `MCP_TOOL_DESCRIPTION`: Provides a description for the MCP tool, which helps the AI agent understand its purpose. + +**Format for `--custom-runtime-options`**: +The options should be a JSON string where you define an `env` map containing `MCP_TOOL_NAME` and `MCP_TOOL_DESCRIPTION`. + +**Example**: +When deploying a Pulsar Function, you can set these properties as follows: +```bash +pulsar-admin functions create \ + --tenant public \ + --namespace default \ + --name my-custom-logic-function \ + --inputs "persistent://public/default/input-topic" \ + --output "persistent://public/default/output-topic" \ + --py my_function.py \ + --classname my_function.MyFunction \ + --custom-runtime-options \ + ''' + { + "env": { + "MCP_TOOL_NAME": "CustomObjectFunction", + "MCP_TOOL_DESCRIPTION": "Takes an input number and returns the value incremented by 100." + } + } + ''' +``` +In this example: +- The MCP tool derived from `my-custom-logic-function` will be named `CustomObjectFunction`. +- Its description will be "Takes an input number and returns the value incremented by 100." + +If these custom options are not provided, the MCP tool name might default to a derivative of the function's FQN, and the description might be generic and cannot help AI Agent to understand the purpose of the MCP tool. + +## Server-Side Configuration via Environment Variables + +Beyond customizing individual tool properties at the function deployment level, you can also configure the overall behavior of the "Functions as Tools" feature on the StreamNative MCP Server side using the following environment variables. These variables are typically set when starting the MCP server. + +* `FUNCTIONS_AS_TOOLS_POLL_INTERVAL` + * **Description**: Controls how frequently the MCP Server polls the Pulsar cluster to discover or update available Pulsar Functions. Setting a lower value means functions are discovered faster, but it may increase the load on the Pulsar cluster. + * **Unit**: Seconds + * **Default**: Defaults to the value specified in `pftools.DefaultManagerOptions()`. Refer to the `pkg/pftools` package for the precise default (e.g., if the internal default is 60 seconds, it will be `60`). +* `FUNCTIONS_AS_TOOLS_TIMEOUT` + * **Description**: Sets the default timeout for invoking a Pulsar Function as an MCP tool. If a function execution exceeds this duration, the call will be considered timed out. + * **Unit**: Seconds + * **Default**: Defaults to the value specified in `pftools.DefaultManagerOptions()` (e.g., if the internal default is 30 seconds, it will be `30`). +* `FUNCTIONS_AS_TOOLS_FAILURE_THRESHOLD` + * **Description**: Defines the number of consecutive failures for a specific Pulsar Function tool before it is temporarily moved to a "circuit breaker open" state. In this state, further calls to this specific function tool will be immediately rejected without attempting to execute the function, until the `FUNCTIONS_AS_TOOLS_RESET_TIMEOUT` is reached. + * **Unit**: Integer (number of failures) + * **Default**: Defaults to the value specified in `pftools.DefaultManagerOptions()` (e.g., if the internal default is 5, it will be `5`). +* `FUNCTIONS_AS_TOOLS_RESET_TIMEOUT` + * **Description**: Specifies the duration for which a Pulsar Function tool remains in the "circuit breaker open" state (due to exceeding the failure threshold) before the MCP server attempts to reset the circuit and allow calls again. + * **Unit**: Seconds + * **Default**: Defaults to the value specified in `pftools.DefaultManagerOptions()` (e.g., if the internal default is 60 seconds, it will be `60`). +* `FUNCTIONS_AS_TOOLS_TENANT_NAMESPACES` + * **Description**: A comma-separated list of Pulsar `tenant/namespace` strings that the MCP Server should scan for Pulsar Functions. This allows you to restrict function discovery to specific namespaces. If not set, the server might attempt to discover functions from all namespaces it has access to, as permitted by its Pulsar client configuration. + * **Format**: `tenant1/namespace1,tenant2/namespace2` + * **Example**: `public/default,my-tenant/app-functions` + * **Default**: Empty (meaning discover from all accessible namespaces (only on StreamNative Cloud)). +* `FUNCTIONS_AS_TOOLS_STRICT_EXPORT` + * **Description**: Only export functions with `MCP_TOOL_NAME` and `MCP_TOOL_DESCRIPTION` defined. + * **Format**: `true` or `false` + * **Example**: `false` + * **Default**: `true` + +## Considerations and Limitations + +* **Schema Definition**: For reliable schema conversion, ensure your Pulsar Functions have clearly defined input and output schemas using Pulsar's schema registry capabilities. Functions with ambiguous or `BYTES` schemas might not be converted effectively or might default to generic byte array inputs/outputs. +* **Function State**: This feature primarily focuses on the stateless request/response invocation pattern of functions. +* **Discovery Latency**: There might be a slight delay between deploying/updating a function and it appearing as an MCP tool, due to the server's polling interval for function discovery. +* **Error Handling**: The MCP Server will attempt to relay errors from function executions, but the specifics might vary. +* **Security**: Ensure that only intended functions are exposed by managing permissions within your Pulsar cluster. The MCP Server will operate with the permissions of its Pulsar client. diff --git a/go.mod b/go.mod index 923ebc3..899189a 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,9 @@ require ( github.com/99designs/keyring v1.2.2 github.com/apache/pulsar-client-go v0.13.1 github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/google/go-cmp v0.7.0 github.com/hamba/avro/v2 v2.28.0 - github.com/mark3labs/mcp-go v0.27.0 + github.com/mark3labs/mcp-go v0.28.0 github.com/mitchellh/go-homedir v1.1.0 github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.9.3 @@ -16,6 +17,7 @@ require ( github.com/streamnative/pulsarctl v0.4.3-0.20250312214758-e472faec284b github.com/streamnative/streamnative-mcp-server/sdk/sdk-apiserver v0.0.0-20250506174209-b67ea08ddd82 github.com/streamnative/streamnative-mcp-server/sdk/sdk-kafkaconnect v0.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.10.0 github.com/twmb/franz-go v1.18.1 github.com/twmb/franz-go/pkg/kadm v1.16.0 github.com/twmb/franz-go/pkg/sr v1.3.0 diff --git a/go.sum b/go.sum index 9570fdb..835b7fd 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,3 @@ -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= @@ -91,9 +90,8 @@ github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -131,8 +129,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mark3labs/mcp-go v0.27.0 h1:iok9kU4DUIU2/XVLgFS2Q9biIDqstC0jY4EQTK2Erzc= -github.com/mark3labs/mcp-go v0.27.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.28.0 h1:7yl4y5D1KYU2f/9Uxp7xfLIggfunHoESCRbrjcytcLM= +github.com/mark3labs/mcp-go v0.28.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= diff --git a/go.work.sum b/go.work.sum index fc10c18..22423cd 100644 --- a/go.work.sum +++ b/go.work.sum @@ -138,6 +138,8 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+ github.com/lithammer/dedent v1.1.0/go.mod h1:jrXYCQtgg0nJiN+StA2KgR7w6CiQNv9Fd/Z9BP0jIOc= github.com/mark3labs/mcp-go v0.23.1 h1:RzTzZ5kJ+HxwnutKA4rll8N/pKV6Wh5dhCmiJUu5S9I= github.com/mark3labs/mcp-go v0.23.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.28.0 h1:7yl4y5D1KYU2f/9Uxp7xfLIggfunHoESCRbrjcytcLM= +github.com/mark3labs/mcp-go v0.28.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y= github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= diff --git a/pkg/cmd/mcp/sse.go b/pkg/cmd/mcp/sse.go index f46431c..3222c61 100644 --- a/pkg/cmd/mcp/sse.go +++ b/pkg/cmd/mcp/sse.go @@ -31,6 +31,7 @@ import ( "github.com/mark3labs/mcp-go/server" "github.com/pkg/errors" "github.com/spf13/cobra" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/mcp" ) @@ -68,24 +69,29 @@ func runSseServer(configOpts *ServerOptions) error { } // 3. Create a new MCP server - ctx = context.WithValue(ctx, mcp.OptionsKey, configOpts.Options) - mcpServer := server.NewSSEServer( - newMcpServer(configOpts, logger), + ctx = context.WithValue(ctx, common.OptionsKey, configOpts.Options) + mcpServer := newMcpServer(configOpts, logger) + + // add Pulsar Functions as MCP tools + mcp.PulsarFunctionManagedMcpTools(mcpServer, false, configOpts.Features) + + sseServer := server.NewSSEServer( + mcpServer, server.WithStaticBasePath(configOpts.HTTPPath), server.WithHTTPContextFunc(func(ctx context.Context, _ *http.Request) context.Context { - return context.WithValue(ctx, mcp.OptionsKey, configOpts.Options) + return context.WithValue(ctx, common.OptionsKey, configOpts.Options) }), ) // 4. Expose the full SSE URL to the user - ssePath := mcpServer.CompleteSsePath() + ssePath := sseServer.CompleteSsePath() fmt.Fprintf(os.Stderr, "StreamNative Cloud MCP Server listening on http://%s%s\n", configOpts.HTTPAddr, ssePath) // 5. Run the HTTP listener in a goroutine errCh := make(chan error, 1) go func() { - if err := mcpServer.Start(configOpts.HTTPAddr); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := sseServer.Start(configOpts.HTTPAddr); err != nil && !errors.Is(err, http.ErrServerClosed) { errCh <- err // bubble up real crashes } }() @@ -108,7 +114,7 @@ func runSseServer(configOpts *ServerOptions) error { defer cancel() // First try to shut down the SSE server - if err := mcpServer.Shutdown(shCtx); err != nil { + if err := sseServer.Shutdown(shCtx); err != nil { if !errors.Is(err, http.ErrServerClosed) { logger.Errorf("Error shutting down SSE server: %v", err) } diff --git a/pkg/cmd/mcp/stdio.go b/pkg/cmd/mcp/stdio.go index f447599..23e3ef1 100644 --- a/pkg/cmd/mcp/stdio.go +++ b/pkg/cmd/mcp/stdio.go @@ -30,8 +30,8 @@ import ( "github.com/mark3labs/mcp-go/server" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/log" - "github.com/streamnative/streamnative-mcp-server/pkg/mcp" ) func NewCmdMcpStdioServer(configOpts *ServerOptions) *cobra.Command { @@ -63,7 +63,7 @@ func runStdioServer(configOpts *ServerOptions) error { } // Create a new MCP server - ctx = context.WithValue(ctx, mcp.OptionsKey, configOpts.Options) + ctx = context.WithValue(ctx, common.OptionsKey, configOpts.Options) stdLogger := stdlog.New(logger.Writer(), "snmcp-server", 0) stdioServer := server.NewStdioServer(newMcpServer(configOpts, logger)) diff --git a/pkg/mcp/utils.go b/pkg/common/utils.go similarity index 74% rename from pkg/mcp/utils.go rename to pkg/common/utils.go index fef2ca3..e607701 100644 --- a/pkg/mcp/utils.go +++ b/pkg/common/utils.go @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package mcp +package common import ( "context" @@ -37,12 +37,12 @@ const ( TokenRefreshWindow = 5 * time.Minute ) -// requiredParam is a helper function that can be used to fetch a requested parameter from the request. +// RequiredParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request. // 2. Checks if the parameter is of the expected type. // 3. Checks if the parameter is not empty, i.e: non-zero value -func requiredParam[T comparable](arguments map[string]interface{}, p string) (T, error) { +func RequiredParam[T comparable](arguments map[string]interface{}, p string) (T, error) { var zero T // Check if the parameter is present in the request @@ -55,16 +55,29 @@ func requiredParam[T comparable](arguments map[string]interface{}, p string) (T, return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) } - if arguments[p].(T) == zero { - return zero, fmt.Errorf("missing required parameter: %s", p) - + _, isBool := interface{}(zero).(bool) + _, isInt := interface{}(zero).(int) + _, isInt8 := interface{}(zero).(int8) + _, isInt16 := interface{}(zero).(int16) + _, isInt32 := interface{}(zero).(int32) + _, isInt64 := interface{}(zero).(int64) + _, isFloat32 := interface{}(zero).(float32) + _, isFloat64 := interface{}(zero).(float64) + _, isUint8 := interface{}(zero).(uint8) + _, isUint16 := interface{}(zero).(uint16) + _, isUint32 := interface{}(zero).(uint32) + _, isUint64 := interface{}(zero).(uint64) + if !isBool && !isInt && !isInt8 && !isInt16 && !isInt32 && !isInt64 && !isFloat32 && !isFloat64 && !isUint8 && !isUint16 && !isUint32 && !isUint64 { + if arguments[p].(T) == zero { + return zero, fmt.Errorf("missing required parameter: %s", p) + } } return arguments[p].(T), nil } // Helper function to get an optional parameter from the request -func optionalParam[T any](arguments map[string]interface{}, paramName string) (T, bool) { +func OptionalParam[T any](arguments map[string]interface{}, paramName string) (T, bool) { var empty T param, ok := arguments[paramName] if !ok { @@ -80,7 +93,7 @@ func optionalParam[T any](arguments map[string]interface{}, paramName string) (T } // Helper function to get a required array parameter from the request -func requiredParamArray[T any](arguments map[string]interface{}, paramName string) ([]T, error) { +func RequiredParamArray[T any](arguments map[string]interface{}, paramName string) ([]T, error) { var empty []T param, ok := arguments[paramName] if !ok { @@ -104,7 +117,7 @@ func requiredParamArray[T any](arguments map[string]interface{}, paramName strin return result, nil } -func optionalParamArray[T any](arguments map[string]interface{}, paramName string) ([]T, bool) { +func OptionalParamArray[T any](arguments map[string]interface{}, paramName string) ([]T, bool) { var empty []T param, ok := arguments[paramName] if !ok { @@ -128,8 +141,8 @@ func optionalParamArray[T any](arguments map[string]interface{}, paramName strin return result, true } -// optionalParamConfigs gets an optional parameter as a list of key=value strings -func optionalParamConfigs(arguments map[string]interface{}, paramName string) ([]string, bool) { +// OptionalParamConfigs gets an optional parameter as a list of key=value strings +func OptionalParamConfigs(arguments map[string]interface{}, paramName string) ([]string, bool) { param, ok := arguments[paramName] if !ok { @@ -151,8 +164,8 @@ func optionalParamConfigs(arguments map[string]interface{}, paramName string) ([ return result, true } -// requiredParamObject gets a required object parameter from the request -func requiredParamObject(arguments map[string]interface{}, name string) (map[string]interface{}, error) { +// RequiredParamObject gets a required object parameter from the request +func RequiredParamObject(arguments map[string]interface{}, name string) (map[string]interface{}, error) { // Get the parameter value paramValue, found := arguments[name] if !found || paramValue == nil { @@ -167,12 +180,12 @@ func requiredParamObject(arguments map[string]interface{}, name string) (map[str return nil, fmt.Errorf("%s parameter must be an object", name) } -func getOptions(ctx context.Context) *config.Options { +func GetOptions(ctx context.Context) *config.Options { return ctx.Value(OptionsKey).(*config.Options) } -// isClusterAvailable checks if a PulsarCluster is available (ready) -func isClusterAvailable(cluster sncloud.ComGithubStreamnativeCloudApiServerPkgApisCloudV1alpha1PulsarCluster) bool { +// IsClusterAvailable checks if a PulsarCluster is available (ready) +func IsClusterAvailable(cluster sncloud.ComGithubStreamnativeCloudApiServerPkgApisCloudV1alpha1PulsarCluster) bool { // Check if broker has ready replicas if cluster.Status.Broker == nil || cluster.Status.Broker.ReadyReplicas == nil || *cluster.Status.Broker.ReadyReplicas == 0 { return false @@ -187,8 +200,8 @@ func isClusterAvailable(cluster sncloud.ComGithubStreamnativeCloudApiServerPkgAp return false } -// getEngineType returns the Pulsar cluster is an Ursa engine or a Classic engine -func getEngineType(cluster sncloud.ComGithubStreamnativeCloudApiServerPkgApisCloudV1alpha1PulsarCluster) string { +// GetEngineType returns the Pulsar cluster is an Ursa engine or a Classic engine +func GetEngineType(cluster sncloud.ComGithubStreamnativeCloudApiServerPkgApisCloudV1alpha1PulsarCluster) string { if cluster.Metadata.Annotations != nil { if v, has := (*cluster.Metadata.Annotations)[AnnotationStreamNativeCloudEngine]; has && v == "ursa" { return "ursa" @@ -197,7 +210,7 @@ func getEngineType(cluster sncloud.ComGithubStreamnativeCloudApiServerPkgApisClo return "classic" } -func convertToMapInterface(m map[string]string) map[string]interface{} { +func ConvertToMapInterface(m map[string]string) map[string]interface{} { result := make(map[string]interface{}) for k, v := range m { result[k] = v @@ -205,7 +218,7 @@ func convertToMapInterface(m map[string]string) map[string]interface{} { return result } -func convertToMapString(m map[string]interface{}) map[string]string { +func ConvertToMapString(m map[string]interface{}) map[string]string { result := make(map[string]string) for k, v := range m { result[k] = fmt.Sprintf("%v", v) @@ -213,15 +226,15 @@ func convertToMapString(m map[string]interface{}) map[string]string { return result } -// isInstanceValid checks if PulsarInstance has valid OAuth2 authentication configuration -func isInstanceValid(instance sncloud.ComGithubStreamnativeCloudApiServerPkgApisCloudV1alpha1PulsarInstance) bool { +// IsInstanceValid checks if PulsarInstance has valid OAuth2 authentication configuration +func IsInstanceValid(instance sncloud.ComGithubStreamnativeCloudApiServerPkgApisCloudV1alpha1PulsarInstance) bool { return instance.Status != nil && (instance.Status.Auth.Type == "oauth2" || instance.Status.Auth.Type == "apikey") && instance.Status.Auth.Oauth2.IssuerURL != "" && instance.Status.Auth.Oauth2.Audience != "" } -func hasCachedValidToken(cachedGrant *auth.AuthorizationGrant) (bool, error) { +func HasCachedValidToken(cachedGrant *auth.AuthorizationGrant) (bool, error) { if cachedGrant == nil || cachedGrant.Token == nil { return false, nil } @@ -230,7 +243,7 @@ func hasCachedValidToken(cachedGrant *auth.AuthorizationGrant) (bool, error) { return cachedGrant.Token.Valid(), nil } -func isTokenAboutToExpire(cachedGrant *auth.AuthorizationGrant, window time.Duration) (bool, error) { +func IsTokenAboutToExpire(cachedGrant *auth.AuthorizationGrant, window time.Duration) (bool, error) { if cachedGrant == nil || cachedGrant.Token == nil { return true, nil } @@ -246,8 +259,8 @@ func isTokenAboutToExpire(cachedGrant *auth.AuthorizationGrant, window time.Dura return timeUntilExpiry <= window, nil } -// parseMessageConfigs parses a list of key=value strings into a map -func parseMessageConfigs(configs []string) (map[string]*string, error) { +// ParseMessageConfigs parses a list of key=value strings into a map +func ParseMessageConfigs(configs []string) (map[string]*string, error) { result := make(map[string]*string) for _, config := range configs { diff --git a/pkg/mcp/context_tools.go b/pkg/mcp/context_tools.go index bf459fa..fbf112b 100644 --- a/pkg/mcp/context_tools.go +++ b/pkg/mcp/context_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/streamnative-mcp-server/pkg/auth/store" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/config" ) @@ -61,7 +62,7 @@ func RegisterContextTools(s *server.MCPServer, features []string) { // handleWhoami handles the whoami tool request func handleWhoami(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { - options := ctx.Value(OptionsKey).(*config.Options) + options := ctx.Value(common.OptionsKey).(*config.Options) issuer := options.LoadConfigOrDie().Auth.Issuer() userName, err := options.WhoAmI(issuer.Audience) @@ -90,14 +91,14 @@ func handleWhoami(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResu // handleSetContext handles the set-context tool request func handleSetContext(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - options := ctx.Value(OptionsKey).(*config.Options) + options := ctx.Value(common.OptionsKey).(*config.Options) - instanceName, err := requiredParam[string](request.Params.Arguments, "instanceName") + instanceName, err := common.RequiredParam[string](request.Params.Arguments, "instanceName") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get instance name: %v", err)), nil } - clusterName, err := requiredParam[string](request.Params.Arguments, "clusterName") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "clusterName") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get cluster name: %v", err)), nil } diff --git a/pkg/mcp/context_utils.go b/pkg/mcp/context_utils.go index 6fcbb2a..598d19b 100644 --- a/pkg/mcp/context_utils.go +++ b/pkg/mcp/context_utils.go @@ -21,6 +21,7 @@ import ( "context" "fmt" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/config" "github.com/streamnative/streamnative-mcp-server/pkg/kafka" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" @@ -52,7 +53,7 @@ func SetContext(options *config.Options, instanceName, clusterName string) error foundInstance := false for _, i := range instances.Items { if *i.Metadata.Name == instanceName { - if isInstanceValid(i) { + if common.IsInstanceValid(i) { instance = i foundInstance = true break @@ -73,7 +74,7 @@ func SetContext(options *config.Options, instanceName, clusterName string) error foundCluster := false for _, c := range clusters.Items { if *c.Metadata.Name == clusterName && c.Spec.InstanceName == instanceName { - if isClusterAvailable(c) { + if common.IsClusterAvailable(c) { cluster = c foundCluster = true break @@ -110,13 +111,13 @@ func SetContext(options *config.Options, instanceName, clusterName string) error cachedGrant, err := options.AuthOptions.LoadGrant(tokenKey) if err == nil && cachedGrant != nil { - cacheHasValidToken, err := hasCachedValidToken(cachedGrant) + cacheHasValidToken, err := common.HasCachedValidToken(cachedGrant) if err != nil { cacheHasValidToken = false } if cacheHasValidToken { - tokenAboutToExpire, err := isTokenAboutToExpire(cachedGrant, TokenRefreshWindow) + tokenAboutToExpire, err := common.IsTokenAboutToExpire(cachedGrant, common.TokenRefreshWindow) if err != nil { tokenAboutToExpire = true } diff --git a/pkg/mcp/features.go b/pkg/mcp/features.go index f7c81ef..1a7ab8c 100644 --- a/pkg/mcp/features.go +++ b/pkg/mcp/features.go @@ -47,4 +47,5 @@ const ( FeaturePulsarAdminTopicPolicy Feature = "pulsar-admin-topic-policy" FeaturePulsarClient Feature = "pulsar-client" FeatureStreamNativeCloud Feature = "streamnative-cloud" + FeatureFunctionsAsTools Feature = "functions-as-tools" ) diff --git a/pkg/mcp/kafka_admin_connect_tools.go b/pkg/mcp/kafka_admin_connect_tools.go index 7b4c9cf..df5878d 100644 --- a/pkg/mcp/kafka_admin_connect_tools.go +++ b/pkg/mcp/kafka_admin_connect_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/kafka" ) @@ -150,12 +151,12 @@ func KafkaAdminAddKafkaConnectTools(s *server.MCPServer, readOnly bool, features func handleKafkaConnectTool(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -255,7 +256,7 @@ func handleKafkaConnectorsList(ctx context.Context, admin kafka.Connect, _ mcp.C func handleKafkaConnectorGet(ctx context.Context, admin kafka.Connect, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get a specific connector - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get connector name: %v", err)), nil } @@ -275,17 +276,17 @@ func handleKafkaConnectorGet(ctx context.Context, admin kafka.Connect, request m func handleKafkaConnectorCreate(ctx context.Context, admin kafka.Connect, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Create a new connector - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get connector name: %v", err)), nil } - configMap, err := requiredParamObject(request.Params.Arguments, "config") + configMap, err := common.RequiredParamObject(request.Params.Arguments, "config") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get config: %v", err)), nil } - config := convertToMapString(configMap) + config := common.ConvertToMapString(configMap) config["name"] = name @@ -304,17 +305,17 @@ func handleKafkaConnectorCreate(ctx context.Context, admin kafka.Connect, reques func handleKafkaConnectorUpdate(ctx context.Context, admin kafka.Connect, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Update a connector - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get connector name: %v", err)), nil } - configMap, err := requiredParamObject(request.Params.Arguments, "config") + configMap, err := common.RequiredParamObject(request.Params.Arguments, "config") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get config: %v", err)), nil } - config := convertToMapString(configMap) + config := common.ConvertToMapString(configMap) config["name"] = name @@ -333,7 +334,7 @@ func handleKafkaConnectorUpdate(ctx context.Context, admin kafka.Connect, reques func handleKafkaConnectorDelete(ctx context.Context, admin kafka.Connect, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Delete a connector - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get connector name: %v", err)), nil } @@ -348,7 +349,7 @@ func handleKafkaConnectorDelete(ctx context.Context, admin kafka.Connect, reques func handleKafkaConnectorRestart(ctx context.Context, admin kafka.Connect, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Restart a connector - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get connector name: %v", err)), nil } @@ -363,7 +364,7 @@ func handleKafkaConnectorRestart(ctx context.Context, admin kafka.Connect, reque func handleKafkaConnectorPause(ctx context.Context, admin kafka.Connect, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Pause a connector - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get connector name: %v", err)), nil } @@ -378,7 +379,7 @@ func handleKafkaConnectorPause(ctx context.Context, admin kafka.Connect, request func handleKafkaConnectorResume(ctx context.Context, admin kafka.Connect, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Resume a connector - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get connector name: %v", err)), nil } diff --git a/pkg/mcp/kafka_admin_groups_tools.go b/pkg/mcp/kafka_admin_groups_tools.go index b4ab028..363f199 100644 --- a/pkg/mcp/kafka_admin_groups_tools.go +++ b/pkg/mcp/kafka_admin_groups_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/kafka" "github.com/twmb/franz-go/pkg/kadm" ) @@ -83,7 +84,6 @@ func KafkaAdminAddGroupsTools(s *server.MCPServer, readOnly bool, features []str " topic: \"my-topic\"\n" + " partition: 0\n" + " offset: 1000\n\n" + - "This tool requires Kafka super-user permissions." kafkaGroupsTool := mcp.NewTool("kafka_admin_groups", @@ -125,12 +125,12 @@ func KafkaAdminAddGroupsTools(s *server.MCPServer, readOnly bool, features []str func handleKafkaGroupsTool(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -181,7 +181,7 @@ func handleKafkaGroupsTool(readOnly bool) func(context.Context, mcp.CallToolRequ } func handleKafkaGroupDescribe(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - groupName, err := requiredParam[string](request.Params.Arguments, "group") + groupName, err := common.RequiredParam[string](request.Params.Arguments, "group") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get group name: %v", err)), nil } @@ -210,12 +210,12 @@ func handleKafkaGroupDescribe(ctx context.Context, admin *kadm.Client, request m } func handleKafkaGroupRemoveMembers(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - groupName, err := requiredParam[string](request.Params.Arguments, "group") + groupName, err := common.RequiredParam[string](request.Params.Arguments, "group") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get group name: %v", err)), nil } - members, err := requiredParam[string](request.Params.Arguments, "members") + members, err := common.RequiredParam[string](request.Params.Arguments, "members") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get members: %v", err)), nil } @@ -248,7 +248,7 @@ func handleKafkaGroupsList(ctx context.Context, admin *kadm.Client, _ mcp.CallTo } func handleKafkaGroupOffsets(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - groupName, err := requiredParam[string](request.Params.Arguments, "group") + groupName, err := common.RequiredParam[string](request.Params.Arguments, "group") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get group name: %v", err)), nil } @@ -267,12 +267,12 @@ func handleKafkaGroupOffsets(ctx context.Context, admin *kadm.Client, request mc } func handleKafkaGroupDeleteOffset(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - groupName, err := requiredParam[string](request.Params.Arguments, "group") + groupName, err := common.RequiredParam[string](request.Params.Arguments, "group") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get group name: %v", err)), nil } - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } @@ -303,23 +303,23 @@ func handleKafkaGroupDeleteOffset(ctx context.Context, admin *kadm.Client, reque func handleKafkaGroupSetOffset(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - groupName, err := requiredParam[string](request.Params.Arguments, "group") + groupName, err := common.RequiredParam[string](request.Params.Arguments, "group") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get group name: %v", err)), nil } - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } - partition, err := requiredParam[float64](request.Params.Arguments, "partition") + partition, err := common.RequiredParam[float64](request.Params.Arguments, "partition") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get partition number: %v", err)), nil } partitionInt := int32(partition) - offset, err := requiredParam[float64](request.Params.Arguments, "offset") + offset, err := common.RequiredParam[float64](request.Params.Arguments, "offset") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get offset value: %v", err)), nil } diff --git a/pkg/mcp/kafka_admin_partitions_tools.go b/pkg/mcp/kafka_admin_partitions_tools.go index a1a3780..9448602 100644 --- a/pkg/mcp/kafka_admin_partitions_tools.go +++ b/pkg/mcp/kafka_admin_partitions_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/kafka" "github.com/twmb/franz-go/pkg/kadm" ) @@ -87,12 +88,12 @@ func KafkaAdminAddPartitionsTools(s *server.MCPServer, readOnly bool, features [ func handleKafkaPartitionsTool(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -128,12 +129,12 @@ func handleKafkaPartitionsTool(readOnly bool) func(context.Context, mcp.CallTool } func handleKafkaPartitionUpdate(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - topicName, err := requiredParam[string](request.Params.Arguments, "topic") + topicName, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } - newTotal, err := requiredParam[int](request.Params.Arguments, "new-total") + newTotal, err := common.RequiredParam[int](request.Params.Arguments, "new-total") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get new total: %v", err)), nil } diff --git a/pkg/mcp/kafka_admin_sr_tools.go b/pkg/mcp/kafka_admin_sr_tools.go index cdec21a..eab524e 100644 --- a/pkg/mcp/kafka_admin_sr_tools.go +++ b/pkg/mcp/kafka_admin_sr_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/kafka" "github.com/twmb/franz-go/pkg/sr" ) @@ -149,12 +150,12 @@ func KafkaAdminAddSchemaRegistryTools(s *server.MCPServer, readOnly bool, featur func handleKafkaSchemaRegistryTool(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -245,7 +246,7 @@ func handleKafkaSubjectsList(ctx context.Context, admin *sr.Client, _ mcp.CallTo } func handleKafkaVersionsList(ctx context.Context, admin *sr.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - subject, err := requiredParam[string](request.Params.Arguments, "subject") + subject, err := common.RequiredParam[string](request.Params.Arguments, "subject") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subject: %v", err)), nil } @@ -278,7 +279,7 @@ func handleKafkaTypesList(ctx context.Context, admin *sr.Client, _ mcp.CallToolR } func handleKafkaSubjectGet(ctx context.Context, admin *sr.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - subject, err := requiredParam[string](request.Params.Arguments, "subject") + subject, err := common.RequiredParam[string](request.Params.Arguments, "subject") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subject: %v", err)), nil } @@ -297,17 +298,17 @@ func handleKafkaSubjectGet(ctx context.Context, admin *sr.Client, request mcp.Ca } func handleKafkaSubjectCreate(ctx context.Context, admin *sr.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - subject, err := requiredParam[string](request.Params.Arguments, "subject") + subject, err := common.RequiredParam[string](request.Params.Arguments, "subject") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subject: %v", err)), nil } - schema, err := requiredParam[string](request.Params.Arguments, "schema") + schema, err := common.RequiredParam[string](request.Params.Arguments, "schema") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get schema: %v", err)), nil } - typeString, err := requiredParam[string](request.Params.Arguments, "type") + typeString, err := common.RequiredParam[string](request.Params.Arguments, "type") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get schema type: %v", err)), nil } @@ -337,12 +338,12 @@ func handleKafkaSubjectCreate(ctx context.Context, admin *sr.Client, request mcp } func handleKafkaSubjectDelete(ctx context.Context, admin *sr.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - subject, err := requiredParam[string](request.Params.Arguments, "subject") + subject, err := common.RequiredParam[string](request.Params.Arguments, "subject") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subject: %v", err)), nil } - version, err := requiredParam[int](request.Params.Arguments, "version") + version, err := common.RequiredParam[int](request.Params.Arguments, "version") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get version: %v", err)), nil } @@ -356,12 +357,12 @@ func handleKafkaSubjectDelete(ctx context.Context, admin *sr.Client, request mcp } func handleKafkaVersionGet(ctx context.Context, admin *sr.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - subject, err := requiredParam[string](request.Params.Arguments, "subject") + subject, err := common.RequiredParam[string](request.Params.Arguments, "subject") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subject: %v", err)), nil } - version, err := requiredParam[int](request.Params.Arguments, "version") + version, err := common.RequiredParam[int](request.Params.Arguments, "version") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get version: %v", err)), nil } @@ -380,7 +381,7 @@ func handleKafkaVersionGet(ctx context.Context, admin *sr.Client, request mcp.Ca } func handleKafkaCompatibilityGet(ctx context.Context, admin *sr.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - subject, err := requiredParam[string](request.Params.Arguments, "subject") + subject, err := common.RequiredParam[string](request.Params.Arguments, "subject") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subject: %v", err)), nil } @@ -396,12 +397,12 @@ func handleKafkaCompatibilityGet(ctx context.Context, admin *sr.Client, request } func handleKafkaCompatibilitySet(ctx context.Context, admin *sr.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - subject, err := requiredParam[string](request.Params.Arguments, "subject") + subject, err := common.RequiredParam[string](request.Params.Arguments, "subject") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subject: %v", err)), nil } - compatibility, err := requiredParam[string](request.Params.Arguments, "compatibility") + compatibility, err := common.RequiredParam[string](request.Params.Arguments, "compatibility") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get compatibility: %v", err)), nil } diff --git a/pkg/mcp/kafka_admin_topics_tools.go b/pkg/mcp/kafka_admin_topics_tools.go index adea483..c83be6f 100644 --- a/pkg/mcp/kafka_admin_topics_tools.go +++ b/pkg/mcp/kafka_admin_topics_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/kafka" "github.com/twmb/franz-go/pkg/kadm" ) @@ -127,12 +128,12 @@ func KafkaAdminAddTopicTools(s *server.MCPServer, readOnly bool, features []stri func handleKafkaTopicTool(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -182,7 +183,7 @@ func handleKafkaTopicTool(readOnly bool) func(context.Context, mcp.CallToolReque func handleKafkaTopicGet(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topicName, err := requiredParam[string](request.Params.Arguments, "name") + topicName, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } @@ -204,27 +205,27 @@ func handleKafkaTopicGet(ctx context.Context, admin *kadm.Client, request mcp.Ca func handleKafkaTopicCreate(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topicName, err := requiredParam[string](request.Params.Arguments, "name") + topicName, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } // Get optional parameters - partitions, ok := optionalParam[float64](request.Params.Arguments, "partitions") + partitions, ok := common.OptionalParam[float64](request.Params.Arguments, "partitions") if !ok { partitions = 1 // Default to 1 partition } - replicationFactor, ok := optionalParam[float64](request.Params.Arguments, "replication-factor") + replicationFactor, ok := common.OptionalParam[float64](request.Params.Arguments, "replication-factor") if !ok { replicationFactor = 1 // Default to replication factor 1 } // Get configs if provided var configEntries map[string]*string - configsArray, ok := optionalParamConfigs(request.Params.Arguments, "configs") + configsArray, ok := common.OptionalParamConfigs(request.Params.Arguments, "configs") if ok { - configEntries, err = parseMessageConfigs(configsArray) + configEntries, err = common.ParseMessageConfigs(configsArray) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to parse configs: %v", err)), nil } @@ -241,7 +242,7 @@ func handleKafkaTopicCreate(ctx context.Context, admin *kadm.Client, request mcp func handleKafkaTopicDelete(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topicName, err := requiredParam[string](request.Params.Arguments, "name") + topicName, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } @@ -257,7 +258,7 @@ func handleKafkaTopicDelete(ctx context.Context, admin *kadm.Client, request mcp func handleKafkaTopicsList(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - includeInternal, ok := optionalParam[bool](request.Params.Arguments, "include-internal") + includeInternal, ok := common.OptionalParam[bool](request.Params.Arguments, "include-internal") if !ok { includeInternal = false } @@ -292,7 +293,7 @@ func handleKafkaTopicsList(ctx context.Context, admin *kadm.Client, request mcp. func handleKafkaTopicMetadata(ctx context.Context, admin *kadm.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topicName, err := requiredParam[string](request.Params.Arguments, "name") + topicName, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } diff --git a/pkg/mcp/kafka_client_consume_tools.go b/pkg/mcp/kafka_client_consume_tools.go index 10f907c..10d9c46 100644 --- a/pkg/mcp/kafka_client_consume_tools.go +++ b/pkg/mcp/kafka_client_consume_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/sirupsen/logrus" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/kafka" "github.com/twmb/franz-go/pkg/kgo" "github.com/twmb/franz-go/pkg/sr" @@ -114,7 +115,7 @@ func KafkaClientAddConsumeTools(s *server.MCPServer, _ bool, logrusLogger *logru func handleKafkaConsume(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { opts := []kgo.Opt{} // Get required parameters - topicName, err := requiredParam[string](request.Params.Arguments, "topic") + topicName, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } @@ -124,17 +125,17 @@ func handleKafkaConsume(ctx context.Context, request mcp.CallToolRequest) (*mcp. opts = append(opts, kgo.KeepRetryableFetchErrors()) w := logger.Writer() opts = append(opts, kgo.WithLogger(kgo.BasicLogger(w, kgo.LogLevelInfo, nil))) - maxMessages, hasMaxMessages := optionalParam[float64](request.Params.Arguments, "max-messages") + maxMessages, hasMaxMessages := common.OptionalParam[float64](request.Params.Arguments, "max-messages") if !hasMaxMessages { maxMessages = 10 // Default to 10 messages } - timeoutSec, hasTimeout := optionalParam[float64](request.Params.Arguments, "timeout") + timeoutSec, hasTimeout := common.OptionalParam[float64](request.Params.Arguments, "timeout") if !hasTimeout { timeoutSec = 10 // Default to 10 seconds } - group, hasGroup := optionalParam[string](request.Params.Arguments, "group") + group, hasGroup := common.OptionalParam[string](request.Params.Arguments, "group") if !hasGroup { group = "" } @@ -142,7 +143,7 @@ func handleKafkaConsume(ctx context.Context, request mcp.CallToolRequest) (*mcp. opts = append(opts, kgo.ConsumerGroup(group)) } - offsetStr, hasOffset := optionalParam[string](request.Params.Arguments, "offset") + offsetStr, hasOffset := common.OptionalParam[string](request.Params.Arguments, "offset") if !hasOffset { offsetStr = "atstart" // Default to starting at the beginning } diff --git a/pkg/mcp/kafka_client_produce_tools.go b/pkg/mcp/kafka_client_produce_tools.go index 8e8a8ab..0b2b5c5 100644 --- a/pkg/mcp/kafka_client_produce_tools.go +++ b/pkg/mcp/kafka_client_produce_tools.go @@ -27,6 +27,7 @@ import ( "github.com/hamba/avro/v2" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/kafka" "github.com/twmb/franz-go/pkg/kgo" "github.com/twmb/franz-go/pkg/sr" @@ -112,23 +113,23 @@ func KafkaClientAddProduceTools(s *server.MCPServer, readOnly bool, features []s // handleKafkaProduce handles producing messages to a Kafka topic func handleKafkaProduce(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topicName, err := requiredParam[string](request.Params.Arguments, "topic") + topicName, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } // Handle single message case // Get value from parameter or file - value, err := requiredParam[string](request.Params.Arguments, "value") + value, err := common.RequiredParam[string](request.Params.Arguments, "value") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get value: %v", err)), nil } // Get optional parameters - key, hasKey := optionalParam[string](request.Params.Arguments, "key") - headers, hasHeaders := optionalParam[[]interface{}](request.Params.Arguments, "headers") + key, hasKey := common.OptionalParam[string](request.Params.Arguments, "key") + headers, hasHeaders := common.OptionalParam[[]interface{}](request.Params.Arguments, "headers") sync := true - if syncVal, hasSync := optionalParam[bool](request.Params.Arguments, "sync"); hasSync { + if syncVal, hasSync := common.OptionalParam[bool](request.Params.Arguments, "sync"); hasSync { sync = syncVal } diff --git a/pkg/mcp/prompts.go b/pkg/mcp/prompts.go index 1a12faf..28252f3 100644 --- a/pkg/mcp/prompts.go +++ b/pkg/mcp/prompts.go @@ -25,6 +25,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/config" sncloud "github.com/streamnative/streamnative-mcp-server/sdk/sdk-apiserver" "k8s.io/utils/ptr" @@ -81,7 +82,7 @@ func RegisterPrompts(s *server.MCPServer) { } func handleListPulsarClusters(ctx context.Context, _ mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { - options := getOptions(ctx) + options := common.GetOptions(ctx) apiClient, err := config.GetAPIClient() if err != nil { return nil, fmt.Errorf("failed to get API client: %v", err) @@ -118,11 +119,11 @@ func handleListPulsarClusters(ctx context.Context, _ mcp.GetPromptRequest) (*mcp } status := "Not Ready" - if isClusterAvailable(cluster) { + if common.IsClusterAvailable(cluster) { status = "Ready" } - engineType := getEngineType(cluster) + engineType := common.GetEngineType(cluster) messages[i+1] = mcp.PromptMessage{ Content: mcp.TextContent{ @@ -147,13 +148,13 @@ func handleListPulsarClusters(ctx context.Context, _ mcp.GetPromptRequest) (*mcp } func handleReadPulsarCluster(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { - options := getOptions(ctx) + options := common.GetOptions(ctx) apiClient, err := config.GetAPIClient() if err != nil { return nil, fmt.Errorf("failed to get API client: %v", err) } - name, err := requiredParam[string](convertToMapInterface(request.Params.Arguments), "name") + name, err := common.RequiredParam[string](common.ConvertToMapInterface(request.Params.Arguments), "name") if err != nil { return nil, fmt.Errorf("failed to get name: %v", err) } @@ -204,24 +205,24 @@ func handleReadPulsarCluster(ctx context.Context, request mcp.GetPromptRequest) } func handleBuildServerlessPulsarCluster(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { - options := getOptions(ctx) + options := common.GetOptions(ctx) apiClient, err := config.GetAPIClient() if err != nil { return nil, fmt.Errorf("failed to get API client: %v", err) } - arguments := convertToMapInterface(request.Params.Arguments) + arguments := common.ConvertToMapInterface(request.Params.Arguments) - instanceName, err := requiredParam[string](arguments, "instance-name") + instanceName, err := common.RequiredParam[string](arguments, "instance-name") if err != nil { return nil, fmt.Errorf("failed to get instance name: %v", err) } - clusterName, err := requiredParam[string](arguments, "cluster-name") + clusterName, err := common.RequiredParam[string](arguments, "cluster-name") if err != nil { return nil, fmt.Errorf("failed to get cluster name: %v", err) } - provider, hasProvider := optionalParam[string](arguments, "provider") + provider, hasProvider := common.OptionalParam[string](arguments, "provider") if !hasProvider { provider = "" } diff --git a/pkg/mcp/pulsar_admin_brokers_stats_tools.go b/pkg/mcp/pulsar_admin_brokers_stats_tools.go index 6b04e51..830b0ef 100644 --- a/pkg/mcp/pulsar_admin_brokers_stats_tools.go +++ b/pkg/mcp/pulsar_admin_brokers_stats_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -70,7 +71,7 @@ func handleBrokerStats(_ bool) func(context.Context, mcp.CallToolRequest) (*mcp. } // Get required resource parameter - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'resource'. " + "Please specify one of: monitoring_metrics, mbeans, topics, allocator_stats, load_report.")), nil @@ -85,7 +86,7 @@ func handleBrokerStats(_ bool) func(context.Context, mcp.CallToolRequest) (*mcp. case "topics": return handleTopics(client) case "allocator_stats": - allocatorName, err := requiredParam[string](request.Params.Arguments, "allocator_name") + allocatorName, err := common.RequiredParam[string](request.Params.Arguments, "allocator_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'allocator_name' for allocator_stats resource. " + "Please provide the name of the allocator to get statistics for.")), nil diff --git a/pkg/mcp/pulsar_admin_brokers_tools.go b/pkg/mcp/pulsar_admin_brokers_tools.go index f343f5d..388233c 100644 --- a/pkg/mcp/pulsar_admin_brokers_tools.go +++ b/pkg/mcp/pulsar_admin_brokers_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -96,13 +97,13 @@ func handleBrokerTool(readOnly bool) func(context.Context, mcp.CallToolRequest) } // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required resource parameter. " + "Please specify one of: brokers, health, config, namespaces.")), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required operation parameter. " + "Please specify one of: list, get, update, delete based on the resource type.")), nil @@ -162,7 +163,7 @@ func validateResourceOperation(resource, operation string) (bool, string) { func handleBrokersResource(client cmdutils.Client, operation string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { switch operation { case "list": - clusterName, err := requiredParam[string](request.Params.Arguments, "clusterName") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "clusterName") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'clusterName'. " + "Please provide the name of the Pulsar cluster to list brokers for.")), nil @@ -207,7 +208,7 @@ func handleHealthResource(client cmdutils.Client, operation string, _ mcp.CallTo func handleConfigResource(client cmdutils.Client, operation string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { switch operation { case "get": - configType, err := requiredParam[string](request.Params.Arguments, "configType") + configType, err := common.RequiredParam[string](request.Params.Arguments, "configType") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'configType'. " + "Please specify one of: dynamic, runtime, internal, all_dynamic.")), nil @@ -242,13 +243,13 @@ func handleConfigResource(client cmdutils.Client, operation string, request mcp. return mcp.NewToolResultText(string(resultJSON)), nil case "update": - configName, err := requiredParam[string](request.Params.Arguments, "configName") + configName, err := common.RequiredParam[string](request.Params.Arguments, "configName") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'configName'. " + "Please provide the name of the configuration parameter to update.")), nil } - configValue, err := requiredParam[string](request.Params.Arguments, "configValue") + configValue, err := common.RequiredParam[string](request.Params.Arguments, "configValue") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'configValue'. " + "Please provide the new value for the configuration parameter.")), nil @@ -264,7 +265,7 @@ func handleConfigResource(client cmdutils.Client, operation string, request mcp. configName, configValue)), nil case "delete": - configName, err := requiredParam[string](request.Params.Arguments, "configName") + configName, err := common.RequiredParam[string](request.Params.Arguments, "configName") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'configName'. " + "Please provide the name of the configuration parameter to delete.")), nil @@ -288,13 +289,13 @@ func handleConfigResource(client cmdutils.Client, operation string, request mcp. func handleNamespacesResource(client cmdutils.Client, operation string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { switch operation { case "get": - clusterName, err := requiredParam[string](request.Params.Arguments, "clusterName") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "clusterName") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'clusterName'. " + "Please provide the name of the Pulsar cluster.")), nil } - brokerURL, err := requiredParam[string](request.Params.Arguments, "brokerUrl") + brokerURL, err := common.RequiredParam[string](request.Params.Arguments, "brokerUrl") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'brokerUrl'. " + "Please provide the URL of the broker (e.g., '127.0.0.1:8080').")), nil diff --git a/pkg/mcp/pulsar_admin_cluster_tools.go b/pkg/mcp/pulsar_admin_cluster_tools.go index dfc04a9..cc89fb6 100644 --- a/pkg/mcp/pulsar_admin_cluster_tools.go +++ b/pkg/mcp/pulsar_admin_cluster_tools.go @@ -27,6 +27,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -113,13 +114,13 @@ func handleClusterTool(readOnly bool) func(context.Context, mcp.CallToolRequest) } // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required resource parameter. " + "Please specify one of: cluster, peer_clusters, failure_domain.")), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required operation parameter. " + "Please specify one of: list, get, create, update, delete based on the resource type.")), nil @@ -179,7 +180,7 @@ func handleClusterResource(client cmdutils.Client, operation string, request mcp case "list": return handleClusterList(client) case "get": - clusterName, err := requiredParam[string](request.Params.Arguments, "cluster_name") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "cluster_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'cluster_name'. " + "Please provide the name of the cluster to get information for.")), nil @@ -190,7 +191,7 @@ func handleClusterResource(client cmdutils.Client, operation string, request mcp case "update": return updateCluster(client, request) case "delete": - clusterName, err := requiredParam[string](request.Params.Arguments, "cluster_name") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "cluster_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'cluster_name'. " + "Please provide the name of the cluster to delete.")), nil @@ -204,7 +205,7 @@ func handleClusterResource(client cmdutils.Client, operation string, request mcp // Handle peer_clusters resource operations func handlePeerClustersResource(client cmdutils.Client, operation string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName, err := requiredParam[string](request.Params.Arguments, "cluster_name") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "cluster_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'cluster_name'. " + "Please provide the name of the cluster to operate on.")), nil @@ -214,7 +215,7 @@ func handlePeerClustersResource(client cmdutils.Client, operation string, reques case "get": return getPeerClusters(client, clusterName) case "update": - peerClusters, err := requiredParamArray[string](request.Params.Arguments, "peer_cluster_names") + peerClusters, err := common.RequiredParamArray[string](request.Params.Arguments, "peer_cluster_names") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'peer_cluster_names'. " + "Please provide an array of peer cluster names to set.")), nil @@ -228,7 +229,7 @@ func handlePeerClustersResource(client cmdutils.Client, operation string, reques // Handle failure_domain resource operations func handleFailureDomainResource(client cmdutils.Client, operation string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName, err := requiredParam[string](request.Params.Arguments, "cluster_name") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "cluster_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'cluster_name'. " + "Please provide the name of the cluster to operate on.")), nil @@ -238,7 +239,7 @@ func handleFailureDomainResource(client cmdutils.Client, operation string, reque case "list": return listFailureDomains(client, clusterName) case "get": - domainName, err := requiredParam[string](request.Params.Arguments, "domain_name") + domainName, err := common.RequiredParam[string](request.Params.Arguments, "domain_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'domain_name'. " + "Please provide the name of the failure domain to get.")), nil @@ -249,7 +250,7 @@ func handleFailureDomainResource(client cmdutils.Client, operation string, reque case "update": return updateFailureDomain(client, request) case "delete": - domainName, err := requiredParam[string](request.Params.Arguments, "domain_name") + domainName, err := common.RequiredParam[string](request.Params.Arguments, "domain_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'domain_name'. " + "Please provide the name of the failure domain to delete.")), nil @@ -296,7 +297,7 @@ func getClusterData(client cmdutils.Client, clusterName string) (*mcp.CallToolRe // createCluster creates a new Pulsar cluster func createCluster(client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName, err := requiredParam[string](request.Params.Arguments, "cluster_name") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "cluster_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'cluster_name'. " + "Please provide the name of the cluster to create.")), nil @@ -307,19 +308,19 @@ func createCluster(client cmdutils.Client, request mcp.CallToolRequest) (*mcp.Ca } // Set optional parameters if provided - if serviceURL, ok := optionalParam[string](request.Params.Arguments, "service_url"); ok { + if serviceURL, ok := common.OptionalParam[string](request.Params.Arguments, "service_url"); ok { clusterData.ServiceURL = serviceURL } - if serviceURLTls, ok := optionalParam[string](request.Params.Arguments, "service_url_tls"); ok { + if serviceURLTls, ok := common.OptionalParam[string](request.Params.Arguments, "service_url_tls"); ok { clusterData.ServiceURLTls = serviceURLTls } - if brokerServiceURL, ok := optionalParam[string](request.Params.Arguments, "broker_service_url"); ok { + if brokerServiceURL, ok := common.OptionalParam[string](request.Params.Arguments, "broker_service_url"); ok { clusterData.BrokerServiceURL = brokerServiceURL } - if brokerServiceURLTls, ok := optionalParam[string](request.Params.Arguments, "broker_service_url_tls"); ok { + if brokerServiceURLTls, ok := common.OptionalParam[string](request.Params.Arguments, "broker_service_url_tls"); ok { clusterData.BrokerServiceURLTls = brokerServiceURLTls } - if peerClusters, ok := optionalParamArray[string](request.Params.Arguments, "peer_cluster_names"); ok { + if peerClusters, ok := common.OptionalParamArray[string](request.Params.Arguments, "peer_cluster_names"); ok { clusterData.PeerClusterNames = peerClusters } @@ -333,7 +334,7 @@ func createCluster(client cmdutils.Client, request mcp.CallToolRequest) (*mcp.Ca // updateCluster updates an existing Pulsar cluster func updateCluster(client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName, err := requiredParam[string](request.Params.Arguments, "cluster_name") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "cluster_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'cluster_name'. " + "Please provide the name of the cluster to update.")), nil @@ -344,19 +345,19 @@ func updateCluster(client cmdutils.Client, request mcp.CallToolRequest) (*mcp.Ca } // Set optional parameters if provided - if serviceURL, ok := optionalParam[string](request.Params.Arguments, "service_url"); ok { + if serviceURL, ok := common.OptionalParam[string](request.Params.Arguments, "service_url"); ok { clusterData.ServiceURL = serviceURL } - if serviceURLTls, ok := optionalParam[string](request.Params.Arguments, "service_url_tls"); ok { + if serviceURLTls, ok := common.OptionalParam[string](request.Params.Arguments, "service_url_tls"); ok { clusterData.ServiceURLTls = serviceURLTls } - if brokerServiceURL, ok := optionalParam[string](request.Params.Arguments, "broker_service_url"); ok { + if brokerServiceURL, ok := common.OptionalParam[string](request.Params.Arguments, "broker_service_url"); ok { clusterData.BrokerServiceURL = brokerServiceURL } - if brokerServiceURLTls, ok := optionalParam[string](request.Params.Arguments, "broker_service_url_tls"); ok { + if brokerServiceURLTls, ok := common.OptionalParam[string](request.Params.Arguments, "broker_service_url_tls"); ok { clusterData.BrokerServiceURLTls = brokerServiceURLTls } - if peerClusters, ok := optionalParamArray[string](request.Params.Arguments, "peer_cluster_names"); ok { + if peerClusters, ok := common.OptionalParamArray[string](request.Params.Arguments, "peer_cluster_names"); ok { clusterData.PeerClusterNames = peerClusters } @@ -438,19 +439,19 @@ func listFailureDomains(client cmdutils.Client, clusterName string) (*mcp.CallTo // createFailureDomain creates a new failure domain in the specified cluster func createFailureDomain(client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName, err := requiredParam[string](request.Params.Arguments, "cluster_name") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "cluster_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'cluster_name'. " + "Please provide the name of the cluster.")), nil } - domainName, err := requiredParam[string](request.Params.Arguments, "domain_name") + domainName, err := common.RequiredParam[string](request.Params.Arguments, "domain_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'domain_name'. " + "Please provide the name of the failure domain to create.")), nil } - brokers, err := requiredParamArray[string](request.Params.Arguments, "brokers") + brokers, err := common.RequiredParamArray[string](request.Params.Arguments, "brokers") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'brokers'. " + "Please provide an array of broker names to include in this failure domain.")), nil @@ -472,19 +473,19 @@ func createFailureDomain(client cmdutils.Client, request mcp.CallToolRequest) (* // updateFailureDomain updates an existing failure domain in the specified cluster func updateFailureDomain(client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName, err := requiredParam[string](request.Params.Arguments, "cluster_name") + clusterName, err := common.RequiredParam[string](request.Params.Arguments, "cluster_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'cluster_name'. " + "Please provide the name of the cluster.")), nil } - domainName, err := requiredParam[string](request.Params.Arguments, "domain_name") + domainName, err := common.RequiredParam[string](request.Params.Arguments, "domain_name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'domain_name'. " + "Please provide the name of the failure domain to update.")), nil } - brokers, err := requiredParamArray[string](request.Params.Arguments, "brokers") + brokers, err := common.RequiredParamArray[string](request.Params.Arguments, "brokers") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'brokers'. " + "Please provide an array of broker names to include in this failure domain.")), nil diff --git a/pkg/mcp/pulsar_admin_functions_tools.go b/pkg/mcp/pulsar_admin_functions_tools.go index 2783c10..cd7bdab 100644 --- a/pkg/mcp/pulsar_admin_functions_tools.go +++ b/pkg/mcp/pulsar_admin_functions_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" ) // PulsarAdminAddFunctionsTools adds a unified function-related tool to the MCP server @@ -166,7 +167,7 @@ func handleFunctionsTool(readOnly bool) func(ctx context.Context, request mcp.Ca client := cmdutils.NewPulsarClientWithAPIVersion(config.V3) // Extract and validate operation parameter - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'operation': %v", err)), nil } @@ -193,12 +194,12 @@ func handleFunctionsTool(readOnly bool) func(ctx context.Context, request mcp.Ca } // Extract common parameters - tenant, err := requiredParam[string](request.Params.Arguments, "tenant") + tenant, err := common.RequiredParam[string](request.Params.Arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'tenant': %v. A tenant is required for all Pulsar Functions operations.", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'namespace': %v. A namespace is required for all Pulsar Functions operations.", err)), nil } @@ -206,7 +207,7 @@ func handleFunctionsTool(readOnly bool) func(ctx context.Context, request mcp.Ca // For all operations except 'list', name is required var name string if operation != "list" { - name, err = requiredParam[string](request.Params.Arguments, "name") + name, err = common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'name' for operation '%s': %v. The function name must be specified for this operation.", operation, err)), nil } @@ -223,7 +224,7 @@ func handleFunctionsTool(readOnly bool) func(ctx context.Context, request mcp.Ca case "stats": return handleFunctionStats(ctx, client, tenant, namespace, name) case "querystate": - key, err := requiredParam[string](request.Params.Arguments, "key") + key, err := common.RequiredParam[string](request.Params.Arguments, "key") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'key' for operation 'querystate': %v. A key is required to look up state in the function's state store.", err)), nil } @@ -241,18 +242,18 @@ func handleFunctionsTool(readOnly bool) func(ctx context.Context, request mcp.Ca case "restart": return handleFunctionRestart(ctx, client, tenant, namespace, name) case "putstate": - key, err := requiredParam[string](request.Params.Arguments, "key") + key, err := common.RequiredParam[string](request.Params.Arguments, "key") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'key' for operation 'putstate': %v. A key is required to store state in the function's state store.", err)), nil } - value, err := requiredParam[string](request.Params.Arguments, "value") + value, err := common.RequiredParam[string](request.Params.Arguments, "value") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'value' for operation 'putstate': %v. A value is required to store state in the function's state store.", err)), nil } return handleFunctionPutstate(ctx, client, tenant, namespace, name, key, value) case "trigger": - topic, _ := optionalParam[string](request.Params.Arguments, "topic") - triggerValue, _ := optionalParam[string](request.Params.Arguments, "triggerValue") + topic, _ := common.OptionalParam[string](request.Params.Arguments, "topic") + triggerValue, _ := common.OptionalParam[string](request.Params.Arguments, "triggerValue") return handleFunctionTrigger(ctx, client, tenant, namespace, name, topic, triggerValue) default: // This should never happen due to the valid operations check above @@ -354,18 +355,18 @@ func handleFunctionQuerystate(_ context.Context, client cmdutils.Client, tenant, // handleFunctionCreate handles creating a new function func handleFunctionCreate(_ context.Context, client cmdutils.Client, tenant, namespace, name string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { // Extract required parameters - classname, err := requiredParam[string](arguments, "classname") + classname, err := common.RequiredParam[string](arguments, "classname") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get classname (required for function creation): %v. The classname must specify the fully qualified class implementing the function.", err)), nil } // Extract optional parameters - inputTopics, _ := optionalParamArray[string](arguments, "inputs") - output, _ := optionalParam[string](arguments, "output") - jar, _ := optionalParam[string](arguments, "jar") - py, _ := optionalParam[string](arguments, "py") - goPath, _ := optionalParam[string](arguments, "go") - parallelismFloat, _ := optionalParam[float64](arguments, "parallelism") + inputTopics, _ := common.OptionalParamArray[string](arguments, "inputs") + output, _ := common.OptionalParam[string](arguments, "output") + jar, _ := common.OptionalParam[string](arguments, "jar") + py, _ := common.OptionalParam[string](arguments, "py") + goPath, _ := common.OptionalParam[string](arguments, "go") + parallelismFloat, _ := common.OptionalParam[float64](arguments, "parallelism") parallelism := int(parallelismFloat) // Get user config if available @@ -440,13 +441,13 @@ func handleFunctionCreate(_ context.Context, client cmdutils.Client, tenant, nam // handleFunctionUpdate handles updating an existing function func handleFunctionUpdate(_ context.Context, client cmdutils.Client, tenant, namespace, name string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { // Extract optional parameters - classname, _ := optionalParam[string](arguments, "classname") - inputTopics, _ := optionalParamArray[string](arguments, "inputs") - output, _ := optionalParam[string](arguments, "output") - jar, _ := optionalParam[string](arguments, "jar") - py, _ := optionalParam[string](arguments, "py") - goPath, _ := optionalParam[string](arguments, "go") - parallelismFloat, _ := optionalParam[float64](arguments, "parallelism") + classname, _ := common.OptionalParam[string](arguments, "classname") + inputTopics, _ := common.OptionalParamArray[string](arguments, "inputs") + output, _ := common.OptionalParam[string](arguments, "output") + jar, _ := common.OptionalParam[string](arguments, "jar") + py, _ := common.OptionalParam[string](arguments, "py") + goPath, _ := common.OptionalParam[string](arguments, "go") + parallelismFloat, _ := common.OptionalParam[float64](arguments, "parallelism") parallelism := int(parallelismFloat) // Get user config if available diff --git a/pkg/mcp/pulsar_admin_functions_worker_tools.go b/pkg/mcp/pulsar_admin_functions_worker_tools.go index e5ed9db..bf84f49 100644 --- a/pkg/mcp/pulsar_admin_functions_worker_tools.go +++ b/pkg/mcp/pulsar_admin_functions_worker_tools.go @@ -26,6 +26,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -66,7 +67,7 @@ func handleFunctionsWorkerTool(_ bool) func(context.Context, mcp.CallToolRequest admin := pulsar.AdminClient // Get required resource parameter - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'resource'. " + "Please specify one of: function_stats, monitoring_metrics, cluster, cluster_leader, function_assignments.")), nil diff --git a/pkg/mcp/pulsar_admin_namespace_policy_tools.go b/pkg/mcp/pulsar_admin_namespace_policy_tools.go index ddf37b7..1372533 100644 --- a/pkg/mcp/pulsar_admin_namespace_policy_tools.go +++ b/pkg/mcp/pulsar_admin_namespace_policy_tools.go @@ -30,6 +30,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" pulsarctlutils "github.com/streamnative/pulsarctl/pkg/ctl/utils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -331,7 +332,7 @@ func handleNamespaceGetPolicies(_ context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -359,12 +360,12 @@ func handleSetMessageTTL(_ context.Context, request mcp.CallToolRequest) (*mcp.C return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - ttlStr, err := requiredParam[string](request.Params.Arguments, "ttl") + ttlStr, err := common.RequiredParam[string](request.Params.Arguments, "ttl") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get TTL: %v", err)), nil } @@ -391,13 +392,13 @@ func handleSetRetention(_ context.Context, request mcp.CallToolRequest) (*mcp.Ca return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - timeStr, hasTime := optionalParam[string](request.Params.Arguments, "time") - sizeStr, hasSize := optionalParam[string](request.Params.Arguments, "size") + timeStr, hasTime := common.OptionalParam[string](request.Params.Arguments, "time") + sizeStr, hasSize := common.OptionalParam[string](request.Params.Arguments, "size") if !hasTime && !hasSize { return mcp.NewToolResultError("At least one of 'time' or 'size' must be specified"), nil @@ -470,17 +471,17 @@ func handleGrantPermission(_ context.Context, request mcp.CallToolRequest) (*mcp return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - role, err := requiredParam[string](request.Params.Arguments, "role") + role, err := common.RequiredParam[string](request.Params.Arguments, "role") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get role: %v", err)), nil } - actions, err := requiredParamArray[string](request.Params.Arguments, "actions") + actions, err := common.RequiredParamArray[string](request.Params.Arguments, "actions") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get actions: %v", err)), nil } @@ -524,12 +525,12 @@ func handleRevokePermission(_ context.Context, request mcp.CallToolRequest) (*mc return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - role, err := requiredParam[string](request.Params.Arguments, "role") + role, err := common.RequiredParam[string](request.Params.Arguments, "role") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get role: %v", err)), nil } @@ -556,12 +557,12 @@ func handleSetReplicationClusters(_ context.Context, request mcp.CallToolRequest return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - clusters, err := requiredParamArray[string](request.Params.Arguments, "clusters") + clusters, err := common.RequiredParamArray[string](request.Params.Arguments, "clusters") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get clusters: %v", err)), nil } @@ -589,17 +590,17 @@ func handleSetBacklogQuota(_ context.Context, request mcp.CallToolRequest) (*mcp return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - limitSizeStr, err := requiredParam[string](request.Params.Arguments, "limit-size") + limitSizeStr, err := common.RequiredParam[string](request.Params.Arguments, "limit-size") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get limit size: %v", err)), nil } - policyStr, err := requiredParam[string](request.Params.Arguments, "policy") + policyStr, err := common.RequiredParam[string](request.Params.Arguments, "policy") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get policy: %v", err)), nil } @@ -611,7 +612,7 @@ func handleSetBacklogQuota(_ context.Context, request mcp.CallToolRequest) (*mcp } // Parse time limit (optional) - limitTimeStr, hasLimitTime := optionalParam[string](request.Params.Arguments, "limit-time") + limitTimeStr, hasLimitTime := common.OptionalParam[string](request.Params.Arguments, "limit-time") var limitTime int64 = -1 // Default to -1 (infinite) if hasLimitTime && limitTimeStr != "" { limitTimeVal, err := strconv.ParseInt(limitTimeStr, 10, 64) @@ -635,7 +636,7 @@ func handleSetBacklogQuota(_ context.Context, request mcp.CallToolRequest) (*mcp } // Parse quota type (optional) - quotaTypeStr, hasQuotaType := optionalParam[string](request.Params.Arguments, "type") + quotaTypeStr, hasQuotaType := common.OptionalParam[string](request.Params.Arguments, "type") quotaType := utils.DestinationStorage // Default if hasQuotaType && quotaTypeStr != "" { parsedType, err := utils.ParseBacklogQuotaType(quotaTypeStr) @@ -663,7 +664,7 @@ func handleRemoveBacklogQuota(_ context.Context, request mcp.CallToolRequest) (* return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -685,7 +686,7 @@ func handleSetTopicAutoCreation(_ context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -697,7 +698,7 @@ func handleSetTopicAutoCreation(_ context.Context, request mcp.CallToolRequest) } // Check if disabled - disableStr, hasDisable := optionalParam[string](request.Params.Arguments, "disable") + disableStr, hasDisable := common.OptionalParam[string](request.Params.Arguments, "disable") disable := false if hasDisable && disableStr == "true" { disable = true @@ -711,7 +712,7 @@ func handleSetTopicAutoCreation(_ context.Context, request mcp.CallToolRequest) // Only set topic type and partitions if not disabled if !disable { // Get topic type (optional) - topicTypeStr, hasType := optionalParam[string](request.Params.Arguments, "type") + topicTypeStr, hasType := common.OptionalParam[string](request.Params.Arguments, "type") if hasType && topicTypeStr != "" { parsedType, err := utils.ParseTopicType(topicTypeStr) if err != nil { @@ -721,7 +722,7 @@ func handleSetTopicAutoCreation(_ context.Context, request mcp.CallToolRequest) } // Get partitions (optional) - partitionsStr, hasPartitions := optionalParam[string](request.Params.Arguments, "partitions") + partitionsStr, hasPartitions := common.OptionalParam[string](request.Params.Arguments, "partitions") if hasPartitions && partitionsStr != "" { partitions, err := strconv.Atoi(partitionsStr) if err != nil { @@ -748,7 +749,7 @@ func handleRemoveTopicAutoCreation(_ context.Context, request mcp.CallToolReques return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -776,7 +777,7 @@ func handleSetSchemaValidationEnforced(_ context.Context, request mcp.CallToolRe return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -788,7 +789,7 @@ func handleSetSchemaValidationEnforced(_ context.Context, request mcp.CallToolRe } // Check if disabled - disableStr, hasDisable := optionalParam[string](request.Params.Arguments, "disable") + disableStr, hasDisable := common.OptionalParam[string](request.Params.Arguments, "disable") disable := false if hasDisable && disableStr == "true" { disable = true @@ -815,12 +816,12 @@ func handleSetSchemaAutoUpdateStrategy(_ context.Context, request mcp.CallToolRe return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - strategyStr, err := requiredParam[string](request.Params.Arguments, "compatibility") + strategyStr, err := common.RequiredParam[string](request.Params.Arguments, "compatibility") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get compatibility strategy: %v", err)), nil } @@ -856,7 +857,7 @@ func handleSetIsAllowAutoUpdateSchema(_ context.Context, request mcp.CallToolReq return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -868,8 +869,8 @@ func handleSetIsAllowAutoUpdateSchema(_ context.Context, request mcp.CallToolReq } // Check if enabled or disabled - enableStr, hasEnable := optionalParam[string](request.Params.Arguments, "enable") - disableStr, hasDisable := optionalParam[string](request.Params.Arguments, "disable") + enableStr, hasEnable := common.OptionalParam[string](request.Params.Arguments, "enable") + disableStr, hasDisable := common.OptionalParam[string](request.Params.Arguments, "disable") if (hasEnable && enableStr == "true") && (hasDisable && disableStr == "true") { return mcp.NewToolResultError("Specify only one of 'enable' or 'disable'"), nil @@ -906,12 +907,12 @@ func handleSetOffloadThreshold(_ context.Context, request mcp.CallToolRequest) ( return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - thresholdStr, err := requiredParam[string](request.Params.Arguments, "threshold") + thresholdStr, err := common.RequiredParam[string](request.Params.Arguments, "threshold") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get threshold: %v", err)), nil } @@ -947,12 +948,12 @@ func handleSetOffloadDeletionLag(_ context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - lagStr, err := requiredParam[string](request.Params.Arguments, "lag") + lagStr, err := common.RequiredParam[string](request.Params.Arguments, "lag") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get lag: %v", err)), nil } @@ -991,7 +992,7 @@ func handleClearOffloadDeletionLag(_ context.Context, request mcp.CallToolReques return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -1021,12 +1022,12 @@ func handleSetCompactionThreshold(_ context.Context, request mcp.CallToolRequest return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - thresholdStr, err := requiredParam[string](request.Params.Arguments, "threshold") + thresholdStr, err := common.RequiredParam[string](request.Params.Arguments, "threshold") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get threshold: %v", err)), nil } @@ -1062,12 +1063,12 @@ func handleSetMaxProducersPerTopic(_ context.Context, request mcp.CallToolReques return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - maxStr, err := requiredParam[string](request.Params.Arguments, "max") + maxStr, err := common.RequiredParam[string](request.Params.Arguments, "max") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get max value: %v", err)), nil } @@ -1107,12 +1108,12 @@ func handleSetMaxConsumersPerTopic(_ context.Context, request mcp.CallToolReques return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - maxStr, err := requiredParam[string](request.Params.Arguments, "max") + maxStr, err := common.RequiredParam[string](request.Params.Arguments, "max") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get max value: %v", err)), nil } @@ -1152,12 +1153,12 @@ func handleSetMaxConsumersPerSubscription(_ context.Context, request mcp.CallToo return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - maxStr, err := requiredParam[string](request.Params.Arguments, "max") + maxStr, err := common.RequiredParam[string](request.Params.Arguments, "max") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get max value: %v", err)), nil } @@ -1197,12 +1198,12 @@ func handleSetAntiAffinityGroup(_ context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - group, err := requiredParam[string](request.Params.Arguments, "group") + group, err := common.RequiredParam[string](request.Params.Arguments, "group") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get anti-affinity group: %v", err)), nil } @@ -1226,7 +1227,7 @@ func handleDeleteAntiAffinityGroup(_ context.Context, request mcp.CallToolReques return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -1250,22 +1251,22 @@ func handleSetPersistence(_ context.Context, request mcp.CallToolRequest) (*mcp. return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - ensembleSizeStr, err := requiredParam[string](request.Params.Arguments, "ensemble-size") + ensembleSizeStr, err := common.RequiredParam[string](request.Params.Arguments, "ensemble-size") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get ensemble size: %v", err)), nil } - writeQuorumSizeStr, err := requiredParam[string](request.Params.Arguments, "write-quorum-size") + writeQuorumSizeStr, err := common.RequiredParam[string](request.Params.Arguments, "write-quorum-size") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get write quorum size: %v", err)), nil } - ackQuorumSizeStr, err := requiredParam[string](request.Params.Arguments, "ack-quorum-size") + ackQuorumSizeStr, err := common.RequiredParam[string](request.Params.Arguments, "ack-quorum-size") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get ack quorum size: %v", err)), nil } @@ -1293,7 +1294,7 @@ func handleSetPersistence(_ context.Context, request mcp.CallToolRequest) (*mcp. // Parse optional rate parameter markDeleteMaxRate := 0.0 - markDeleteMaxRateStr, hasRate := optionalParam[string](request.Params.Arguments, "ml-mark-delete-max-rate") + markDeleteMaxRateStr, hasRate := common.OptionalParam[string](request.Params.Arguments, "ml-mark-delete-max-rate") if hasRate && markDeleteMaxRateStr != "" { rate, err := strconv.ParseFloat(markDeleteMaxRateStr, 64) if err != nil { @@ -1324,13 +1325,13 @@ func handleSetDeduplication(_ context.Context, request mcp.CallToolRequest) (*mc return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get enable flag, default to false if not specified - enable, _ := optionalParam[string](request.Params.Arguments, "enable") + enable, _ := common.OptionalParam[string](request.Params.Arguments, "enable") enableBool := false if enable == "true" { enableBool = true @@ -1355,13 +1356,13 @@ func handleSetEncryptionRequired(_ context.Context, request mcp.CallToolRequest) return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get disable flag, default to false if not specified - disable, _ := optionalParam[string](request.Params.Arguments, "disable") + disable, _ := common.OptionalParam[string](request.Params.Arguments, "disable") disableFlag := disable == "true" // Get namespace name @@ -1394,12 +1395,12 @@ func handleSetSubscriptionAuthMode(_ context.Context, request mcp.CallToolReques return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - mode, err := requiredParam[string](request.Params.Arguments, "mode") + mode, err := common.RequiredParam[string](request.Params.Arguments, "mode") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subscription auth mode: %v", err)), nil } @@ -1436,17 +1437,17 @@ func handleGrantSubscriptionPermission(_ context.Context, request mcp.CallToolRe return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - subscription, err := requiredParam[string](request.Params.Arguments, "subscription") + subscription, err := common.RequiredParam[string](request.Params.Arguments, "subscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subscription name: %v", err)), nil } - roles, err := requiredParamArray[string](request.Params.Arguments, "roles") + roles, err := common.RequiredParamArray[string](request.Params.Arguments, "roles") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get roles: %v", err)), nil } @@ -1481,17 +1482,17 @@ func handleRevokeSubscriptionPermission(_ context.Context, request mcp.CallToolR return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - subscription, err := requiredParam[string](request.Params.Arguments, "subscription") + subscription, err := common.RequiredParam[string](request.Params.Arguments, "subscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subscription name: %v", err)), nil } - role, err := requiredParam[string](request.Params.Arguments, "role") + role, err := common.RequiredParam[string](request.Params.Arguments, "role") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get role: %v", err)), nil } @@ -1522,13 +1523,13 @@ func handleSetDispatchRate(_ context.Context, request mcp.CallToolRequest) (*mcp return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get rate parameters - msgRateStr, _ := optionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInMsg") + msgRateStr, _ := common.OptionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInMsg") msgRate := -1 // Default value if msgRateStr != "" { parsedMsgRate, err := strconv.Atoi(msgRateStr) @@ -1538,7 +1539,7 @@ func handleSetDispatchRate(_ context.Context, request mcp.CallToolRequest) (*mcp msgRate = parsedMsgRate } - byteRateStr, _ := optionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInByte") + byteRateStr, _ := common.OptionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInByte") byteRate := int64(-1) // Default value if byteRateStr != "" { parsedByteRate, err := strconv.ParseInt(byteRateStr, 10, 64) @@ -1548,7 +1549,7 @@ func handleSetDispatchRate(_ context.Context, request mcp.CallToolRequest) (*mcp byteRate = parsedByteRate } - ratePeriodStr, _ := optionalParam[string](request.Params.Arguments, "ratePeriodInSecond") + ratePeriodStr, _ := common.OptionalParam[string](request.Params.Arguments, "ratePeriodInSecond") ratePeriod := 1 // Default value if ratePeriodStr != "" { parsedRatePeriod, err := strconv.Atoi(ratePeriodStr) @@ -1590,13 +1591,13 @@ func handleSetReplicatorDispatchRate(_ context.Context, request mcp.CallToolRequ return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get rate parameters - msgRateStr, _ := optionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInMsg") + msgRateStr, _ := common.OptionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInMsg") msgRate := -1 // Default value if msgRateStr != "" { parsedMsgRate, err := strconv.Atoi(msgRateStr) @@ -1606,7 +1607,7 @@ func handleSetReplicatorDispatchRate(_ context.Context, request mcp.CallToolRequ msgRate = parsedMsgRate } - byteRateStr, _ := optionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInByte") + byteRateStr, _ := common.OptionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInByte") byteRate := int64(-1) // Default value if byteRateStr != "" { parsedByteRate, err := strconv.ParseInt(byteRateStr, 10, 64) @@ -1616,7 +1617,7 @@ func handleSetReplicatorDispatchRate(_ context.Context, request mcp.CallToolRequ byteRate = parsedByteRate } - ratePeriodStr, _ := optionalParam[string](request.Params.Arguments, "ratePeriodInSecond") + ratePeriodStr, _ := common.OptionalParam[string](request.Params.Arguments, "ratePeriodInSecond") ratePeriod := 1 // Default value if ratePeriodStr != "" { parsedRatePeriod, err := strconv.Atoi(ratePeriodStr) @@ -1658,13 +1659,13 @@ func handleSetSubscribeRate(_ context.Context, request mcp.CallToolRequest) (*mc return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get rate parameters - subRateStr, _ := optionalParam[string](request.Params.Arguments, "subscribeThrottlingRatePerConsumer") + subRateStr, _ := common.OptionalParam[string](request.Params.Arguments, "subscribeThrottlingRatePerConsumer") subRate := -1 // Default value if subRateStr != "" { parsedSubRate, err := strconv.Atoi(subRateStr) @@ -1674,7 +1675,7 @@ func handleSetSubscribeRate(_ context.Context, request mcp.CallToolRequest) (*mc subRate = parsedSubRate } - periodStr, _ := optionalParam[string](request.Params.Arguments, "ratePeriodInSecond") + periodStr, _ := common.OptionalParam[string](request.Params.Arguments, "ratePeriodInSecond") period := 30 // Default value if periodStr != "" { parsedPeriod, err := strconv.Atoi(periodStr) @@ -1713,13 +1714,13 @@ func handleSetSubscriptionDispatchRate(_ context.Context, request mcp.CallToolRe return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get rate parameters - msgRateStr, _ := optionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInMsg") + msgRateStr, _ := common.OptionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInMsg") msgRate := -1 // Default value if msgRateStr != "" { parsedMsgRate, err := strconv.Atoi(msgRateStr) @@ -1729,7 +1730,7 @@ func handleSetSubscriptionDispatchRate(_ context.Context, request mcp.CallToolRe msgRate = parsedMsgRate } - byteRateStr, _ := optionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInByte") + byteRateStr, _ := common.OptionalParam[string](request.Params.Arguments, "dispatchThrottlingRateInByte") byteRate := int64(-1) // Default value if byteRateStr != "" { parsedByteRate, err := strconv.ParseInt(byteRateStr, 10, 64) @@ -1739,7 +1740,7 @@ func handleSetSubscriptionDispatchRate(_ context.Context, request mcp.CallToolRe byteRate = parsedByteRate } - periodStr, _ := optionalParam[string](request.Params.Arguments, "ratePeriodInSecond") + periodStr, _ := common.OptionalParam[string](request.Params.Arguments, "ratePeriodInSecond") period := 1 // Default value if periodStr != "" { parsedPeriod, err := strconv.Atoi(periodStr) @@ -1781,7 +1782,7 @@ func handleSetPublishRate(_ context.Context, request mcp.CallToolRequest) (*mcp. return mcp.NewToolResultError(fmt.Sprintf("Failed to get admin client: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -1799,7 +1800,7 @@ func handleSetPublishRate(_ context.Context, request mcp.CallToolRequest) (*mcp. } // Get message rate if provided - msgRateStr, hasMsgRate := optionalParam[string](request.Params.Arguments, "publishThrottlingRateInMsg") + msgRateStr, hasMsgRate := common.OptionalParam[string](request.Params.Arguments, "publishThrottlingRateInMsg") if hasMsgRate && msgRateStr != "" { msgRate, err := strconv.Atoi(msgRateStr) if err != nil { @@ -1809,7 +1810,7 @@ func handleSetPublishRate(_ context.Context, request mcp.CallToolRequest) (*mcp. } // Get byte rate if provided - byteRateStr, hasByteRate := optionalParam[string](request.Params.Arguments, "publishThrottlingRateInByte") + byteRateStr, hasByteRate := common.OptionalParam[string](request.Params.Arguments, "publishThrottlingRateInByte") if hasByteRate && byteRateStr != "" { byteRate, err := strconv.ParseInt(byteRateStr, 10, 64) if err != nil { @@ -1831,12 +1832,12 @@ func handleSetPublishRate(_ context.Context, request mcp.CallToolRequest) (*mcp. // handleNamespaceSetPolicy handles setting policies for a namespace using the unified tool func handleNamespaceSetPolicy(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - _, err := requiredParam[string](request.Params.Arguments, "namespace") + _, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - policyType, err := requiredParam[string](request.Params.Arguments, "policy") + policyType, err := common.RequiredParam[string](request.Params.Arguments, "policy") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get policy type: %v", err)), nil } @@ -1902,12 +1903,12 @@ func handleNamespaceSetPolicy(ctx context.Context, request mcp.CallToolRequest) // handleNamespaceRemovePolicy handles removing policies from a namespace using the unified tool func handleNamespaceRemovePolicy(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - _, err := requiredParam[string](request.Params.Arguments, "namespace") + _, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - policyType, err := requiredParam[string](request.Params.Arguments, "policy") + policyType, err := common.RequiredParam[string](request.Params.Arguments, "policy") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get policy type: %v", err)), nil } diff --git a/pkg/mcp/pulsar_admin_namespace_tools.go b/pkg/mcp/pulsar_admin_namespace_tools.go index 65e1542..4e35c9d 100644 --- a/pkg/mcp/pulsar_admin_namespace_tools.go +++ b/pkg/mcp/pulsar_admin_namespace_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -93,7 +94,7 @@ func PulsarAdminAddNamespaceTools(s *server.MCPServer, readOnly bool, features [ func handleNamespace(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get operation parameter - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -142,7 +143,7 @@ func handleNamespace(readOnly bool) func(context.Context, mcp.CallToolRequest) ( // handleNamespaceList handles listing namespaces for a tenant func handleNamespaceList(_ context.Context, client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](request.Params.Arguments, "tenant") + tenant, err := common.RequiredParam[string](request.Params.Arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get tenant name: %v", err)), nil } @@ -164,7 +165,7 @@ func handleNamespaceList(_ context.Context, client cmdutils.Client, request mcp. // handleNamespaceGetTopics handles getting topics for a namespace func handleNamespaceGetTopics(_ context.Context, client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -186,13 +187,13 @@ func handleNamespaceGetTopics(_ context.Context, client cmdutils.Client, request // handleNamespaceCreate handles creating a new namespace func handleNamespaceCreate(_ context.Context, client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get optional parameters - bundlesStr, hasBundles := optionalParam[string](request.Params.Arguments, "bundles") + bundlesStr, hasBundles := common.OptionalParam[string](request.Params.Arguments, "bundles") bundles := 0 if hasBundles && bundlesStr != "" { bundlesInt, err := strconv.Atoi(bundlesStr) @@ -202,7 +203,7 @@ func handleNamespaceCreate(_ context.Context, client cmdutils.Client, request mc bundles = bundlesInt } - clusters, _ := optionalParamArray[string](request.Params.Arguments, "clusters") + clusters, _ := common.OptionalParamArray[string](request.Params.Arguments, "clusters") // Prepare policies policies := utils.NewDefaultPolicies() @@ -238,7 +239,7 @@ func handleNamespaceCreate(_ context.Context, client cmdutils.Client, request mc // handleNamespaceDelete handles deleting a namespace func handleNamespaceDelete(_ context.Context, client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } @@ -254,15 +255,15 @@ func handleNamespaceDelete(_ context.Context, client cmdutils.Client, request mc // handleClearBacklog handles clearing the backlog for all topics in a namespace func handleClearBacklog(_ context.Context, client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get optional parameters - subscription, _ := optionalParam[string](request.Params.Arguments, "subscription") - bundle, _ := optionalParam[string](request.Params.Arguments, "bundle") - force, _ := optionalParam[string](request.Params.Arguments, "force") + subscription, _ := common.OptionalParam[string](request.Params.Arguments, "subscription") + bundle, _ := common.OptionalParam[string](request.Params.Arguments, "bundle") + force, _ := common.OptionalParam[string](request.Params.Arguments, "force") forceFlag := force == "true" // If not forced, return an error requiring explicit force flag @@ -304,18 +305,18 @@ func handleClearBacklog(_ context.Context, client cmdutils.Client, request mcp.C // handleUnsubscribe handles unsubscribing the specified subscription for all topics of a namespace func handleUnsubscribe(_ context.Context, client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - subscription, err := requiredParam[string](request.Params.Arguments, "subscription") + subscription, err := common.RequiredParam[string](request.Params.Arguments, "subscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subscription name: %v", err)), nil } // Get optional bundle - bundle, _ := optionalParam[string](request.Params.Arguments, "bundle") + bundle, _ := common.OptionalParam[string](request.Params.Arguments, "bundle") // Get namespace name ns, err := utils.GetNamespaceName(namespace) @@ -350,13 +351,13 @@ func handleUnsubscribe(_ context.Context, client cmdutils.Client, request mcp.Ca // handleUnload handles unloading a namespace from the current serving broker func handleUnload(_ context.Context, client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } // Get optional bundle - bundle, _ := optionalParam[string](request.Params.Arguments, "bundle") + bundle, _ := common.OptionalParam[string](request.Params.Arguments, "bundle") // Unload namespace var unloadErr error @@ -383,18 +384,18 @@ func handleUnload(_ context.Context, client cmdutils.Client, request mcp.CallToo // handleSplitBundle handles splitting a namespace bundle func handleSplitBundle(_ context.Context, client cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace name: %v", err)), nil } - bundle, err := requiredParam[string](request.Params.Arguments, "bundle") + bundle, err := common.RequiredParam[string](request.Params.Arguments, "bundle") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get bundle: %v", err)), nil } // Get optional unload flag - unloadStr, _ := optionalParam[string](request.Params.Arguments, "unload") + unloadStr, _ := common.OptionalParam[string](request.Params.Arguments, "unload") unload := unloadStr == "true" // Split namespace bundle diff --git a/pkg/mcp/pulsar_admin_nsisolationpolicy_tools.go b/pkg/mcp/pulsar_admin_nsisolationpolicy_tools.go index 191fb83..3da9cff 100644 --- a/pkg/mcp/pulsar_admin_nsisolationpolicy_tools.go +++ b/pkg/mcp/pulsar_admin_nsisolationpolicy_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -115,17 +116,17 @@ func handleNsIsolationPolicy(readOnly bool) func(context.Context, mcp.CallToolRe } // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } - cluster, err := requiredParam[string](request.Params.Arguments, "cluster") + cluster, err := common.RequiredParam[string](request.Params.Arguments, "cluster") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get cluster name: %v", err)), nil } @@ -157,7 +158,7 @@ func handleNsIsolationPolicy(readOnly bool) func(context.Context, mcp.CallToolRe func handlePolicyResource(client cmdutils.Client, operation, cluster string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { switch operation { case "get": - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'name' for policy.get: %v", err)), nil } @@ -192,7 +193,7 @@ func handlePolicyResource(client cmdutils.Client, operation, cluster string, req return mcp.NewToolResultText(string(policiesJSON)), nil case "delete": - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'name' for policy.delete: %v", err)), nil } @@ -206,7 +207,7 @@ func handlePolicyResource(client cmdutils.Client, operation, cluster string, req return mcp.NewToolResultText(fmt.Sprintf("Delete namespace isolation policy %s successfully", name)), nil case "set": - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'name' for policy.set: %v", err)), nil } @@ -221,11 +222,11 @@ func handlePolicyResource(client cmdutils.Client, operation, cluster string, req return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'primary' for policy.set: %v", err)), nil } - secondary, _ := optionalParamArray[string](request.Params.Arguments, "secondary") - autoFailoverPolicyType, _ := optionalParam[string](request.Params.Arguments, "autoFailoverPolicyType") + secondary, _ := common.OptionalParamArray[string](request.Params.Arguments, "secondary") + autoFailoverPolicyType, _ := common.OptionalParam[string](request.Params.Arguments, "autoFailoverPolicyType") // Parse autoFailoverPolicyParams as a map - autoFailoverPolicyParamsRaw, _ := optionalParam[map[string]interface{}](request.Params.Arguments, "autoFailoverPolicyParams") + autoFailoverPolicyParamsRaw, _ := common.OptionalParam[map[string]interface{}](request.Params.Arguments, "autoFailoverPolicyParams") autoFailoverPolicyParams := make(map[string]string) for k, v := range autoFailoverPolicyParamsRaw { if strVal, ok := v.(string); ok { @@ -257,7 +258,7 @@ func handlePolicyResource(client cmdutils.Client, operation, cluster string, req func handleBrokerResource(client cmdutils.Client, operation, cluster string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { switch operation { case "get": - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'name' for broker.get: %v", err)), nil } diff --git a/pkg/mcp/pulsar_admin_packages_tools.go b/pkg/mcp/pulsar_admin_packages_tools.go index 7c33f54..0ded0df 100644 --- a/pkg/mcp/pulsar_admin_packages_tools.go +++ b/pkg/mcp/pulsar_admin_packages_tools.go @@ -27,6 +27,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -112,12 +113,12 @@ func PulsarAdminAddPackagesTools(s *server.MCPServer, readOnly bool, features [] func handlePackageTool(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -151,7 +152,7 @@ func handlePackageTool(readOnly bool) func(context.Context, mcp.CallToolRequest) // handlePackageResource handles operations on a specific package func handlePackageResource(client cmdutils.Client, operation string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - packageName, err := requiredParam[string](request.Params.Arguments, "packageName") + packageName, err := common.RequiredParam[string](request.Params.Arguments, "packageName") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'packageName' for package operations: %v", err)), nil } @@ -188,12 +189,12 @@ func handlePackageResource(client cmdutils.Client, operation string, request mcp return mcp.NewToolResultText(string(metadataJSON)), nil case "update": - description, err := requiredParam[string](request.Params.Arguments, "description") + description, err := common.RequiredParam[string](request.Params.Arguments, "description") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'description' for package.update: %v", err)), nil } - contact, _ := optionalParam[string](request.Params.Arguments, "contact") + contact, _ := common.OptionalParam[string](request.Params.Arguments, "contact") properties := extractProperties(request.Params.Arguments) // Update package metadata @@ -214,7 +215,7 @@ func handlePackageResource(client cmdutils.Client, operation string, request mcp return mcp.NewToolResultText(fmt.Sprintf("The package '%s' deleted successfully", packageName)), nil case "download": - path, err := requiredParam[string](request.Params.Arguments, "path") + path, err := common.RequiredParam[string](request.Params.Arguments, "path") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'path' for package.download: %v", err)), nil } @@ -230,17 +231,17 @@ func handlePackageResource(client cmdutils.Client, operation string, request mcp ), nil case "upload": - path, err := requiredParam[string](request.Params.Arguments, "path") + path, err := common.RequiredParam[string](request.Params.Arguments, "path") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'path' for package.upload: %v", err)), nil } - description, err := requiredParam[string](request.Params.Arguments, "description") + description, err := common.RequiredParam[string](request.Params.Arguments, "description") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'description' for package.upload: %v", err)), nil } - contact, _ := optionalParam[string](request.Params.Arguments, "contact") + contact, _ := common.OptionalParam[string](request.Params.Arguments, "contact") properties := extractProperties(request.Params.Arguments) // Upload package @@ -262,12 +263,12 @@ func handlePackageResource(client cmdutils.Client, operation string, request mcp func handlePackagesResource(client cmdutils.Client, operation string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { switch operation { case "list": - packageType, err := requiredParam[string](request.Params.Arguments, "type") + packageType, err := common.RequiredParam[string](request.Params.Arguments, "type") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'type' for packages.list: %v", err)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'namespace' for packages.list: %v", err)), nil } diff --git a/pkg/mcp/pulsar_admin_resourcequotas_tools.go b/pkg/mcp/pulsar_admin_resourcequotas_tools.go index 896cf2d..059f63c 100644 --- a/pkg/mcp/pulsar_admin_resourcequotas_tools.go +++ b/pkg/mcp/pulsar_admin_resourcequotas_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -100,12 +101,12 @@ func PulsarAdminAddResourceQuotasTools(s *server.MCPServer, readOnly bool, featu func handleResourceQuotaTool(readOnly bool) func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -147,8 +148,8 @@ func handleResourceQuotaTool(readOnly bool) func(_ context.Context, request mcp. // handleQuotaGet handles getting a resource quota func handleQuotaGet(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get optional parameters - namespace, hasNamespace := optionalParam[string](request.Params.Arguments, "namespace") - bundle, hasBundle := optionalParam[string](request.Params.Arguments, "bundle") + namespace, hasNamespace := common.OptionalParam[string](request.Params.Arguments, "namespace") + bundle, hasBundle := common.OptionalParam[string](request.Params.Arguments, "bundle") // Check if both namespace and bundle are provided or neither is provided if (hasNamespace && !hasBundle) || (!hasNamespace && hasBundle) { @@ -191,35 +192,35 @@ func handleQuotaGet(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.Ca // handleQuotaSet handles setting a resource quota func handleQuotaSet(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters for set operation - msgRateIn, err := requiredParam[float64](request.Params.Arguments, "msgRateIn") + msgRateIn, err := common.RequiredParam[float64](request.Params.Arguments, "msgRateIn") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'msgRateIn' for quota.set: %v", err)), nil } - msgRateOut, err := requiredParam[float64](request.Params.Arguments, "msgRateOut") + msgRateOut, err := common.RequiredParam[float64](request.Params.Arguments, "msgRateOut") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'msgRateOut' for quota.set: %v", err)), nil } - bandwidthIn, err := requiredParam[float64](request.Params.Arguments, "bandwidthIn") + bandwidthIn, err := common.RequiredParam[float64](request.Params.Arguments, "bandwidthIn") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'bandwidthIn' for quota.set: %v", err)), nil } - bandwidthOut, err := requiredParam[float64](request.Params.Arguments, "bandwidthOut") + bandwidthOut, err := common.RequiredParam[float64](request.Params.Arguments, "bandwidthOut") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'bandwidthOut' for quota.set: %v", err)), nil } - memory, err := requiredParam[float64](request.Params.Arguments, "memory") + memory, err := common.RequiredParam[float64](request.Params.Arguments, "memory") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'memory' for quota.set: %v", err)), nil } // Get optional parameters - namespace, hasNamespace := optionalParam[string](request.Params.Arguments, "namespace") - bundle, hasBundle := optionalParam[string](request.Params.Arguments, "bundle") - dynamic, hasDynamic := optionalParam[bool](request.Params.Arguments, "dynamic") + namespace, hasNamespace := common.OptionalParam[string](request.Params.Arguments, "namespace") + bundle, hasBundle := common.OptionalParam[string](request.Params.Arguments, "bundle") + dynamic, hasDynamic := common.OptionalParam[bool](request.Params.Arguments, "dynamic") if !hasDynamic { dynamic = false @@ -263,12 +264,12 @@ func handleQuotaSet(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.Ca // handleQuotaReset handles resetting a resource quota func handleQuotaReset(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters for reset operation - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'namespace' for quota.reset: %v", err)), nil } - bundle, err := requiredParam[string](request.Params.Arguments, "bundle") + bundle, err := common.RequiredParam[string](request.Params.Arguments, "bundle") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'bundle' for quota.reset: %v", err)), nil } diff --git a/pkg/mcp/pulsar_admin_schemas_tools.go b/pkg/mcp/pulsar_admin_schemas_tools.go index e986e7c..eece294 100644 --- a/pkg/mcp/pulsar_admin_schemas_tools.go +++ b/pkg/mcp/pulsar_admin_schemas_tools.go @@ -30,6 +30,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -86,17 +87,17 @@ func PulsarAdminAddSchemasTools(s *server.MCPServer, readOnly bool, features []s func handleSchemaTool(readOnly bool) func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic'. Please provide the fully qualified topic name: %v", err)), nil } @@ -138,7 +139,7 @@ func handleSchemaTool(readOnly bool) func(_ context.Context, request mcp.CallToo // handleSchemaGet handles getting a schema func handleSchemaGet(admin cmdutils.Client, topic string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get optional version parameter - version, hasVersion := optionalParam[float64](request.Params.Arguments, "version") + version, hasVersion := common.OptionalParam[float64](request.Params.Arguments, "version") // Get schema info if hasVersion { @@ -194,7 +195,7 @@ func handleSchemaGet(admin cmdutils.Client, topic string, request mcp.CallToolRe // handleSchemaUpload handles uploading a schema func handleSchemaUpload(admin cmdutils.Client, topic string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - filename, err := requiredParam[string](request.Params.Arguments, "filename") + filename, err := common.RequiredParam[string](request.Params.Arguments, "filename") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'filename' for schema.upload. Please provide the path to the schema definition file: %v", err)), nil } diff --git a/pkg/mcp/pulsar_admin_sinks_tools.go b/pkg/mcp/pulsar_admin_sinks_tools.go index d03fe63..a098716 100644 --- a/pkg/mcp/pulsar_admin_sinks_tools.go +++ b/pkg/mcp/pulsar_admin_sinks_tools.go @@ -29,6 +29,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" ) // PulsarAdminAddSinksTools adds a unified sink-related tool to the MCP server @@ -135,7 +136,7 @@ func PulsarAdminAddSinksTools(s *server.MCPServer, readOnly bool, features []str func handleSinksTool(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract and validate operation parameter - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'operation': %v", err)), nil } @@ -169,12 +170,12 @@ func handleSinksTool(readOnly bool) func(context.Context, mcp.CallToolRequest) ( } // Extract common parameters (all operations except list-built-in require tenant and namespace) - tenant, err := requiredParam[string](request.Params.Arguments, "tenant") + tenant, err := common.RequiredParam[string](request.Params.Arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'tenant': %v. A tenant is required for operation '%s'.", err, operation)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'namespace': %v. A namespace is required for operation '%s'.", err, operation)), nil } @@ -182,7 +183,7 @@ func handleSinksTool(readOnly bool) func(context.Context, mcp.CallToolRequest) ( // For all operations except 'list', name is required var name string if operation != "list" { - name, err = requiredParam[string](request.Params.Arguments, "name") + name, err = common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'name' for operation '%s': %v. The sink name must be specified for this operation.", operation, err)), nil } @@ -268,17 +269,17 @@ func handleSinkStatus(_ context.Context, admin cmdutils.Client, tenant, namespac // handleSinkCreate handles creating a new sink func handleSinkCreate(_ context.Context, admin cmdutils.Client, arguments map[string]interface{}) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](arguments, "tenant") + tenant, err := common.RequiredParam[string](arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get tenant: %v", err)), nil } - namespace, err := requiredParam[string](arguments, "namespace") + namespace, err := common.RequiredParam[string](arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace: %v", err)), nil } - name, err := requiredParam[string](arguments, "name") + name, err := common.RequiredParam[string](arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get name: %v", err)), nil } @@ -292,32 +293,32 @@ func handleSinkCreate(_ context.Context, admin cmdutils.Client, arguments map[st } // Get optional parameters - archive, hasArchive := optionalParam[string](arguments, "archive") + archive, hasArchive := common.OptionalParam[string](arguments, "archive") if hasArchive && archive != "" { sinkData.Archive = archive } - sinkType, hasSinkType := optionalParam[string](arguments, "sink-type") + sinkType, hasSinkType := common.OptionalParam[string](arguments, "sink-type") if hasSinkType && sinkType != "" { sinkData.SinkType = sinkType } - inputsArray, hasInputs := optionalParamArray[string](arguments, "inputs") + inputsArray, hasInputs := common.OptionalParamArray[string](arguments, "inputs") if hasInputs && len(inputsArray) > 0 { sinkData.Inputs = strings.Join(inputsArray, ",") } - topicsPattern, hasTopicsPattern := optionalParam[string](arguments, "topics-pattern") + topicsPattern, hasTopicsPattern := common.OptionalParam[string](arguments, "topics-pattern") if hasTopicsPattern && topicsPattern != "" { sinkData.TopicsPattern = topicsPattern } - subsName, hasSubsName := optionalParam[string](arguments, "subs-name") + subsName, hasSubsName := common.OptionalParam[string](arguments, "subs-name") if hasSubsName && subsName != "" { sinkData.SubsName = subsName } - parallelismFloat, hasParallelism := optionalParam[float64](arguments, "parallelism") + parallelismFloat, hasParallelism := common.OptionalParam[float64](arguments, "parallelism") if hasParallelism { sinkData.Parallelism = int(parallelismFloat) } @@ -374,17 +375,17 @@ func handleSinkCreate(_ context.Context, admin cmdutils.Client, arguments map[st // handleSinkUpdate handles updating an existing sink func handleSinkUpdate(_ context.Context, admin cmdutils.Client, arguments map[string]interface{}) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](arguments, "tenant") + tenant, err := common.RequiredParam[string](arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get tenant: %v", err)), nil } - namespace, err := requiredParam[string](arguments, "namespace") + namespace, err := common.RequiredParam[string](arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace: %v", err)), nil } - name, err := requiredParam[string](arguments, "name") + name, err := common.RequiredParam[string](arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get name: %v", err)), nil } @@ -398,32 +399,32 @@ func handleSinkUpdate(_ context.Context, admin cmdutils.Client, arguments map[st } // Get optional parameters - archive, hasArchive := optionalParam[string](arguments, "archive") + archive, hasArchive := common.OptionalParam[string](arguments, "archive") if hasArchive && archive != "" { sinkData.Archive = archive } - sinkType, hasSinkType := optionalParam[string](arguments, "sink-type") + sinkType, hasSinkType := common.OptionalParam[string](arguments, "sink-type") if hasSinkType && sinkType != "" { sinkData.SinkType = sinkType } - inputsArray, hasInputs := optionalParamArray[string](arguments, "inputs") + inputsArray, hasInputs := common.OptionalParamArray[string](arguments, "inputs") if hasInputs && len(inputsArray) > 0 { sinkData.Inputs = strings.Join(inputsArray, ",") } - topicsPattern, hasTopicsPattern := optionalParam[string](arguments, "topics-pattern") + topicsPattern, hasTopicsPattern := common.OptionalParam[string](arguments, "topics-pattern") if hasTopicsPattern && topicsPattern != "" { sinkData.TopicsPattern = topicsPattern } - subsName, hasSubsName := optionalParam[string](arguments, "subs-name") + subsName, hasSubsName := common.OptionalParam[string](arguments, "subs-name") if hasSubsName && subsName != "" { sinkData.SubsName = subsName } - parallelismFloat, hasParallelism := optionalParam[float64](arguments, "parallelism") + parallelismFloat, hasParallelism := common.OptionalParam[float64](arguments, "parallelism") if hasParallelism { sinkData.Parallelism = int(parallelismFloat) } diff --git a/pkg/mcp/pulsar_admin_sources_tools.go b/pkg/mcp/pulsar_admin_sources_tools.go index 98cac07..664a9da 100644 --- a/pkg/mcp/pulsar_admin_sources_tools.go +++ b/pkg/mcp/pulsar_admin_sources_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" ) // PulsarAdminAddSourcesTools adds a unified source-related tool to the MCP server @@ -135,7 +136,7 @@ func PulsarAdminAddSourcesTools(s *server.MCPServer, readOnly bool, features []s func handleSourcesTool(readOnly bool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract and validate operation parameter - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'operation': %v", err)), nil } @@ -169,12 +170,12 @@ func handleSourcesTool(readOnly bool) func(context.Context, mcp.CallToolRequest) } // Extract common parameters (all operations except list-built-in require tenant and namespace) - tenant, err := requiredParam[string](request.Params.Arguments, "tenant") + tenant, err := common.RequiredParam[string](request.Params.Arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'tenant': %v. A tenant is required for operation '%s'.", err, operation)), nil } - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'namespace': %v. A namespace is required for operation '%s'.", err, operation)), nil } @@ -182,7 +183,7 @@ func handleSourcesTool(readOnly bool) func(context.Context, mcp.CallToolRequest) // For all operations except 'list', name is required var name string if operation != "list" { - name, err = requiredParam[string](request.Params.Arguments, "name") + name, err = common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'name' for operation '%s': %v. The source name must be specified for this operation.", operation, err)), nil } @@ -268,17 +269,17 @@ func handleSourceStatus(_ context.Context, admin cmdutils.Client, tenant, namesp // handleSourceCreate handles creating a new source func handleSourceCreate(_ context.Context, admin cmdutils.Client, arguments map[string]interface{}) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](arguments, "tenant") + tenant, err := common.RequiredParam[string](arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get tenant: %v", err)), nil } - namespace, err := requiredParam[string](arguments, "namespace") + namespace, err := common.RequiredParam[string](arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace: %v", err)), nil } - name, err := requiredParam[string](arguments, "name") + name, err := common.RequiredParam[string](arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get name: %v", err)), nil } @@ -292,42 +293,42 @@ func handleSourceCreate(_ context.Context, admin cmdutils.Client, arguments map[ } // Get optional parameters - archive, hasArchive := optionalParam[string](arguments, "archive") + archive, hasArchive := common.OptionalParam[string](arguments, "archive") if hasArchive && archive != "" { sourceData.Archive = archive } - sourceType, hasSourceType := optionalParam[string](arguments, "source-type") + sourceType, hasSourceType := common.OptionalParam[string](arguments, "source-type") if hasSourceType && sourceType != "" { sourceData.SourceType = sourceType } - destTopic, hasDestTopic := optionalParam[string](arguments, "destination-topic-name") + destTopic, hasDestTopic := common.OptionalParam[string](arguments, "destination-topic-name") if hasDestTopic && destTopic != "" { sourceData.DestinationTopicName = destTopic } - deserializationClassName, hasDeserialization := optionalParam[string](arguments, "deserialization-classname") + deserializationClassName, hasDeserialization := common.OptionalParam[string](arguments, "deserialization-classname") if hasDeserialization && deserializationClassName != "" { sourceData.DeserializationClassName = deserializationClassName } - schemaType, hasSchemaType := optionalParam[string](arguments, "schema-type") + schemaType, hasSchemaType := common.OptionalParam[string](arguments, "schema-type") if hasSchemaType && schemaType != "" { sourceData.SchemaType = schemaType } - className, hasClassName := optionalParam[string](arguments, "classname") + className, hasClassName := common.OptionalParam[string](arguments, "classname") if hasClassName && className != "" { sourceData.ClassName = className } - processingGuarantees, hasProcessingGuarantees := optionalParam[string](arguments, "processing-guarantees") + processingGuarantees, hasProcessingGuarantees := common.OptionalParam[string](arguments, "processing-guarantees") if hasProcessingGuarantees && processingGuarantees != "" { sourceData.ProcessingGuarantees = processingGuarantees } - parallelismFloat, hasParallelism := optionalParam[float64](arguments, "parallelism") + parallelismFloat, hasParallelism := common.OptionalParam[float64](arguments, "parallelism") if hasParallelism { sourceData.Parallelism = int(parallelismFloat) } @@ -384,17 +385,17 @@ func handleSourceCreate(_ context.Context, admin cmdutils.Client, arguments map[ // handleSourceUpdate handles updating an existing source func handleSourceUpdate(_ context.Context, admin cmdutils.Client, arguments map[string]interface{}) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](arguments, "tenant") + tenant, err := common.RequiredParam[string](arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get tenant: %v", err)), nil } - namespace, err := requiredParam[string](arguments, "namespace") + namespace, err := common.RequiredParam[string](arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get namespace: %v", err)), nil } - name, err := requiredParam[string](arguments, "name") + name, err := common.RequiredParam[string](arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get name: %v", err)), nil } @@ -408,42 +409,42 @@ func handleSourceUpdate(_ context.Context, admin cmdutils.Client, arguments map[ } // Get optional parameters - archive, hasArchive := optionalParam[string](arguments, "archive") + archive, hasArchive := common.OptionalParam[string](arguments, "archive") if hasArchive && archive != "" { sourceData.Archive = archive } - sourceType, hasSourceType := optionalParam[string](arguments, "source-type") + sourceType, hasSourceType := common.OptionalParam[string](arguments, "source-type") if hasSourceType && sourceType != "" { sourceData.SourceType = sourceType } - destTopic, hasDestTopic := optionalParam[string](arguments, "destination-topic-name") + destTopic, hasDestTopic := common.OptionalParam[string](arguments, "destination-topic-name") if hasDestTopic && destTopic != "" { sourceData.DestinationTopicName = destTopic } - deserializationClassName, hasDeserialization := optionalParam[string](arguments, "deserialization-classname") + deserializationClassName, hasDeserialization := common.OptionalParam[string](arguments, "deserialization-classname") if hasDeserialization && deserializationClassName != "" { sourceData.DeserializationClassName = deserializationClassName } - schemaType, hasSchemaType := optionalParam[string](arguments, "schema-type") + schemaType, hasSchemaType := common.OptionalParam[string](arguments, "schema-type") if hasSchemaType && schemaType != "" { sourceData.SchemaType = schemaType } - className, hasClassName := optionalParam[string](arguments, "classname") + className, hasClassName := common.OptionalParam[string](arguments, "classname") if hasClassName && className != "" { sourceData.ClassName = className } - processingGuarantees, hasProcessingGuarantees := optionalParam[string](arguments, "processing-guarantees") + processingGuarantees, hasProcessingGuarantees := common.OptionalParam[string](arguments, "processing-guarantees") if hasProcessingGuarantees && processingGuarantees != "" { sourceData.ProcessingGuarantees = processingGuarantees } - parallelismFloat, hasParallelism := optionalParam[float64](arguments, "parallelism") + parallelismFloat, hasParallelism := common.OptionalParam[float64](arguments, "parallelism") if hasParallelism { sourceData.Parallelism = int(parallelismFloat) } diff --git a/pkg/mcp/pulsar_admin_subscription_tools.go b/pkg/mcp/pulsar_admin_subscription_tools.go index 698d6e1..604765c 100644 --- a/pkg/mcp/pulsar_admin_subscription_tools.go +++ b/pkg/mcp/pulsar_admin_subscription_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -103,17 +104,17 @@ func PulsarAdminAddSubscriptionTools(s *server.MCPServer, readOnly bool, feature func handleSubscriptionTool(readOnly bool) func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic'. Please provide the fully qualified topic name: %v", err)), nil } @@ -185,13 +186,13 @@ func handleSubsList(admin cmdutils.Client, topicName *utils.TopicName) (*mcp.Cal // handleSubsCreate handles creating a new subscription func handleSubsCreate(admin cmdutils.Client, topicName *utils.TopicName, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameter - subscription, err := requiredParam[string](request.Params.Arguments, "subscription") + subscription, err := common.RequiredParam[string](request.Params.Arguments, "subscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'subscription' for subscription.create: %v", err)), nil } // Get optional messageID parameter (default is "latest") - messageID, hasMessageID := optionalParam[string](request.Params.Arguments, "messageId") + messageID, hasMessageID := common.OptionalParam[string](request.Params.Arguments, "messageId") if !hasMessageID { messageID = "latest" } @@ -230,13 +231,13 @@ func handleSubsCreate(admin cmdutils.Client, topicName *utils.TopicName, request // handleSubsDelete handles deleting a subscription func handleSubsDelete(admin cmdutils.Client, topicName *utils.TopicName, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameter - subscription, err := requiredParam[string](request.Params.Arguments, "subscription") + subscription, err := common.RequiredParam[string](request.Params.Arguments, "subscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'subscription' for subscription.delete: %v", err)), nil } // Get optional force parameter (default is false) - force, hasForce := optionalParam[bool](request.Params.Arguments, "force") + force, hasForce := common.OptionalParam[bool](request.Params.Arguments, "force") if !hasForce { force = false } @@ -267,12 +268,12 @@ func handleSubsDelete(admin cmdutils.Client, topicName *utils.TopicName, request // handleSubsSkip handles skipping messages for a subscription func handleSubsSkip(admin cmdutils.Client, topicName *utils.TopicName, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - subscription, err := requiredParam[string](request.Params.Arguments, "subscription") + subscription, err := common.RequiredParam[string](request.Params.Arguments, "subscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'subscription' for subscription.skip: %v", err)), nil } - count, err := requiredParam[float64](request.Params.Arguments, "count") + count, err := common.RequiredParam[float64](request.Params.Arguments, "count") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'count' for subscription.skip: %v", err)), nil } @@ -291,12 +292,12 @@ func handleSubsSkip(admin cmdutils.Client, topicName *utils.TopicName, request m // handleSubsExpire handles expiring messages for a subscription func handleSubsExpire(admin cmdutils.Client, topicName *utils.TopicName, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - subscription, err := requiredParam[string](request.Params.Arguments, "subscription") + subscription, err := common.RequiredParam[string](request.Params.Arguments, "subscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'subscription' for subscription.expire: %v", err)), nil } - expireTime, err := requiredParam[float64](request.Params.Arguments, "expireTimeInSeconds") + expireTime, err := common.RequiredParam[float64](request.Params.Arguments, "expireTimeInSeconds") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'expireTimeInSeconds' for subscription.expire: %v", err)), nil } @@ -317,12 +318,12 @@ func handleSubsExpire(admin cmdutils.Client, topicName *utils.TopicName, request // handleSubsResetCursor handles resetting a subscription cursor func handleSubsResetCursor(admin cmdutils.Client, topicName *utils.TopicName, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - subscription, err := requiredParam[string](request.Params.Arguments, "subscription") + subscription, err := common.RequiredParam[string](request.Params.Arguments, "subscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'subscription' for subscription.reset-cursor: %v", err)), nil } - messageID, err := requiredParam[string](request.Params.Arguments, "messageId") + messageID, err := common.RequiredParam[string](request.Params.Arguments, "messageId") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'messageId' for subscription.reset-cursor: %v", err)), nil } diff --git a/pkg/mcp/pulsar_admin_tenant_tools.go b/pkg/mcp/pulsar_admin_tenant_tools.go index 188fa21..f55c964 100644 --- a/pkg/mcp/pulsar_admin_tenant_tools.go +++ b/pkg/mcp/pulsar_admin_tenant_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -104,12 +105,12 @@ func PulsarAdminAddTenantTools(s *server.MCPServer, readOnly bool, features []st func handleTenantTool(readOnly bool) func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -171,7 +172,7 @@ func handleTenantsList(admin cmdutils.Client) (*mcp.CallToolResult, error) { // handleTenantGet handles getting a tenant's configuration func handleTenantGet(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](request.Params.Arguments, "tenant") + tenant, err := common.RequiredParam[string](request.Params.Arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'tenant' for tenant.get: %v", err)), nil } @@ -193,17 +194,17 @@ func handleTenantGet(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.C // handleTenantCreate handles creating a new tenant func handleTenantCreate(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](request.Params.Arguments, "tenant") + tenant, err := common.RequiredParam[string](request.Params.Arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'tenant' for tenant.create: %v", err)), nil } - adminRoles, hasAdminRoles := optionalParamArray[string](request.Params.Arguments, "adminRoles") + adminRoles, hasAdminRoles := common.OptionalParamArray[string](request.Params.Arguments, "adminRoles") if !hasAdminRoles { adminRoles = []string{} } - allowedClusters, hasAllowedClusters := optionalParamArray[string](request.Params.Arguments, "allowedClusters") + allowedClusters, hasAllowedClusters := common.OptionalParamArray[string](request.Params.Arguments, "allowedClusters") if !hasAllowedClusters { allowedClusters = []string{""} } @@ -238,7 +239,7 @@ func handleTenantCreate(admin cmdutils.Client, request mcp.CallToolRequest) (*mc // handleTenantUpdate handles updating an existing tenant func handleTenantUpdate(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](request.Params.Arguments, "tenant") + tenant, err := common.RequiredParam[string](request.Params.Arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'tenant' for tenant.update: %v", err)), nil } @@ -250,8 +251,8 @@ func handleTenantUpdate(admin cmdutils.Client, request mcp.CallToolRequest) (*mc } // Get update parameters - adminRoles, hasAdminRoles := optionalParamArray[string](request.Params.Arguments, "adminRoles") - allowedClusters, hasAllowedClusters := optionalParamArray[string](request.Params.Arguments, "allowedClusters") + adminRoles, hasAdminRoles := common.OptionalParamArray[string](request.Params.Arguments, "adminRoles") + allowedClusters, hasAllowedClusters := common.OptionalParamArray[string](request.Params.Arguments, "allowedClusters") // If parameters not provided, keep existing values if !hasAdminRoles { @@ -292,7 +293,7 @@ func handleTenantUpdate(admin cmdutils.Client, request mcp.CallToolRequest) (*mc // handleTenantDelete handles deleting an existing tenant func handleTenantDelete(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - tenant, err := requiredParam[string](request.Params.Arguments, "tenant") + tenant, err := common.RequiredParam[string](request.Params.Arguments, "tenant") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'tenant' for tenant.delete: %v", err)), nil } diff --git a/pkg/mcp/pulsar_admin_topic_policy_tools.go b/pkg/mcp/pulsar_admin_topic_policy_tools.go index ce5f813..fbc492a 100644 --- a/pkg/mcp/pulsar_admin_topic_policy_tools.go +++ b/pkg/mcp/pulsar_admin_topic_policy_tools.go @@ -27,6 +27,7 @@ import ( "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -336,7 +337,7 @@ func PulsarAdminAddTopicPolicyTools(s *server.MCPServer, readOnly bool, features // handleTopicsGetPublishRate gets the publish rate for a topic func handleTopicsGetPublishRate(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -376,7 +377,7 @@ func handleTopicsGetPublishRate(_ context.Context, request mcp.CallToolRequest) // handleTopicsSetPublishRate sets the publish rate for a topic func handleTopicsSetPublishRate(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -398,13 +399,13 @@ func handleTopicsSetPublishRate(_ context.Context, request mcp.CallToolRequest) publishRateInByte := int64(-1) // unlimited // Get publish rate in messages if provided - msgRateParam, hasMsgRate := optionalParam[float64](request.Params.Arguments, "publishThrottlingRateInMsg") + msgRateParam, hasMsgRate := common.OptionalParam[float64](request.Params.Arguments, "publishThrottlingRateInMsg") if hasMsgRate { publishRateInMsg = int64(msgRateParam) } // Get publish rate in bytes if provided - byteRateParam, hasByteRate := optionalParam[float64](request.Params.Arguments, "publishThrottlingRateInByte") + byteRateParam, hasByteRate := common.OptionalParam[float64](request.Params.Arguments, "publishThrottlingRateInByte") if hasByteRate { publishRateInByte = int64(byteRateParam) } @@ -447,7 +448,7 @@ func handleTopicsSetPublishRate(_ context.Context, request mcp.CallToolRequest) // handleTopicsRemovePublishRate removes the publish rate for a topic func handleTopicsRemovePublishRate(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -476,7 +477,7 @@ func handleTopicsRemovePublishRate(_ context.Context, request mcp.CallToolReques // handleTopicsGetPermissions gets the permissions on a topic func handleTopicsGetPermissions(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -511,17 +512,17 @@ func handleTopicsGetPermissions(_ context.Context, request mcp.CallToolRequest) // handleTopicsGrantPermissions grants a new permission to a role on a topic func handleTopicsGrantPermissions(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - role, err := requiredParam[string](request.Params.Arguments, "role") + role, err := common.RequiredParam[string](request.Params.Arguments, "role") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get role: %v", err)), nil } - actions, err := requiredParamArray[string](request.Params.Arguments, "actions") + actions, err := common.RequiredParamArray[string](request.Params.Arguments, "actions") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get actions: %v", err)), nil } @@ -566,12 +567,12 @@ func handleTopicsGrantPermissions(_ context.Context, request mcp.CallToolRequest // handleTopicsRevokePermissions revokes all permissions for a role on a topic func handleTopicsRevokePermissions(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - role, err := requiredParam[string](request.Params.Arguments, "role") + role, err := common.RequiredParam[string](request.Params.Arguments, "role") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get role: %v", err)), nil } @@ -620,7 +621,7 @@ func grantTopicPermission(admin interface{}, topicName utils.TopicName, role str // handleTopicsGetMessageTTL gets the message TTL for a topic func handleTopicsGetMessageTTL(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -654,12 +655,12 @@ func handleTopicsGetMessageTTL(_ context.Context, request mcp.CallToolRequest) ( // handleTopicsSetMessageTTL sets the message TTL for a topic func handleTopicsSetMessageTTL(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - ttl, err := requiredParam[float64](request.Params.Arguments, "ttl") + ttl, err := common.RequiredParam[float64](request.Params.Arguments, "ttl") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get TTL: %v", err)), nil } @@ -700,7 +701,7 @@ func handleTopicsSetMessageTTL(_ context.Context, request mcp.CallToolRequest) ( // handleTopicsRemoveMessageTTL removes the message TTL for a topic func handleTopicsRemoveMessageTTL(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -729,7 +730,7 @@ func handleTopicsRemoveMessageTTL(_ context.Context, request mcp.CallToolRequest // handleTopicsGetMaxProducers gets the maximum number of producers allowed for a topic func handleTopicsGetMaxProducers(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -764,12 +765,12 @@ func handleTopicsGetMaxProducers(_ context.Context, request mcp.CallToolRequest) // handleTopicsSetMaxProducers sets the maximum number of producers allowed for a topic func handleTopicsSetMaxProducers(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - maxProducers, err := requiredParam[float64](request.Params.Arguments, "maxProducers") + maxProducers, err := common.RequiredParam[float64](request.Params.Arguments, "maxProducers") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get maxProducers: %v", err)), nil } @@ -810,7 +811,7 @@ func handleTopicsSetMaxProducers(_ context.Context, request mcp.CallToolRequest) // handleTopicsRemoveMaxProducers removes the maximum producers limit for a topic func handleTopicsRemoveMaxProducers(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -839,7 +840,7 @@ func handleTopicsRemoveMaxProducers(_ context.Context, request mcp.CallToolReque // handleTopicsGetMaxConsumers gets the maximum number of consumers allowed for a topic func handleTopicsGetMaxConsumers(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -874,12 +875,12 @@ func handleTopicsGetMaxConsumers(_ context.Context, request mcp.CallToolRequest) // handleTopicsSetMaxConsumers sets the maximum number of consumers allowed for a topic func handleTopicsSetMaxConsumers(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - maxConsumers, err := requiredParam[float64](request.Params.Arguments, "maxConsumers") + maxConsumers, err := common.RequiredParam[float64](request.Params.Arguments, "maxConsumers") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get maxConsumers: %v", err)), nil } @@ -920,7 +921,7 @@ func handleTopicsSetMaxConsumers(_ context.Context, request mcp.CallToolRequest) // handleTopicsRemoveMaxConsumers removes the maximum consumers limit for a topic func handleTopicsRemoveMaxConsumers(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -950,7 +951,7 @@ func handleTopicsRemoveMaxConsumers(_ context.Context, request mcp.CallToolReque // gets the maximum number of unacknowledged messages allowed for a consumer on a topic func handleTopicsGetMaxUnackMessagesPerConsumer(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -988,12 +989,12 @@ func handleTopicsGetMaxUnackMessagesPerConsumer(_ context.Context, request mcp.C // sets the maximum number of unacknowledged messages allowed for a consumer on a topic func handleTopicsSetMaxUnackMessagesPerConsumer(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - maxUnack, err := requiredParam[float64](request.Params.Arguments, "maxUnackMessagesPerConsumer") + maxUnack, err := common.RequiredParam[float64](request.Params.Arguments, "maxUnackMessagesPerConsumer") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get maxUnackMessagesPerConsumer: %v", err)), nil } @@ -1035,7 +1036,7 @@ func handleTopicsSetMaxUnackMessagesPerConsumer(_ context.Context, request mcp.C // removes the maximum unacknowledged messages per consumer limit for a topic func handleTopicsRemoveMaxUnackMessagesPerConsumer(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1070,7 +1071,7 @@ func handleTopicsRemoveMaxUnackMessagesPerConsumer(_ context.Context, request mc // gets the maximum number of unacknowledged messages allowed for a subscription on a topic func handleTopicsGetMaxUnackMessagesPerSubscription(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1110,12 +1111,12 @@ func handleTopicsGetMaxUnackMessagesPerSubscription(_ context.Context, request m // sets the maximum number of unacknowledged messages allowed for a subscription on a topic func handleTopicsSetMaxUnackMessagesPerSubscription(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - maxUnack, err := requiredParam[float64](request.Params.Arguments, "maxUnackMessagesPerSubscription") + maxUnack, err := common.RequiredParam[float64](request.Params.Arguments, "maxUnackMessagesPerSubscription") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get maxUnackMessagesPerSubscription: %v", err)), nil } @@ -1161,7 +1162,7 @@ func handleTopicsSetMaxUnackMessagesPerSubscription(_ context.Context, request m // removes the maximum unacknowledged messages per subscription limit for a topic func handleTopicsRemoveMaxUnackMessagesPerSubscription(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1195,7 +1196,7 @@ func handleTopicsRemoveMaxUnackMessagesPerSubscription(_ context.Context, reques // handleTopicsGetPersistence gets the persistence policy for a topic func handleTopicsGetPersistence(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1230,22 +1231,22 @@ func handleTopicsGetPersistence(_ context.Context, request mcp.CallToolRequest) // handleTopicsSetPersistence sets the persistence policy for a topic func handleTopicsSetPersistence(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - ensembleSize, err := requiredParam[float64](request.Params.Arguments, "ensembleSize") + ensembleSize, err := common.RequiredParam[float64](request.Params.Arguments, "ensembleSize") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get ensembleSize: %v", err)), nil } - writeQuorum, err := requiredParam[float64](request.Params.Arguments, "writeQuorum") + writeQuorum, err := common.RequiredParam[float64](request.Params.Arguments, "writeQuorum") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get writeQuorum: %v", err)), nil } - ackQuorum, err := requiredParam[float64](request.Params.Arguments, "ackQuorum") + ackQuorum, err := common.RequiredParam[float64](request.Params.Arguments, "ackQuorum") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get ackQuorum: %v", err)), nil } @@ -1298,7 +1299,7 @@ func handleTopicsSetPersistence(_ context.Context, request mcp.CallToolRequest) // handleTopicsRemovePersistence removes the persistence policy for a topic func handleTopicsRemovePersistence(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1327,7 +1328,7 @@ func handleTopicsRemovePersistence(_ context.Context, request mcp.CallToolReques // handleTopicsGetDelayedDelivery gets the delayed delivery policy for a topic func handleTopicsGetDelayedDelivery(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1362,12 +1363,12 @@ func handleTopicsGetDelayedDelivery(_ context.Context, request mcp.CallToolReque // handleTopicsSetDelayedDelivery sets the delayed delivery policy for a topic func handleTopicsSetDelayedDelivery(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - delayInMillis, err := requiredParam[float64](request.Params.Arguments, "delayInMillis") + delayInMillis, err := common.RequiredParam[float64](request.Params.Arguments, "delayInMillis") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get delayInMillis: %v", err)), nil } @@ -1379,7 +1380,7 @@ func handleTopicsSetDelayedDelivery(_ context.Context, request mcp.CallToolReque // Default tick time is 1 second (1000ms) tickTime := 1000.0 - tickTimeParam, hasTickTime := optionalParam[float64](request.Params.Arguments, "tickTime") + tickTimeParam, hasTickTime := common.OptionalParam[float64](request.Params.Arguments, "tickTime") if hasTickTime && tickTimeParam > 0 { tickTime = tickTimeParam } @@ -1420,7 +1421,7 @@ func handleTopicsSetDelayedDelivery(_ context.Context, request mcp.CallToolReque // handleTopicsRemoveDelayedDelivery removes the delayed delivery policy for a topic func handleTopicsRemoveDelayedDelivery(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1449,7 +1450,7 @@ func handleTopicsRemoveDelayedDelivery(_ context.Context, request mcp.CallToolRe // handleTopicsGetDispatchRate gets the message dispatch rate for a topic func handleTopicsGetDispatchRate(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1489,7 +1490,7 @@ func handleTopicsGetDispatchRate(_ context.Context, request mcp.CallToolRequest) // handleTopicsSetDispatchRate sets the message dispatch rate for a topic func handleTopicsSetDispatchRate(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1512,19 +1513,19 @@ func handleTopicsSetDispatchRate(_ context.Context, request mcp.CallToolRequest) ratePeriodInSecond := int64(1) // default 1 second // Get dispatch rate in messages if provided - msgRateParam, hasMsgRate := optionalParam[float64](request.Params.Arguments, "dispatchThrottlingRateInMsg") + msgRateParam, hasMsgRate := common.OptionalParam[float64](request.Params.Arguments, "dispatchThrottlingRateInMsg") if hasMsgRate { dispatchRateInMsg = int64(msgRateParam) } // Get dispatch rate in bytes if provided - byteRateParam, hasByteRate := optionalParam[float64](request.Params.Arguments, "dispatchThrottlingRateInByte") + byteRateParam, hasByteRate := common.OptionalParam[float64](request.Params.Arguments, "dispatchThrottlingRateInByte") if hasByteRate { dispatchRateInByte = int64(byteRateParam) } // Get rate period if provided - ratePeriodParam, hasRatePeriod := optionalParam[float64](request.Params.Arguments, "ratePeriodInSecond") + ratePeriodParam, hasRatePeriod := common.OptionalParam[float64](request.Params.Arguments, "ratePeriodInSecond") if hasRatePeriod && ratePeriodParam > 0 { ratePeriodInSecond = int64(ratePeriodParam) } else if hasRatePeriod && ratePeriodParam <= 0 { @@ -1571,7 +1572,7 @@ func handleTopicsSetDispatchRate(_ context.Context, request mcp.CallToolRequest) // handleTopicsRemoveDispatchRate removes the message dispatch rate for a topic func handleTopicsRemoveDispatchRate(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1600,7 +1601,7 @@ func handleTopicsRemoveDispatchRate(_ context.Context, request mcp.CallToolReque // handleTopicsGetDeduplicationStatus gets the deduplication status for a topic func handleTopicsGetDeduplicationStatus(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1634,12 +1635,12 @@ func handleTopicsGetDeduplicationStatus(_ context.Context, request mcp.CallToolR // handleTopicsSetDeduplicationStatus sets the deduplication status for a topic func handleTopicsSetDeduplicationStatus(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - enabled, err := requiredParam[bool](request.Params.Arguments, "enabled") + enabled, err := common.RequiredParam[bool](request.Params.Arguments, "enabled") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get enabled parameter: %v", err)), nil } @@ -1673,7 +1674,7 @@ func handleTopicsSetDeduplicationStatus(_ context.Context, request mcp.CallToolR // handleTopicsRemoveDeduplicationStatus removes the deduplication status for a topic func handleTopicsRemoveDeduplicationStatus(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1702,7 +1703,7 @@ func handleTopicsRemoveDeduplicationStatus(_ context.Context, request mcp.CallTo // handleTopicsGetRetention gets the retention policy for a topic func handleTopicsGetRetention(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1715,7 +1716,7 @@ func handleTopicsGetRetention(_ context.Context, request mcp.CallToolRequest) (* // Check if applied policies should be included applied := false - appliedParam, hasApplied := optionalParam[bool](request.Params.Arguments, "applied") + appliedParam, hasApplied := common.OptionalParam[bool](request.Params.Arguments, "applied") if hasApplied { applied = appliedParam } @@ -1768,17 +1769,17 @@ func handleTopicsGetRetention(_ context.Context, request mcp.CallToolRequest) (* // handleTopicsSetRetention sets the retention policy for a topic func handleTopicsSetRetention(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - retentionTimeInMinutes, err := requiredParam[float64](request.Params.Arguments, "retentionTimeInMinutes") + retentionTimeInMinutes, err := common.RequiredParam[float64](request.Params.Arguments, "retentionTimeInMinutes") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get retentionTimeInMinutes: %v", err)), nil } - retentionSizeInMB, err := requiredParam[float64](request.Params.Arguments, "retentionSizeInMB") + retentionSizeInMB, err := common.RequiredParam[float64](request.Params.Arguments, "retentionSizeInMB") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get retentionSizeInMB: %v", err)), nil } @@ -1835,7 +1836,7 @@ func handleTopicsSetRetention(_ context.Context, request mcp.CallToolRequest) (* // handleTopicsRemoveRetention removes the retention policy for a topic func handleTopicsRemoveRetention(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1864,7 +1865,7 @@ func handleTopicsRemoveRetention(_ context.Context, request mcp.CallToolRequest) // handleTopicsGetBacklogQuota gets the backlog quota policy for a topic func handleTopicsGetBacklogQuota(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -1877,7 +1878,7 @@ func handleTopicsGetBacklogQuota(_ context.Context, request mcp.CallToolRequest) // Check if applied policies should be included applied := false - appliedParam, hasApplied := optionalParam[bool](request.Params.Arguments, "applied") + appliedParam, hasApplied := common.OptionalParam[bool](request.Params.Arguments, "applied") if hasApplied { applied = appliedParam } @@ -1910,17 +1911,17 @@ func handleTopicsGetBacklogQuota(_ context.Context, request mcp.CallToolRequest) // handleTopicsSetBacklogQuota sets the backlog quota policy for a topic func handleTopicsSetBacklogQuota(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - limitSize, err := requiredParam[float64](request.Params.Arguments, "limitSize") + limitSize, err := common.RequiredParam[float64](request.Params.Arguments, "limitSize") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get limitSize: %v", err)), nil } - policy, err := requiredParam[string](request.Params.Arguments, "policy") + policy, err := common.RequiredParam[string](request.Params.Arguments, "policy") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get policy: %v", err)), nil } @@ -1941,7 +1942,7 @@ func handleTopicsSetBacklogQuota(_ context.Context, request mcp.CallToolRequest) limitTime := int64(-1) // unlimited by default // Get limit time if provided - limitTimeParam, hasLimitTime := optionalParam[float64](request.Params.Arguments, "limitTime") + limitTimeParam, hasLimitTime := common.OptionalParam[float64](request.Params.Arguments, "limitTime") if hasLimitTime { limitTime = int64(limitTimeParam) } @@ -1956,7 +1957,7 @@ func handleTopicsSetBacklogQuota(_ context.Context, request mcp.CallToolRequest) backlogQuotaType := utils.DestinationStorage // Get quota type if provided - quotaTypeStr, hasQuotaType := optionalParam[string](request.Params.Arguments, "type") + quotaTypeStr, hasQuotaType := common.OptionalParam[string](request.Params.Arguments, "type") if hasQuotaType { parsedType, err := utils.ParseBacklogQuotaType(quotaTypeStr) if err != nil { @@ -1981,7 +1982,7 @@ func handleTopicsSetBacklogQuota(_ context.Context, request mcp.CallToolRequest) // handleTopicsRemoveBacklogQuota removes the backlog quota policy from a topic func handleTopicsRemoveBacklogQuota(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -2002,7 +2003,7 @@ func handleTopicsRemoveBacklogQuota(_ context.Context, request mcp.CallToolReque backlogQuotaType := utils.DestinationStorage // Get quota type if provided - quotaTypeStr, hasQuotaType := optionalParam[string](request.Params.Arguments, "type") + quotaTypeStr, hasQuotaType := common.OptionalParam[string](request.Params.Arguments, "type") if hasQuotaType { parsedType, err := utils.ParseBacklogQuotaType(quotaTypeStr) if err != nil { @@ -2023,7 +2024,7 @@ func handleTopicsRemoveBacklogQuota(_ context.Context, request mcp.CallToolReque // handleTopicsGetCompactionThreshold gets the compaction threshold for a topic func handleTopicsGetCompactionThreshold(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -2036,7 +2037,7 @@ func handleTopicsGetCompactionThreshold(_ context.Context, request mcp.CallToolR // Check if applied policies should be included applied := false - appliedParam, hasApplied := optionalParam[bool](request.Params.Arguments, "applied") + appliedParam, hasApplied := common.OptionalParam[bool](request.Params.Arguments, "applied") if hasApplied { applied = appliedParam } @@ -2065,12 +2066,12 @@ func handleTopicsGetCompactionThreshold(_ context.Context, request mcp.CallToolR // handleTopicsSetCompactionThreshold sets the compaction threshold for a topic func handleTopicsSetCompactionThreshold(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - threshold, err := requiredParam[float64](request.Params.Arguments, "threshold") + threshold, err := common.RequiredParam[float64](request.Params.Arguments, "threshold") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get threshold: %v", err)), nil } @@ -2111,7 +2112,7 @@ func handleTopicsSetCompactionThreshold(_ context.Context, request mcp.CallToolR // handleTopicsRemoveCompactionThreshold removes the compaction threshold for a topic func handleTopicsRemoveCompactionThreshold(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -2141,7 +2142,7 @@ func handleTopicsRemoveCompactionThreshold(_ context.Context, request mcp.CallTo // handleTopicsGetInactiveTopic gets the inactive topic policies for a topic func handleTopicsGetInactiveTopic(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -2154,7 +2155,7 @@ func handleTopicsGetInactiveTopic(_ context.Context, request mcp.CallToolRequest // Check if applied policies should be included applied := false - appliedParam, hasApplied := optionalParam[bool](request.Params.Arguments, "applied") + appliedParam, hasApplied := common.OptionalParam[bool](request.Params.Arguments, "applied") if hasApplied { applied = appliedParam } @@ -2183,22 +2184,22 @@ func handleTopicsGetInactiveTopic(_ context.Context, request mcp.CallToolRequest // handleTopicsSetInactiveTopic sets the inactive topic policies for a topic func handleTopicsSetInactiveTopic(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - enableDelete, err := requiredParam[bool](request.Params.Arguments, "enableDeleteWhileInactive") + enableDelete, err := common.RequiredParam[bool](request.Params.Arguments, "enableDeleteWhileInactive") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get enableDeleteWhileInactive: %v", err)), nil } - maxInactiveDuration, err := requiredParam[float64](request.Params.Arguments, "maxInactiveDurationSeconds") + maxInactiveDuration, err := common.RequiredParam[float64](request.Params.Arguments, "maxInactiveDurationSeconds") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get maxInactiveDurationSeconds: %v", err)), nil } - deleteModeStr, err := requiredParam[string](request.Params.Arguments, "deleteMode") + deleteModeStr, err := common.RequiredParam[string](request.Params.Arguments, "deleteMode") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get deleteMode: %v", err)), nil } @@ -2246,7 +2247,7 @@ func handleTopicsSetInactiveTopic(_ context.Context, request mcp.CallToolRequest // handleTopicsRemoveInactiveTopic removes the inactive topic policies from a topic func handleTopicsRemoveInactiveTopic(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } @@ -2276,12 +2277,12 @@ func handleTopicsRemoveInactiveTopic(_ context.Context, request mcp.CallToolRequ // handleTopicGetPolicy handles getting policies for a topic using the unified tool func handleTopicGetPolicy(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - _, err := requiredParam[string](request.Params.Arguments, "topic") + _, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } - policyType, err := requiredParam[string](request.Params.Arguments, "policy") + policyType, err := common.RequiredParam[string](request.Params.Arguments, "policy") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get policy type: %v", err)), nil } @@ -2326,12 +2327,12 @@ func handleTopicGetPolicy(ctx context.Context, request mcp.CallToolRequest) (*mc // handleTopicSetPolicy handles setting policies for a topic using the unified tool func handleTopicSetPolicy(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - _, err := requiredParam[string](request.Params.Arguments, "topic") + _, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } - policyType, err := requiredParam[string](request.Params.Arguments, "policy") + policyType, err := common.RequiredParam[string](request.Params.Arguments, "policy") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get policy type: %v", err)), nil } @@ -2376,12 +2377,12 @@ func handleTopicSetPolicy(ctx context.Context, request mcp.CallToolRequest) (*mc // handleTopicRemovePolicy handles removing policies for a topic using the unified tool func handleTopicRemovePolicy(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - _, err := requiredParam[string](request.Params.Arguments, "topic") + _, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic name: %v", err)), nil } - policyType, err := requiredParam[string](request.Params.Arguments, "policy") + policyType, err := common.RequiredParam[string](request.Params.Arguments, "policy") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get policy type: %v", err)), nil } diff --git a/pkg/mcp/pulsar_admin_topic_tools.go b/pkg/mcp/pulsar_admin_topic_tools.go index 805972c..4b98d7d 100644 --- a/pkg/mcp/pulsar_admin_topic_tools.go +++ b/pkg/mcp/pulsar_admin_topic_tools.go @@ -28,6 +28,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/streamnative/pulsarctl/pkg/cmdutils" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -135,12 +136,12 @@ func PulsarAdminAddTopicTools(s *server.MCPServer, readOnly bool, features []str func handleTopicTool(readOnly bool) func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - resource, err := requiredParam[string](request.Params.Arguments, "resource") + resource, err := common.RequiredParam[string](request.Params.Arguments, "resource") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get resource: %v", err)), nil } - operation, err := requiredParam[string](request.Params.Arguments, "operation") + operation, err := common.RequiredParam[string](request.Params.Arguments, "operation") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get operation: %v", err)), nil } @@ -216,7 +217,7 @@ func handleTopicTool(readOnly bool) func(_ context.Context, request mcp.CallTool // handleTopicsList lists all existing topics under the specified namespace func handleTopicsList(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - namespace, err := requiredParam[string](request.Params.Arguments, "namespace") + namespace, err := common.RequiredParam[string](request.Params.Arguments, "namespace") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'namespace' for topics.list: %v", err)), nil } @@ -254,7 +255,7 @@ func handleTopicsList(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp. // handleTopicGet gets the metadata of an existing topic func handleTopicGet(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.get: %v", err)), nil } @@ -284,14 +285,14 @@ func handleTopicGet(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.Ca // handleTopicStats gets the stats for an existing topic func handleTopicStats(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.stats: %v", err)), nil } // Get optional parameters - partitioned, hasPartitioned := optionalParam[bool](request.Params.Arguments, "partitioned") - perPartition, hasPerPartition := optionalParam[bool](request.Params.Arguments, "per-partition") + partitioned, hasPartitioned := common.OptionalParam[bool](request.Params.Arguments, "partitioned") + perPartition, hasPerPartition := common.OptionalParam[bool](request.Params.Arguments, "per-partition") if !hasPartitioned { partitioned = false @@ -342,7 +343,7 @@ func handleTopicStats(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp. // handleTopicLookup looks up the owner broker of a topic func handleTopicLookup(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.lookup: %v", err)), nil } @@ -372,12 +373,12 @@ func handleTopicLookup(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp // handleTopicCreate creates a topic with the specified number of partitions func handleTopicCreate(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.create: %v", err)), nil } - partitions, err := requiredParam[float64](request.Params.Arguments, "partitions") + partitions, err := common.RequiredParam[float64](request.Params.Arguments, "partitions") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'partitions' for topic.create: %v", err)), nil } @@ -411,14 +412,14 @@ func handleTopicCreate(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp // handleTopicDelete deletes a topic func handleTopicDelete(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.delete: %v", err)), nil } // Get optional parameters - force, hasForce := optionalParam[bool](request.Params.Arguments, "force") - nonPartitioned, hasNonPartitioned := optionalParam[bool](request.Params.Arguments, "non-partitioned") + force, hasForce := common.OptionalParam[bool](request.Params.Arguments, "force") + nonPartitioned, hasNonPartitioned := common.OptionalParam[bool](request.Params.Arguments, "non-partitioned") if !hasForce { force = false @@ -457,7 +458,7 @@ func handleTopicDelete(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp // handleTopicUnload unloads a topic func handleTopicUnload(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.unload: %v", err)), nil } @@ -480,7 +481,7 @@ func handleTopicUnload(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp // handleTopicTerminate terminates a topic func handleTopicTerminate(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.terminate: %v", err)), nil } @@ -508,7 +509,7 @@ func handleTopicTerminate(admin cmdutils.Client, request mcp.CallToolRequest) (* // handleTopicCompact triggers compaction on a topic func handleTopicCompact(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.compact: %v", err)), nil } @@ -532,7 +533,7 @@ func handleTopicCompact(admin cmdutils.Client, request mcp.CallToolRequest) (*mc // handleTopicInternalStats gets the internal stats for a topic func handleTopicInternalStats(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.internal-stats: %v", err)), nil } @@ -561,7 +562,7 @@ func handleTopicInternalStats(admin cmdutils.Client, request mcp.CallToolRequest // handleTopicInternalInfo gets the internal info for a topic func handleTopicInternalInfo(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.internal-info: %v", err)), nil } @@ -590,7 +591,7 @@ func handleTopicInternalInfo(admin cmdutils.Client, request mcp.CallToolRequest) // handleTopicBundleRange gets the bundle range of a topic func handleTopicBundleRange(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.bundle-range: %v", err)), nil } @@ -613,7 +614,7 @@ func handleTopicBundleRange(admin cmdutils.Client, request mcp.CallToolRequest) // handleTopicLastMessageID gets the last message ID of a topic func handleTopicLastMessageID(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.last-message-id: %v", err)), nil } @@ -642,7 +643,7 @@ func handleTopicLastMessageID(admin cmdutils.Client, request mcp.CallToolRequest // handleTopicStatus gets the status of a topic func handleTopicStatus(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.status: %v", err)), nil } @@ -680,12 +681,12 @@ func handleTopicStatus(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp // handleTopicUpdate updates a topic configuration func handleTopicUpdate(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.update: %v", err)), nil } - partitions, err := requiredParam[float64](request.Params.Arguments, "partitions") + partitions, err := common.RequiredParam[float64](request.Params.Arguments, "partitions") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'partitions' for topic.update: %v", err)), nil } @@ -708,12 +709,12 @@ func handleTopicUpdate(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp // handleTopicOffload offloads data from a topic to long-term storage func handleTopicOffload(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.offload: %v", err)), nil } - messageIDStr, err := requiredParam[string](request.Params.Arguments, "messageId") + messageIDStr, err := common.RequiredParam[string](request.Params.Arguments, "messageId") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'messageId' for topic.offload: %v", err)), nil } @@ -751,7 +752,7 @@ func handleTopicOffload(admin cmdutils.Client, request mcp.CallToolRequest) (*mc // handleTopicOffloadStatus checks the status of data offloading for a topic func handleTopicOffloadStatus(admin cmdutils.Client, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get required parameters - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Missing required parameter 'topic' for topic.offload-status: %v", err)), nil } diff --git a/pkg/mcp/pulsar_client_consume_tools.go b/pkg/mcp/pulsar_client_consume_tools.go index ec71b23..b247034 100644 --- a/pkg/mcp/pulsar_client_consume_tools.go +++ b/pkg/mcp/pulsar_client_consume_tools.go @@ -28,6 +28,7 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" mcppulsar "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -76,49 +77,49 @@ func PulsarClientAddConsumerTools(s *server.MCPServer, _ bool, features []string func handleClientConsume(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract required parameters with validation - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } - subscriptionName, err := requiredParam[string](request.Params.Arguments, "subscription-name") + subscriptionName, err := common.RequiredParam[string](request.Params.Arguments, "subscription-name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get subscription name: %v", err)), nil } // Set default values and extract optional parameters subscriptionType := "exclusive" - if val, exists := optionalParam[string](request.Params.Arguments, "subscription-type"); exists && val != "" { + if val, exists := common.OptionalParam[string](request.Params.Arguments, "subscription-type"); exists && val != "" { subscriptionType = val } subscriptionMode := "durable" - if val, exists := optionalParam[string](request.Params.Arguments, "subscription-mode"); exists && val != "" { + if val, exists := common.OptionalParam[string](request.Params.Arguments, "subscription-mode"); exists && val != "" { subscriptionMode = val } initialPosition := "latest" - if val, exists := optionalParam[string](request.Params.Arguments, "initial-position"); exists && val != "" { + if val, exists := common.OptionalParam[string](request.Params.Arguments, "initial-position"); exists && val != "" { initialPosition = val } numMessages := 10 - if val, exists := optionalParam[float64](request.Params.Arguments, "num-messages"); exists { + if val, exists := common.OptionalParam[float64](request.Params.Arguments, "num-messages"); exists { numMessages = int(val) } timeout := 30 - if val, exists := optionalParam[float64](request.Params.Arguments, "timeout"); exists { + if val, exists := common.OptionalParam[float64](request.Params.Arguments, "timeout"); exists { timeout = int(val) } showProperties := false - if val, exists := optionalParam[bool](request.Params.Arguments, "show-properties"); exists { + if val, exists := common.OptionalParam[bool](request.Params.Arguments, "show-properties"); exists { showProperties = val } hidePayload := false - if val, exists := optionalParam[bool](request.Params.Arguments, "hide-payload"); exists { + if val, exists := common.OptionalParam[bool](request.Params.Arguments, "hide-payload"); exists { hidePayload = val } diff --git a/pkg/mcp/pulsar_client_produce_tools.go b/pkg/mcp/pulsar_client_produce_tools.go index a0be5c6..c064cae 100644 --- a/pkg/mcp/pulsar_client_produce_tools.go +++ b/pkg/mcp/pulsar_client_produce_tools.go @@ -28,6 +28,7 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" mcppulsar "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" ) @@ -87,14 +88,14 @@ func PulsarClientAddProducerTools(s *server.MCPServer, _ bool, features []string func handleClientProduce(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract required parameters with validation - topic, err := requiredParam[string](request.Params.Arguments, "topic") + topic, err := common.RequiredParam[string](request.Params.Arguments, "topic") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get topic: %v", err)), nil } // Set default values and extract optional parameters messages := []string{} - if val, exists := optionalParam[[]interface{}](request.Params.Arguments, "messages"); exists && len(val) > 0 { + if val, exists := common.OptionalParam[[]interface{}](request.Params.Arguments, "messages"); exists && len(val) > 0 { for _, m := range val { if strMsg, ok := m.(string); ok { messages = append(messages, strMsg) @@ -107,32 +108,32 @@ func handleClientProduce(ctx context.Context, request mcp.CallToolRequest) (*mcp } numProduce := 1 - if val, exists := optionalParam[float64](request.Params.Arguments, "num-produce"); exists { + if val, exists := common.OptionalParam[float64](request.Params.Arguments, "num-produce"); exists { numProduce = int(val) } rate := 0.0 - if val, exists := optionalParam[float64](request.Params.Arguments, "rate"); exists { + if val, exists := common.OptionalParam[float64](request.Params.Arguments, "rate"); exists { rate = val } disableBatching := false - if val, exists := optionalParam[bool](request.Params.Arguments, "disable-batching"); exists { + if val, exists := common.OptionalParam[bool](request.Params.Arguments, "disable-batching"); exists { disableBatching = val } chunkingAllowed := false - if val, exists := optionalParam[bool](request.Params.Arguments, "chunking"); exists { + if val, exists := common.OptionalParam[bool](request.Params.Arguments, "chunking"); exists { chunkingAllowed = val } separator := "" - if val, exists := optionalParam[string](request.Params.Arguments, "separator"); exists && val != "" { + if val, exists := common.OptionalParam[string](request.Params.Arguments, "separator"); exists && val != "" { separator = val } properties := []string{} - if val, exists := optionalParam[[]interface{}](request.Params.Arguments, "properties"); exists && len(val) > 0 { + if val, exists := common.OptionalParam[[]interface{}](request.Params.Arguments, "properties"); exists && len(val) > 0 { for _, p := range val { if strProp, ok := p.(string); ok { properties = append(properties, strProp) @@ -141,7 +142,7 @@ func handleClientProduce(ctx context.Context, request mcp.CallToolRequest) (*mcp } key := "" - if val, exists := optionalParam[string](request.Params.Arguments, "key"); exists { + if val, exists := common.OptionalParam[string](request.Params.Arguments, "key"); exists { key = val } diff --git a/pkg/mcp/pulsar_functions_as_tools.go b/pkg/mcp/pulsar_functions_as_tools.go new file mode 100644 index 0000000..87d6fa5 --- /dev/null +++ b/pkg/mcp/pulsar_functions_as_tools.go @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package mcp + +import ( + "log" + "os" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/pftools" +) + +var ( + functionManagers = make(map[string]*pftools.PulsarFunctionManager) + functionManagersLock sync.RWMutex +) + +func StopAllPulsarFunctionManagers() { + functionManagersLock.Lock() + defer functionManagersLock.Unlock() + + for id, manager := range functionManagers { + log.Printf("Stopping Pulsar Function manager: %s", id) + manager.Stop() + delete(functionManagers, id) + } + + if len(functionManagers) > 0 { + time.Sleep(500 * time.Millisecond) + } + + log.Println("All Pulsar Function managers stopped") +} + +func PulsarFunctionManagedMcpTools(s *server.MCPServer, readOnly bool, features []string) { + if !slices.Contains(features, string(FeatureAll)) && + !slices.Contains(features, string(FeatureFunctionsAsTools)) && + !slices.Contains(features, string(FeatureStreamNativeCloud)) { + return + } + + options := pftools.DefaultManagerOptions() + + if pollIntervalStr := os.Getenv("FUNCTIONS_AS_TOOLS_POLL_INTERVAL"); pollIntervalStr != "" { + if seconds, err := strconv.Atoi(pollIntervalStr); err == nil && seconds > 0 { + options.PollInterval = time.Duration(seconds) * time.Second + log.Printf("Setting Pulsar Functions poll interval to %v", options.PollInterval) + } + } + + if timeoutStr := os.Getenv("FUNCTIONS_AS_TOOLS_TIMEOUT"); timeoutStr != "" { + if seconds, err := strconv.Atoi(timeoutStr); err == nil && seconds > 0 { + options.DefaultTimeout = time.Duration(seconds) * time.Second + log.Printf("Setting Pulsar Functions default timeout to %v", options.DefaultTimeout) + } + } + + if failureThresholdStr := os.Getenv("FUNCTIONS_AS_TOOLS_FAILURE_THRESHOLD"); failureThresholdStr != "" { + if threshold, err := strconv.Atoi(failureThresholdStr); err == nil && threshold > 0 { + options.FailureThreshold = threshold + log.Printf("Setting Pulsar Functions failure threshold to %d", options.FailureThreshold) + } + } + + if resetTimeoutStr := os.Getenv("FUNCTIONS_AS_TOOLS_RESET_TIMEOUT"); resetTimeoutStr != "" { + if seconds, err := strconv.Atoi(resetTimeoutStr); err == nil && seconds > 0 { + options.ResetTimeout = time.Duration(seconds) * time.Second + log.Printf("Setting Pulsar Functions reset timeout to %v", options.ResetTimeout) + } + } + + if tenantNamespacesStr := os.Getenv("FUNCTIONS_AS_TOOLS_TENANT_NAMESPACES"); tenantNamespacesStr != "" { + options.TenantNamespaces = strings.Split(tenantNamespacesStr, ",") + log.Printf("Setting Pulsar Functions tenant namespaces to %v", options.TenantNamespaces) + } + + if strictExportStr := os.Getenv("FUNCTIONS_AS_TOOLS_STRICT_EXPORT"); strictExportStr != "" { + options.StrictExport = strictExportStr == "true" + log.Printf("Setting Pulsar Functions strict export to %v", options.StrictExport) + } + + manager, err := pftools.NewPulsarFunctionManager(s, readOnly, options) + if err != nil { + log.Printf("Failed to create Pulsar Function manager: %v", err) + return + } + + manager.Start() + + managerID := "FUNCTIONS_AS_TOOLS_manager_" + strconv.FormatInt(time.Now().UnixNano(), 10) + functionManagersLock.Lock() + functionManagers[managerID] = manager + functionManagersLock.Unlock() + + log.Printf("Registered Pulsar Function Manager with ID: %s", managerID) + log.Printf("Pulsar Functions as MCP Tools service started. Functions will be dynamically converted to MCP tools.") +} diff --git a/pkg/mcp/streamnative_resources_log_tools.go b/pkg/mcp/streamnative_resources_log_tools.go index 3146681..1611748 100644 --- a/pkg/mcp/streamnative_resources_log_tools.go +++ b/pkg/mcp/streamnative_resources_log_tools.go @@ -31,6 +31,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/config" ) @@ -112,54 +113,54 @@ type LogContent struct { } func handleStreamNativeResourcesLog(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - snConfig := getOptions(ctx) + snConfig := common.GetOptions(ctx) instance, cluster, organization := GetMcpContext() if instance == "" || cluster == "" || organization == "" { return mcp.NewToolResultError("No context is set, please use `sncloud_context_use_cluster` to set the context first."), nil } // Extract required parameters with validation - component, err := requiredParam[string](request.Params.Arguments, "component") + component, err := common.RequiredParam[string](request.Params.Arguments, "component") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get component: %v", err)), nil } - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get name: %v", err)), nil } - tenant, hasTenant := optionalParam[string](request.Params.Arguments, "tenant") + tenant, hasTenant := common.OptionalParam[string](request.Params.Arguments, "tenant") if !hasTenant { tenant = "public" } - namespace, hasNamespace := optionalParam[string](request.Params.Arguments, "namespace") + namespace, hasNamespace := common.OptionalParam[string](request.Params.Arguments, "namespace") if !hasNamespace { namespace = "default" } - size, hasSize := optionalParam[string](request.Params.Arguments, "size") + size, hasSize := common.OptionalParam[string](request.Params.Arguments, "size") if !hasSize { size = "20" } - replicaID, hasreplicaID := optionalParam[int](request.Params.Arguments, "replica_id") + replicaID, hasreplicaID := common.OptionalParam[int](request.Params.Arguments, "replica_id") if !hasreplicaID { replicaID = -1 } - timestampStr, hasTimestamp := optionalParam[string](request.Params.Arguments, "timestamp") + timestampStr, hasTimestamp := common.OptionalParam[string](request.Params.Arguments, "timestamp") if !hasTimestamp { timestampStr = "" } - sinceStr, hasSince := optionalParam[string](request.Params.Arguments, "since") + sinceStr, hasSince := common.OptionalParam[string](request.Params.Arguments, "since") if !hasSince { sinceStr = "" } - previousContainer, hasPreviousContainer := optionalParam[bool](request.Params.Arguments, "previous_container") + previousContainer, hasPreviousContainer := common.OptionalParam[bool](request.Params.Arguments, "previous_container") if !hasPreviousContainer { previousContainer = false } diff --git a/pkg/mcp/streamnative_resources_tools.go b/pkg/mcp/streamnative_resources_tools.go index 9f73f58..7d09842 100644 --- a/pkg/mcp/streamnative_resources_tools.go +++ b/pkg/mcp/streamnative_resources_tools.go @@ -11,6 +11,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/streamnative-mcp-server/pkg/common" "github.com/streamnative/streamnative-mcp-server/pkg/config" sncloud "github.com/streamnative/streamnative-mcp-server/sdk/sdk-apiserver" ) @@ -73,20 +74,20 @@ type Metadata struct { // handleStreamNativeResourcesApply handles the streaming_cloud_resources_apply tool func handleStreamNativeResourcesApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get necessary parameters - snConfig := getOptions(ctx) + snConfig := common.GetOptions(ctx) organization := snConfig.Organization if organization == "" { return mcp.NewToolResultError("No organization is set. Please set the organization using the appropriate context tool."), nil } // Get YAML content - jsonContent, err := requiredParam[string](request.Params.Arguments, "json_content") + jsonContent, err := common.RequiredParam[string](request.Params.Arguments, "json_content") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get json_content: %v", err)), nil } // Get dry_run flag - dryRun, hasDryRun := optionalParam[bool](request.Params.Arguments, "dry_run") + dryRun, hasDryRun := common.OptionalParam[bool](request.Params.Arguments, "dry_run") if !hasDryRun { dryRun = false } @@ -323,15 +324,15 @@ func applyPulsarCluster(ctx context.Context, apiClient *sncloud.APIClient, jsonC // handleStreamNativeResourcesDelete handles the streaming_cloud_resources_delete tool func handleStreamNativeResourcesDelete(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - snConfig := getOptions(ctx) + snConfig := common.GetOptions(ctx) organization := snConfig.Organization - name, err := requiredParam[string](request.Params.Arguments, "name") + name, err := common.RequiredParam[string](request.Params.Arguments, "name") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get name: %v", err)), nil } - resourceType, err := requiredParam[string](request.Params.Arguments, "type") + resourceType, err := common.RequiredParam[string](request.Params.Arguments, "type") if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get type: %v", err)), nil } diff --git a/pkg/pftools/circuit_breaker.go b/pkg/pftools/circuit_breaker.go new file mode 100644 index 0000000..c2542ff --- /dev/null +++ b/pkg/pftools/circuit_breaker.go @@ -0,0 +1,163 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package pftools + +import ( + "fmt" + "time" +) + +// NewCircuitBreaker creates a new circuit breaker +func NewCircuitBreaker(threshold int, resetTimeout time.Duration) *CircuitBreaker { + return &CircuitBreaker{ + failureCount: 0, + failureThreshold: threshold, + resetTimeout: resetTimeout, + state: StateClosed, + } +} + +// RecordSuccess records a successful operation +func (cb *CircuitBreaker) RecordSuccess() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + // Reset the failure count + cb.failureCount = 0 + + // If we're half-open, close the circuit + if cb.state == StateHalfOpen { + cb.state = StateClosed + } +} + +// RecordFailure records a failed operation +func (cb *CircuitBreaker) RecordFailure() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + // Record the failure + cb.lastFailure = time.Now() + + // If we're already open, do nothing + if cb.state == StateOpen { + return + } + + // If we're half-open, open the circuit immediately + if cb.state == StateHalfOpen { + cb.state = StateOpen + return + } + + // Increment the failure count + cb.failureCount++ + + // If we've exceeded the threshold, open the circuit + if cb.failureCount >= cb.failureThreshold { + cb.state = StateOpen + } +} + +// AllowRequest determines if a request should be allowed +func (cb *CircuitBreaker) AllowRequest() bool { + cb.mutex.RLock() + defer cb.mutex.RUnlock() + + switch cb.state { + case StateClosed: + // Always allow if closed + return true + case StateOpen: + // If open, check if timeout has expired + if time.Since(cb.lastFailure) > cb.resetTimeout { + // Reset to half-open in a separate goroutine to avoid deadlock + go func() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + cb.state = StateHalfOpen + }() + // Allow this request as a test + return true + } + // Still open and timeout hasn't expired + return false + case StateHalfOpen: + // Allow one request in half-open state to test + return true + default: + return false + } +} + +// GetState returns the current state of the circuit breaker +func (cb *CircuitBreaker) GetState() CircuitState { + cb.mutex.RLock() + defer cb.mutex.RUnlock() + return cb.state +} + +// ForceOpen forces the circuit breaker to open +func (cb *CircuitBreaker) ForceOpen() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + cb.state = StateOpen + cb.lastFailure = time.Now() +} + +// ForceClose forces the circuit breaker to close +func (cb *CircuitBreaker) ForceClose() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + cb.state = StateClosed + cb.failureCount = 0 +} + +// Reset resets the circuit breaker to closed state +func (cb *CircuitBreaker) Reset() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + cb.state = StateClosed + cb.failureCount = 0 +} + +// GetStateString returns a string representation of the circuit breaker state +func (cb *CircuitBreaker) GetStateString() string { + cb.mutex.RLock() + defer cb.mutex.RUnlock() + + switch cb.state { + case StateOpen: + return "OPEN" + case StateHalfOpen: + return "HALF-OPEN" + case StateClosed: + return "CLOSED" + default: + return "UNKNOWN" + } +} + +// String returns a string representation of the circuit breaker +func (cb *CircuitBreaker) String() string { + cb.mutex.RLock() + defer cb.mutex.RUnlock() + + return fmt.Sprintf("CircuitBreaker{state=%s, failures=%d/%d, lastFailure=%v}", + cb.GetStateString(), cb.failureCount, cb.failureThreshold, cb.lastFailure) +} diff --git a/pkg/pftools/errors.go b/pkg/pftools/errors.go new file mode 100644 index 0000000..7fa0f72 --- /dev/null +++ b/pkg/pftools/errors.go @@ -0,0 +1,8 @@ +package pftools + +import "errors" + +var ( + ErrFunctionNotFound = errors.New("function not found") + ErrNotOurMessage = errors.New("not our message") +) diff --git a/pkg/pftools/invocation.go b/pkg/pftools/invocation.go new file mode 100644 index 0000000..fd8542e --- /dev/null +++ b/pkg/pftools/invocation.go @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package pftools + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/apache/pulsar-client-go/pulsar" + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/streamnative/streamnative-mcp-server/pkg/schema" +) + +// FunctionInvoker handles function invocation and result tracking +type FunctionInvoker struct { + client pulsar.Client + resultChannels map[string]chan FunctionResult + mutex sync.RWMutex + manager *PulsarFunctionManager +} + +// FunctionResult represents the result of a function invocation +type FunctionResult struct { + Data string + Error error +} + +// NewFunctionInvoker creates a new FunctionInvoker +func NewFunctionInvoker(manager *PulsarFunctionManager) *FunctionInvoker { + return &FunctionInvoker{ + client: manager.pulsarClient, + resultChannels: make(map[string]chan FunctionResult), + mutex: sync.RWMutex{}, + manager: manager, + } +} + +// InvokeFunctionAndWait sends a message to the function and waits for the result +func (fi *FunctionInvoker) InvokeFunctionAndWait(ctx context.Context, fnTool *FunctionTool, params map[string]interface{}) (*mcp.CallToolResult, error) { + schemaConverter, err := schema.ConverterFactory(fnTool.OutputSchema.Type) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to get schema converter: %v", err)), nil + } + + payload, err := schemaConverter.SerializeMCPRequestToPulsarPayload(params, fnTool.OutputSchema.PulsarSchemaInfo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to serialize payload: %v", err)), nil + } + + // Create a result channel for this request + resultChan := make(chan FunctionResult, 1) + + // Send message to input topic + msgID, err := fi.sendMessage(ctx, fnTool.InputTopic, payload) + if err != nil || msgID == "" { + return mcp.NewToolResultError(fmt.Sprintf("Failed to send message: %v", err)), nil + } + + fi.registerResultChannel(msgID, resultChan) + defer fi.unregisterResultChannel(msgID) + + // Set up consumer for output topic + err = fi.setupConsumer(ctx, fnTool.InputTopic, fnTool.OutputTopic, msgID, fnTool.OutputSchema) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to set up consumer: %v", err)), nil + } + + // Wait for result or timeout + select { + case result := <-resultChan: + if result.Error != nil { + return mcp.NewToolResultError(fmt.Sprintf("Function execution failed: %v", result.Error)), nil + } + + return mcp.NewToolResultText(result.Data), nil + case <-ctx.Done(): + return mcp.NewToolResultError(fmt.Sprintf("Function invocation timed out after %v", ctx.Value("timeout"))), nil + } +} + +// setupConsumer creates a consumer for the output topic +func (fi *FunctionInvoker) setupConsumer(ctx context.Context, inputTopic, outputTopic, messageID string, schema *SchemaInfo) error { + consumerOptions := pulsar.ConsumerOptions{ + Topic: outputTopic, + SubscriptionName: fmt.Sprintf("mcp-tool-consumer-%s", messageID), + Type: pulsar.Exclusive, + SubscriptionInitialPosition: pulsar.SubscriptionPositionEarliest, + } + + // Create the consumer + consumer, err := fi.client.Subscribe(consumerOptions) + if err != nil { + return fmt.Errorf("failed to create consumer: %w", err) + } + + // Start goroutine to receive messages + go func() { + // Ensure we close the consumer when done + defer func() { + consumer.Close() + }() + + // Get messages with a timeout + for { + select { + case <-ctx.Done(): + // Context canceled, exit + return + default: + // Set a short timeout for Receive to make it responsive to context cancellation + receiveCtx, cancel := context.WithTimeout(ctx, 1*time.Second) + msg, err := consumer.Receive(receiveCtx) + defer cancel() + + if err != nil { + if err == context.DeadlineExceeded || err == context.Canceled { + // Timeout or cancellation, but keep trying unless the parent context is done + continue + } + + continue + } + + // Process the message + err = fi.processMessage(inputTopic, msg, messageID, schema) + if err != nil { + if err == ErrNotOurMessage { + _ = consumer.Ack(msg) + continue + } + continue + } + + // Acknowledge the message + _ = consumer.Ack(msg) + + // Stop after processing one message + return + } + } + }() + + return nil +} + +// sendMessage sends a message to the input topic +func (fi *FunctionInvoker) sendMessage(ctx context.Context, inputTopic string, payload []byte) (string, error) { + producer, err := fi.manager.GetProducer(inputTopic) + if err != nil { + return "", fmt.Errorf("failed to get producer for topic %s: %w", inputTopic, err) + } + + // Send the message with properties + msgID, err := producer.Send(ctx, &pulsar.ProducerMessage{ + Payload: payload, + }) + + if err != nil { + return "", fmt.Errorf("failed to send message: %w", err) + } + + return msgID.String(), nil +} + +// processMessage processes a message received from the output topic +func (fi *FunctionInvoker) processMessage(inputTopic string, msg pulsar.Message, messageID string, schema *SchemaInfo) error { + // Check if the message has our correlation ID + correlationIDbytes, err := base64.StdEncoding.DecodeString(msg.Properties()["__pfn_input_msg_id__"]) + if err != nil { + return fmt.Errorf("failed to decode correlation ID: %w", err) + } + correlationID, err := pulsar.DeserializeMessageID(correlationIDbytes) + if err != nil { + return fmt.Errorf("failed to deserialize correlation ID: %w", err) + } + correlationInputTopic := msg.Properties()["__pfn_input_topic__"] + + if !isCorrelationInputTopic(correlationInputTopic, inputTopic) { + // Not our message, ignore + return ErrNotOurMessage + } + if correlationID.String() != messageID { + // Not our message, ignore + return ErrNotOurMessage + } + + // Get the result channel + fi.mutex.RLock() + resultChan, exists := fi.resultChannels[messageID] + fi.mutex.RUnlock() + + if !exists { + return fmt.Errorf("result channel not found for message ID: %s", messageID) + } + + switch schema.Type { + case "STRING": + result := string(msg.Payload()) + // Send the result to the channel + resultChan <- FunctionResult{ + Data: result, + Error: nil, + } + case "JSON": + var result map[string]interface{} + err := json.Unmarshal(msg.Payload(), &result) + if err != nil { + return fmt.Errorf("failed to unmarshal message payload: %w", err) + } + resultString, err := json.Marshal(result) + if err != nil { + return fmt.Errorf("failed to marshal result to JSON: %w", err) + } + resultChan <- FunctionResult{ + Data: string(resultString), + Error: nil, + } + default: + return fmt.Errorf("unsupported schema type: %s", schema.Type) + } + + return nil +} + +// registerResultChannel registers a result channel for a message ID +func (fi *FunctionInvoker) registerResultChannel(messageID string, resultChan chan FunctionResult) { + fi.mutex.Lock() + defer fi.mutex.Unlock() + fi.resultChannels[messageID] = resultChan +} + +// unregisterResultChannel unregisters a result channel for a message ID +func (fi *FunctionInvoker) unregisterResultChannel(messageID string) { + fi.mutex.Lock() + defer fi.mutex.Unlock() + delete(fi.resultChannels, messageID) +} + +func isCorrelationInputTopic(correlationInputTopic string, inputTopic string) bool { + // remove the partition index from the input topic + if strings.Contains(correlationInputTopic, cliutils.PARTITIONEDTOPICSUFFIX) { + correlationInputTopic = strings.Split(correlationInputTopic, cliutils.PARTITIONEDTOPICSUFFIX)[0] + } + + // remove the partition index from the input topic + if strings.Contains(inputTopic, cliutils.PARTITIONEDTOPICSUFFIX) { + inputTopic = strings.Split(inputTopic, cliutils.PARTITIONEDTOPICSUFFIX)[0] + } + + correlationInputTopicName, err := cliutils.GetTopicName(correlationInputTopic) + if err != nil { + return false + } + inputTopicName, err := cliutils.GetTopicName(inputTopic) + if err != nil { + return false + } + + return correlationInputTopicName.String() == inputTopicName.String() +} diff --git a/pkg/pftools/manager.go b/pkg/pftools/manager.go new file mode 100644 index 0000000..d69470b --- /dev/null +++ b/pkg/pftools/manager.go @@ -0,0 +1,528 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package pftools + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + "sync" + "time" + + "github.com/apache/pulsar-client-go/pulsar" + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config" + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/rest" + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/google/go-cmp/cmp" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/pulsarctl/pkg/cmdutils" + pulsarutils "github.com/streamnative/streamnative-mcp-server/pkg/pulsar" + "github.com/streamnative/streamnative-mcp-server/pkg/schema" +) + +const ( + CustomRuntimeOptionsEnvMcpToolNameKey = "MCP_TOOL_NAME" + CustomRuntimeOptionsEnvMcpToolDescriptionKey = "MCP_TOOL_DESCRIPTION" +) + +var DefaultStringSchemaInfo = &SchemaInfo{ + Type: "STRING", + Definition: map[string]interface{}{ + "type": "string", + }, + PulsarSchemaInfo: &cliutils.SchemaInfo{ + Type: "STRING", + }, +} + +// NewPulsarFunctionManager creates a new PulsarFunctionManager +func NewPulsarFunctionManager(mcpServer *server.MCPServer, readOnly bool, options *ManagerOptions) (*PulsarFunctionManager, error) { + // Get Pulsar client and admin client + pulsarClient, err := pulsarutils.GetPulsarClient() + if err != nil { + return nil, fmt.Errorf("failed to get Pulsar client: %w", err) + } + + adminClient := cmdutils.NewPulsarClientWithAPIVersion(config.V3) + v2adminClient := cmdutils.NewPulsarClientWithAPIVersion(config.V2) + if options == nil { + options = DefaultManagerOptions() + } + + // Create the manager + manager := &PulsarFunctionManager{ + adminClient: adminClient, + v2adminClient: v2adminClient, + pulsarClient: pulsarClient, + fnToToolMap: make(map[string]*FunctionTool), + mutex: sync.RWMutex{}, + producerCache: make(map[string]pulsar.Producer), + producerMutex: sync.RWMutex{}, + pollInterval: options.PollInterval, + stopCh: make(chan struct{}), + callInProgressMap: make(map[string]context.CancelFunc), + mcpServer: mcpServer, + readOnly: readOnly, + defaultTimeout: options.DefaultTimeout, + circuitBreakers: make(map[string]*CircuitBreaker), + tenantNamespaces: options.TenantNamespaces, + strictExport: options.StrictExport, + } + + return manager, nil +} + +// Start starts polling for functions +func (m *PulsarFunctionManager) Start() { + go m.pollFunctions() +} + +// Stop stops polling for functions +func (m *PulsarFunctionManager) Stop() { + close(m.stopCh) + + m.producerMutex.Lock() + defer m.producerMutex.Unlock() + for topic, producer := range m.producerCache { + log.Printf("Closing producer for topic: %s", topic) + producer.Close() + } + m.producerCache = make(map[string]pulsar.Producer) + log.Println("All cached producers closed and cache cleared.") +} + +// pollFunctions polls for functions periodically +func (m *PulsarFunctionManager) pollFunctions() { + ticker := time.NewTicker(m.pollInterval) + defer ticker.Stop() + + // Initial update + m.updateFunctions() + + for { + select { + case <-ticker.C: + m.updateFunctions() + case <-m.stopCh: + return + } + } +} + +// updateFunctions updates the function tool mappings +func (m *PulsarFunctionManager) updateFunctions() { + // Get all functions + functions, err := m.getFunctionsList() + if err != nil { + log.Printf("Failed to get functions list: %v", err) + return + } + + // Track which functions we've seen + seenFunctions := make(map[string]bool) + + // Add or update functions + for _, fn := range functions { + fullName := getFunctionFullName(fn.Tenant, fn.Namespace, fn.Name) + seenFunctions[fullName] = true + + // Check if we already have this function + m.mutex.RLock() + _, exists := m.fnToToolMap[fullName] + m.mutex.RUnlock() + + changed := false + if exists { + // Check if the function has changed + existingFn, exists := m.fnToToolMap[fullName] + if exists { + if !cmp.Equal(*existingFn.Function, *fn) { + changed = true + } + if !existingFn.SchemaFetchSuccess { + changed = true + } + } + if !changed { + continue + } + } + + // Convert function to tool + fnTool, err := m.convertFunctionToTool(fn) + if err != nil || !fnTool.SchemaFetchSuccess { + if err != nil { + log.Printf("Failed to convert function %s to tool: %v", fullName, err) + } else { + log.Printf("Failed to fetch schema for function %s, retry later...", fullName) + } + continue + } + + if changed { + m.mcpServer.DeleteTools(fnTool.Tool.Name) + } + m.mcpServer.AddTool(fnTool.Tool, m.handleToolCall(fnTool)) + + // Add function to map + m.mutex.Lock() + m.fnToToolMap[fullName] = fnTool + m.mutex.Unlock() + + if changed { + log.Printf("Updated function %s as MCP tool [%s]", fullName, fnTool.Tool.Name) + } else { + log.Printf("Added function %s as MCP tool [%s]", fullName, fnTool.Tool.Name) + } + } + + // Remove deleted functions + m.mutex.Lock() + for fullName, fnTool := range m.fnToToolMap { + if !seenFunctions[fullName] { + m.mcpServer.DeleteTools(fnTool.Tool.Name) + delete(m.fnToToolMap, fullName) + log.Printf("Removed function %s from MCP tools [%s]", fullName, fnTool.Tool.Name) + } + } + m.mutex.Unlock() +} + +// getFunctionsList retrieves all functions from the specified tenants/namespaces +func (m *PulsarFunctionManager) getFunctionsList() ([]*utils.FunctionConfig, error) { + var allFunctions []*utils.FunctionConfig + var runningFunctions []*utils.FunctionConfig + + if len(m.tenantNamespaces) == 0 { + // This is StreamNative supported way to get all functions when using Function Mesh + functions, err := m.getFunctionsInNamespace("tenant@all", "namespace@all") + if err != nil { + return nil, fmt.Errorf("failed to get functions in all namespaces: %w", err) + } + + allFunctions = append(allFunctions, functions...) + } else { + // Get functions from specified tenant/namespaces + for _, tn := range m.tenantNamespaces { + parts := strings.Split(tn, "/") + if len(parts) != 2 { + log.Printf("Invalid tenant/namespace format: %s", tn) + continue + } + + tenant := parts[0] + namespace := parts[1] + + functions, err := m.getFunctionsInNamespace(tenant, namespace) + if err != nil { + log.Printf("Failed to get functions in namespace %s/%s: %v", tenant, namespace, err) + continue + } + + allFunctions = append(allFunctions, functions...) + } + } + + for _, fn := range allFunctions { + if m.strictExport && + !strings.Contains(fn.CustomRuntimeOptions, CustomRuntimeOptionsEnvMcpToolNameKey) && + !strings.Contains(fn.CustomRuntimeOptions, CustomRuntimeOptionsEnvMcpToolDescriptionKey) { + continue + } + status, err := m.adminClient.Functions().GetFunctionStatus(fn.Tenant, fn.Namespace, fn.Name) + if err != nil { + continue + } + if status.NumRunning <= 0 { + continue + } + running := false + for _, instance := range status.Instances { + if instance.Status.Err != "" { + continue + } + if instance.Status.Running { + running = true + break + } + } + if !running { + continue + } + runningFunctions = append(runningFunctions, fn) + } + + return runningFunctions, nil +} + +// getFunctionsInNamespace retrieves all functions in a namespace +func (m *PulsarFunctionManager) getFunctionsInNamespace(tenant, namespace string) ([]*utils.FunctionConfig, error) { + var functions []*utils.FunctionConfig + + // Get function names + functionNames, err := m.adminClient.Functions().GetFunctions(tenant, namespace) + if err != nil { + return nil, fmt.Errorf("failed to get function names: %w", err) + } + + // Get details for each function + for _, name := range functionNames { + parts := strings.Split(name, "/") + if len(parts) != 3 { + log.Printf("Invalid function name format: %s", name) + continue + } + + function, err := m.adminClient.Functions().GetFunction(parts[0], parts[1], parts[2]) + if err != nil { + log.Printf("Failed to get function details for %s/%s/%s: %v", parts[0], parts[1], parts[2], err) + continue + } + + functions = append(functions, &function) + } + + return functions, nil +} + +// convertFunctionToTool converts a Pulsar Function to an MCP Tool +func (m *PulsarFunctionManager) convertFunctionToTool(fn *utils.FunctionConfig) (*FunctionTool, error) { + schemaFetchSuccess := true + // Determine input and output topics + if len(fn.InputSpecs) == 0 { + return nil, fmt.Errorf("function has no input topics") + } + + var inputTopic string + // Get the first input topic + for topic := range fn.InputSpecs { + inputTopic = topic + break + } + if inputTopic == "" { + return nil, fmt.Errorf("function has no input topics") + } + + // Get schema for input topic + inputSchema, err := GetSchemaFromTopic(m.v2adminClient, inputTopic) + if err != nil { + // Continue with a default schema + inputSchema = DefaultStringSchemaInfo + if restError, ok := err.(rest.Error); ok { + if restError.Code != 404 { + log.Printf("Failed to get schema for input topic %s: %v", inputTopic, err) + schemaFetchSuccess = false + } + } + } + + // Get output topic and schema + outputTopic := fn.Output + var outputSchema *SchemaInfo + if outputTopic != "" { + outputSchema, err = GetSchemaFromTopic(m.v2adminClient, outputTopic) + if err != nil { + // Continue with a default schema + outputSchema = DefaultStringSchemaInfo + if restError, ok := err.(rest.Error); ok { + if restError.Code != 404 { + log.Printf("Failed to get schema for output topic %s: %v", outputTopic, err) + schemaFetchSuccess = false + } + } + } + } + + toolName := retrieveToolName(fn) + // Replace non-alphanumeric characters + toolName = strings.ReplaceAll(toolName, "-", "_") + toolName = strings.ReplaceAll(toolName, ".", "_") + + // Create description + description := retrieveToolDescription(fn) + + schemaConverter, err := schema.ConverterFactory(inputSchema.Type) + if err != nil { + return nil, fmt.Errorf("failed to create schema converter: %w", err) + } + + toolInputSchemaProperties, err := schemaConverter.ToMCPToolInputSchemaProperties(inputSchema.PulsarSchemaInfo) + if err != nil { + return nil, fmt.Errorf("failed to convert input schema to MCP tool input schema properties: %w", err) + } + + toolInputSchemaProperties = append(toolInputSchemaProperties, mcp.WithDescription(description)) + + // Create the tool + tool := mcp.NewTool(toolName, + toolInputSchemaProperties..., + ) + + // Create circuit breaker for this function + circuitBreaker := NewCircuitBreaker(5, 60*time.Second) + + // Store in map + m.mutex.Lock() + m.circuitBreakers[toolName] = circuitBreaker + m.mutex.Unlock() + + return &FunctionTool{ + Name: toolName, + Function: fn, + InputSchema: inputSchema, + OutputSchema: outputSchema, + InputTopic: inputTopic, + OutputTopic: outputTopic, + Tool: tool, + SchemaFetchSuccess: schemaFetchSuccess, + }, nil +} + +// handleToolCall returns a handler function for a specific function tool +func (m *PulsarFunctionManager) handleToolCall(fnTool *FunctionTool) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Get the circuit breaker + m.mutex.RLock() + cb, exists := m.circuitBreakers[fnTool.Name] + m.mutex.RUnlock() + + if !exists { + cb = NewCircuitBreaker(5, 60*time.Second) + m.mutex.Lock() + m.circuitBreakers[fnTool.Name] = cb + m.mutex.Unlock() + } + + // Check if the circuit breaker allows the request + if !cb.AllowRequest() { + return mcp.NewToolResultError(fmt.Sprintf("Circuit breaker is open for function %s. Too many failures, please try again later.", fnTool.Name)), nil + } + + // Create function invoker + invoker := NewFunctionInvoker(m) + + // Create context with timeout + timeoutCtx, cancel := context.WithTimeout(ctx, m.defaultTimeout) + defer cancel() + + // Register call + m.mutex.Lock() + m.callInProgressMap[fnTool.Name] = cancel + m.mutex.Unlock() + defer func() { + m.mutex.Lock() + delete(m.callInProgressMap, fnTool.Name) + m.mutex.Unlock() + }() + + // Invoke function and wait for result + result, err := invoker.InvokeFunctionAndWait(timeoutCtx, fnTool, request.Params.Arguments) + + // Record success or failure + if err != nil { + cb.RecordFailure() + } else { + cb.RecordSuccess() + } + + return result, err + } +} + +// getFunctionFullName returns the full name of a function +func getFunctionFullName(tenant, namespace, name string) string { + return fmt.Sprintf("%s/%s/%s", tenant, namespace, name) +} + +// retrieveToolName retrieves the tool name from a function +func retrieveToolName(fn *utils.FunctionConfig) string { + if fn == nil { + return "" + } + fallbackName := fmt.Sprintf("pulsar_function_%s_%s_%s", fn.Tenant, fn.Namespace, fn.Name) + if fn.CustomRuntimeOptions != "" { + option := make(map[string]interface{}) + if err := json.Unmarshal([]byte(fn.CustomRuntimeOptions), &option); err != nil { + return fallbackName + } + if envs, ok := option["env"]; ok { + if envsMap, ok := envs.(map[string]interface{}); ok { + if name, ok := envsMap[CustomRuntimeOptionsEnvMcpToolNameKey]; ok { + return name.(string) + } + } + } + } + return fallbackName +} + +// retrieveToolDescription retrieves the tool description from a function +func retrieveToolDescription(fn *utils.FunctionConfig) string { + if fn == nil { + return "" + } + fallbackDescription := fmt.Sprintf("Linked to Pulsar Function: %s/%s/%s", fn.Tenant, fn.Namespace, fn.Name) + if fn.CustomRuntimeOptions != "" { + option := make(map[string]interface{}) + if err := json.Unmarshal([]byte(fn.CustomRuntimeOptions), &option); err != nil { + return fallbackDescription + } + if envs, ok := option["env"]; ok { + if envsMap, ok := envs.(map[string]interface{}); ok { + if description, ok := envsMap[CustomRuntimeOptionsEnvMcpToolDescriptionKey]; ok { + return description.(string) + " " + fallbackDescription + } + } + } + } + return fallbackDescription +} + +// GetProducer retrieves a producer from the cache or creates a new one if not found. +func (m *PulsarFunctionManager) GetProducer(topic string) (pulsar.Producer, error) { + m.producerMutex.RLock() + producer, found := m.producerCache[topic] + m.producerMutex.RUnlock() + + if found { + return producer, nil + } + + m.producerMutex.Lock() + defer m.producerMutex.Unlock() + + producer, found = m.producerCache[topic] + if found { + return producer, nil + } + + newProducer, err := m.pulsarClient.CreateProducer(pulsar.ProducerOptions{ + Topic: topic, + }) + if err != nil { + return nil, fmt.Errorf("failed to create producer for topic %s: %w", topic, err) + } + + m.producerCache[topic] = newProducer + log.Printf("Created and cached producer for topic: %s", topic) + return newProducer, nil +} diff --git a/pkg/pftools/schema.go b/pkg/pftools/schema.go new file mode 100644 index 0000000..2c8c725 --- /dev/null +++ b/pkg/pftools/schema.go @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package pftools + +import ( + "encoding/json" + "fmt" + + "github.com/apache/pulsar-client-go/pulsar" + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/streamnative/pulsarctl/pkg/cmdutils" +) + +var DefaultStringSchema = &mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "payload": map[string]interface{}{ + "type": "string", + "description": "The payload of the message, in plain text format", + }, + }, +} + +// GetSchemaFromTopic retrieves schema information from a topic +func GetSchemaFromTopic(admin cmdutils.Client, topic string) (*SchemaInfo, error) { + if admin == nil { + return nil, fmt.Errorf("failed to get schema from topic '%s': mcp server is not initialized", topic) + } + topicName, err := cliutils.GetTopicName(topic) + if err != nil { + return nil, fmt.Errorf("failed to get topic name from topic '%s': %w", topic, err) + } + + // Get schema info from topic + si, err := admin.Schemas().GetSchemaInfo(topicName.String()) + if err != nil { + return nil, fmt.Errorf("failed to get schema for topic '%s': %w", topicName.String(), err) + } + + if si == nil { + return nil, fmt.Errorf("no schema found for topic '%s'", topic) + } + + // Parse schema definition + var definition map[string]interface{} + if si.Schema != nil { + if err := json.Unmarshal(si.Schema, &definition); err != nil { + // If it's not a valid JSON, just create a string type schema + definition = map[string]interface{}{ + "type": "string", + } + } + } else { + // Default to string type if no schema is provided + definition = map[string]interface{}{ + "type": "string", + } + } + + return &SchemaInfo{ + Type: string(si.Type), + Definition: definition, + PulsarSchemaInfo: si, + }, nil +} + +// ConvertSchemaToToolInput converts a schema to MCP tool input schema +func ConvertSchemaToToolInput(schemaInfo *SchemaInfo) (*mcp.ToolInputSchema, error) { + if schemaInfo == nil { + // Default to object with any fields if no schema is provided + return DefaultStringSchema, nil + } + + // Handle different schema types + switch schemaInfo.Type { + case "JSON": + return convertComplexSchemaToToolInput(schemaInfo) + case "AVRO", "PROTOBUF", "PROTOBUF_NATIVE": + return nil, fmt.Errorf("AVRO, PROTOBUF and PROTOBUF_NATIVE schema is not supported") + default: + return DefaultStringSchema, nil + } +} + +// convertComplexSchemaToToolInput handles conversion of complex schema types +func convertComplexSchemaToToolInput(schemaInfo *SchemaInfo) (*mcp.ToolInputSchema, error) { + if schemaInfo.Definition == nil { + return DefaultStringSchema, nil + } + + fields, hasFields := schemaInfo.Definition["fields"].([]any) + if !hasFields { + return nil, fmt.Errorf("failed to get fields from schema definition") + } + + definitionString, err := json.Marshal(fields) + if err != nil { + return nil, fmt.Errorf("failed to marshal schema definition: %w", err) + } + + // For JSON schemas, use the definition directly + return &mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "payload": map[string]interface{}{ + "type": "string", + "description": "The payload of the message, in JSON String format, the schema of the payload in AVRO format is: " + string(definitionString), + }, + }, + }, nil +} + +func GetPulsarTypeSchema(schemaInfo *SchemaInfo) (pulsar.Schema, error) { + if schemaInfo == nil || schemaInfo.Definition == nil { + return pulsar.NewStringSchema(nil), nil + } + + switch schemaInfo.Type { + case "JSON": + schemaData, err := json.Marshal(schemaInfo.Definition) + if err != nil { + return nil, fmt.Errorf("failed to marshal schema definition: %w", err) + } + return pulsar.NewJSONSchema(string(schemaData), nil), nil + case "AVRO", "PROTOBUF", "PROTOBUF_NATIVE": + return nil, fmt.Errorf("AVRO, PROTOBUF and PROTOBUF_NATIVE schema is not supported") + default: + return pulsar.NewStringSchema(nil), nil + } +} diff --git a/pkg/pftools/types.go b/pkg/pftools/types.go new file mode 100644 index 0000000..d856a75 --- /dev/null +++ b/pkg/pftools/types.go @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package pftools + +import ( + "context" + "sync" + "time" + + "github.com/apache/pulsar-client-go/pulsar" + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/streamnative/pulsarctl/pkg/cmdutils" +) + +// PulsarFunctionManager manages the lifecycle of Pulsar Functions as MCP tools +type PulsarFunctionManager struct { + adminClient cmdutils.Client + v2adminClient cmdutils.Client + pulsarClient pulsar.Client + fnToToolMap map[string]*FunctionTool + mutex sync.RWMutex + producerCache map[string]pulsar.Producer + producerMutex sync.RWMutex + pollInterval time.Duration + stopCh chan struct{} + callInProgressMap map[string]context.CancelFunc + mcpServer *server.MCPServer + readOnly bool + defaultTimeout time.Duration + circuitBreakers map[string]*CircuitBreaker + tenantNamespaces []string + strictExport bool +} + +type FunctionTool struct { + Name string + Function *utils.FunctionConfig + InputSchema *SchemaInfo + OutputSchema *SchemaInfo + InputTopic string + OutputTopic string + Tool mcp.Tool + SchemaFetchSuccess bool +} + +type SchemaInfo struct { + Type string + Definition map[string]interface{} + PulsarSchemaInfo *utils.SchemaInfo +} + +type CircuitBreaker struct { + failureCount int + failureThreshold int + resetTimeout time.Duration + lastFailure time.Time + state CircuitState + mutex sync.RWMutex +} + +type CircuitState int + +const ( + StateOpen CircuitState = iota + StateHalfOpen + StateClosed +) + +type ManagerOptions struct { + PollInterval time.Duration + DefaultTimeout time.Duration + FailureThreshold int + ResetTimeout time.Duration + TenantNamespaces []string + StrictExport bool +} + +func DefaultManagerOptions() *ManagerOptions { + return &ManagerOptions{ + PollInterval: 30 * time.Second, + DefaultTimeout: 10 * time.Second, + FailureThreshold: 5, + ResetTimeout: 60 * time.Second, + TenantNamespaces: []string{}, + StrictExport: false, + } +} diff --git a/pkg/schema/avro.go b/pkg/schema/avro.go new file mode 100644 index 0000000..f407639 --- /dev/null +++ b/pkg/schema/avro.go @@ -0,0 +1,38 @@ +package schema + +import ( + "fmt" + // "reflect" // No longer needed here + + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" +) + +type AvroConverter struct { + BaseConverter +} + +func NewAvroConverter() *AvroConverter { + return &AvroConverter{} +} + +func (c *AvroConverter) ToMCPToolInputSchemaProperties(schemaInfo *cliutils.SchemaInfo) ([]mcp.ToolOption, error) { + if schemaInfo.Type != "AVRO" { + return nil, fmt.Errorf("expected AVRO schema, got %s", schemaInfo.Type) + } + return processAvroSchemaStringToMCPToolInput(string(schemaInfo.Schema)) +} + +func (c *AvroConverter) SerializeMCPRequestToPulsarPayload(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) ([]byte, error) { + if err := c.ValidateArguments(arguments, targetPulsarSchemaInfo); err != nil { + return nil, fmt.Errorf("arguments validation failed: %w", err) + } + return serializeArgumentsToAvroBinary(arguments, string(targetPulsarSchemaInfo.Schema)) +} + +func (c *AvroConverter) ValidateArguments(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) error { + if targetPulsarSchemaInfo.Type != "AVRO" { + return fmt.Errorf("expected AVRO schema for validation, got %s", targetPulsarSchemaInfo.Type) + } + return validateArgumentsAgainstAvroSchemaString(arguments, string(targetPulsarSchemaInfo.Schema)) +} diff --git a/pkg/schema/avro_core.go b/pkg/schema/avro_core.go new file mode 100644 index 0000000..8a8dc2c --- /dev/null +++ b/pkg/schema/avro_core.go @@ -0,0 +1,731 @@ +package schema + +import ( + "encoding/base64" + "fmt" + "reflect" + "strings" + + "github.com/hamba/avro/v2" + "github.com/mark3labs/mcp-go/mcp" +) + +// processAvroSchemaStringToMCPToolInput takes an AVRO schema string, parses it, +// and converts it to MCP tool input schema properties. +func processAvroSchemaStringToMCPToolInput(avroSchemaString string) ([]mcp.ToolOption, error) { + schema, err := avro.Parse(avroSchemaString) + if err != nil { + return nil, fmt.Errorf("failed to parse AVRO schema: %w", err) + } + + recordSchema, ok := schema.(*avro.RecordSchema) + if !ok { + // If it's not a record, perhaps it's a simpler type that can't be directly mapped to tool options, + // or we need a different handling strategy. For now, strict record check. + return nil, fmt.Errorf("expected AVRO record schema at the top level, got %s", reflect.TypeOf(schema).String()) + } + + var opts []mcp.ToolOption + for _, field := range recordSchema.Fields() { + fieldOption, err := avroFieldToMcpOption(field) + if err != nil { + return nil, fmt.Errorf("failed to convert field '%s': %w", field.Name(), err) + } + opts = append(opts, fieldOption) + } + return opts, nil +} + +// avroFieldToMcpOption converts a single AVRO field to an mcp.ToolOption. +func avroFieldToMcpOption(field *avro.Field) (mcp.ToolOption, error) { + fieldType := field.Type() + fieldName := field.Name() + + var description string + if field.Doc() != "" { + description = field.Doc() + } else { + description = fmt.Sprintf("%s (type: %s)", fieldName, strings.ReplaceAll(fieldType.String(), "\"", "")) // Default description + } + + isRequired := true + var underlyingTypeForDefault avro.Schema = fieldType // Used to check default value against non-union type + + if unionSchema, ok := fieldType.(*avro.UnionSchema); ok { + isNullAble := false + var nonNullTypes []avro.Schema + for _, t := range unionSchema.Types() { + if t.Type() == avro.Null { + isNullAble = true + } else { + nonNullTypes = append(nonNullTypes, t) + } + } + isRequired = !isNullAble + + // If it's a nullable union with one other type (e.g., ["null", "string"]), + // treat it as that other type for default value and MCP type mapping. + //nolint:gocritic // This is a valid use of len(nonNullTypes) == 1 + if isNullAble && len(nonNullTypes) == 1 { + underlyingTypeForDefault = nonNullTypes[0] + } else if len(nonNullTypes) == 1 { + // Not nullable, but still a union with one type (should ideally not happen, but handle) + underlyingTypeForDefault = nonNullTypes[0] + } else if len(nonNullTypes) > 1 { + // Complex union (e.g., ["string", "int"]), for MCP, describe as string and mention union nature. + // Default values for complex unions are tricky with current MCP options. + // MCP schema might need to be a string with a description of the union. + props := []mcp.PropertyOption{mcp.Description(description + " (union type: " + strings.ReplaceAll(fieldType.String(), "\"", "") + ")")} + if isRequired { + props = append(props, mcp.Required()) + } + // Default value for complex union is not straightforward to map to a single MCP type's default. + // We will skip setting mcp.Default... for complex unions for now. + return mcp.WithString(fieldName, props...), nil + } + // If only "null" type was in union, or empty nonNullTypes (invalid schema), this will be caught by later type switch. + } + + var prop []mcp.PropertyOption + if isRequired { + prop = append(prop, mcp.Required()) + } + prop = append(prop, mcp.Description(description)) + + var opt mcp.ToolOption + + // Use underlyingTypeForDefault for determining MCP type and handling default values + // This handles cases like ["null", "string"] by treating it as "string" for MCP mapping. + effectiveType := underlyingTypeForDefault.Type() + + switch effectiveType { + case avro.String: + if field.HasDefault() { + if defaultVal, ok := field.Default().(string); ok { + prop = append(prop, mcp.DefaultString(defaultVal)) + } + } + opt = mcp.WithString(fieldName, prop...) + case avro.Int, avro.Long: // MCP 'number' can represent both + if field.HasDefault() { + // Avro library parses numeric defaults as float64 for int/long/float/double from JSON representation + if defaultVal, ok := field.Default().(float64); ok { + prop = append(prop, mcp.DefaultNumber(defaultVal)) + } else if defaultIntVal, ok := field.Default().(int); ok { // direct int + prop = append(prop, mcp.DefaultNumber(float64(defaultIntVal))) + } else if defaultInt32Val, ok := field.Default().(int32); ok { + prop = append(prop, mcp.DefaultNumber(float64(defaultInt32Val))) + } else if defaultInt64Val, ok := field.Default().(int64); ok { + prop = append(prop, mcp.DefaultNumber(float64(defaultInt64Val))) + } + } + opt = mcp.WithNumber(fieldName, prop...) + case avro.Float, avro.Double: // MCP 'number' can represent both + if field.HasDefault() { + if defaultVal, ok := field.Default().(float64); ok { + prop = append(prop, mcp.DefaultNumber(defaultVal)) + } + } + opt = mcp.WithNumber(fieldName, prop...) + case avro.Boolean: + if field.HasDefault() { + if defaultVal, ok := field.Default().(bool); ok { + prop = append(prop, mcp.DefaultBool(defaultVal)) + } + } + opt = mcp.WithBoolean(fieldName, prop...) + case avro.Bytes, avro.Fixed: + if field.HasDefault() { + if defaultVal, ok := field.Default().(string); ok { + prop = append(prop, mcp.DefaultString(defaultVal)) + } else if defaultBytes, ok := field.Default().([]byte); ok { + prop = append(prop, mcp.DefaultString(string(defaultBytes))) // Or base64 + } + } + // For Fixed type, add size information to description + if fixedSchema, ok := underlyingTypeForDefault.(*avro.FixedSchema); ok { + description += fmt.Sprintf(" (fixed size: %d bytes)", fixedSchema.Size()) + prop[0] = mcp.Description(description) // Update description in prop + } + opt = mcp.WithString(fieldName, prop...) // Always use WithString for Bytes/Fixed in MCP options + case avro.Array: + arraySchema, _ := underlyingTypeForDefault.(*avro.ArraySchema) // Safe due to switch + itemsDef, err := avroSchemaDefinitionToMcpProperties("item", arraySchema.Items(), "Array items") + if err != nil { + return nil, fmt.Errorf("failed to convert array items for field '%s': %w", fieldName, err) + } + prop = append(prop, mcp.Items(itemsDef)) + // Default for array is not directly supported by mcp.DefaultArray like mcp.DefaultString etc. + // It would require converting []any to a string representation or specific handling. + opt = mcp.WithArray(fieldName, prop...) + case avro.Map: + mapSchema, _ := underlyingTypeForDefault.(*avro.MapSchema) // Safe due to switch + // For MCP, map values are usually represented as an object where keys are arbitrary strings + // and all values conform to a single schema. + // mcp.Properties expects a map[string]any defining named properties. + // This is a slight mismatch. MCP's WithObject with mcp.Properties(map[string]any{"*": valuesDef}) + // is one way, or a more flexible mcp.WithMap that takes a value schema. + // For now, let's assume map values translate to a generic object property definition. + valuesDef, err := avroSchemaDefinitionToMcpProperties("value", mapSchema.Values(), "Map values") + if err != nil { + return nil, fmt.Errorf("failed to convert map values for field '%s': %w", fieldName, err) + } + // This isn't a perfect fit for mcp.Properties which expects fixed keys. + // A better MCP representation for a map might be WithObject where AdditionalProperties is set. + // For now, we describe it as an object and the value schema applies to all its properties. + // A more accurate MCP representation might be needed if maps are used extensively. + // Let's use a single property "*" to denote the schema for all values. + prop = append(prop, mcp.Properties(map[string]any{"*": valuesDef})) + opt = mcp.WithObject(fieldName, prop...) // Representing map as a generic object. + case avro.Record: + recordSchema, _ := underlyingTypeForDefault.(*avro.RecordSchema) // Safe due to switch + subProps := make(map[string]any) + for _, subField := range recordSchema.Fields() { + // Recursively call avroFieldToMcpOption to get the ToolOption, then extract its schema part? + // No, avroSchemaDefinitionToMcpProperties is for defining schema of items/values, not named fields. + // We need to build the properties map for mcp.WithObject. + // Each subField needs its schema definition. + var subFieldDescription string + if subField.Doc() != "" { + subFieldDescription = subField.Doc() + } else { + subFieldDescription = fmt.Sprintf("%s (type: %s)", subField.Name(), strings.ReplaceAll(subField.Type().String(), "\"", "")) // Default description + } + subFieldDef, err := avroSchemaDefinitionToMcpProperties(subField.Name(), subField.Type(), subFieldDescription) + if err != nil { + return nil, fmt.Errorf("failed to convert sub-field '%s' of record '%s': %w", subField.Name(), fieldName, err) + } + subProps[subField.Name()] = subFieldDef + } + prop = append(prop, mcp.Properties(subProps)) + opt = mcp.WithObject(fieldName, prop...) + case avro.Enum: + enumSchema, _ := underlyingTypeForDefault.(*avro.EnumSchema) // Safe due to switch + prop = append(prop, mcp.Enum(enumSchema.Symbols()...)) + if field.HasDefault() { + if defaultVal, ok := field.Default().(string); ok { // Enum default is string + prop = append(prop, mcp.DefaultString(defaultVal)) + } + } + opt = mcp.WithString(fieldName, prop...) // Enum is represented as string in MCP + case avro.Null: + // This case should ideally be handled by the union logic making the field not required. + // If a field is solely "null", it's an odd schema. For MCP, maybe a string with note. + // If isRequired is true here (meaning it wasn't a union with null), it's a non-optional null field. + // This is unusual. Let's represent as a non-required string. + if isRequired { // Should not happen if only type is null and handled by union logic + prop = []mcp.PropertyOption{mcp.Description(description + " (null type)")} // remove mcp.Required + } else { + prop = append(prop, mcp.Description(description+" (null type)")) + } + opt = mcp.WithString(fieldName, prop...) // Or handle as a special ignorable field? + default: + // For unknown or unsupported AVRO types, represent as a string in MCP with a description. + var defaultCaseProps []mcp.PropertyOption + defaultCaseProps = append(defaultCaseProps, mcp.Description(description+" (unsupported AVRO type: "+string(effectiveType)+")")) + if isRequired { + defaultCaseProps = append(defaultCaseProps, mcp.Required()) + } + opt = mcp.WithString(fieldName, defaultCaseProps...) + } + return opt, nil +} + +// avroSchemaDefinitionToMcpProperties converts an AVRO schema definition (like for array items or map values) +// into a map[string]any structure suitable for mcp.PropertyOption's Items or Properties. +func avroSchemaDefinitionToMcpProperties(name string, avroDef avro.Schema, description string) (map[string]any, error) { + props := make(map[string]any) + if description == "" { + props["description"] = fmt.Sprintf("Schema for %s", name) + } else { + props["description"] = description + } + + // Handle unions for nested types as well + var effectiveSchema = avroDef + if unionSchema, ok := avroDef.(*avro.UnionSchema); ok { + var nonNullTypes []avro.Schema + for _, t := range unionSchema.Types() { + if t.Type() != avro.Null { + nonNullTypes = append(nonNullTypes, t) + } + } + //nolint:gocritic // This is a valid use of len(nonNullTypes) == 1 + if len(nonNullTypes) == 1 { + effectiveSchema = nonNullTypes[0] + props["description"] = props["description"].(string) + " (nullable)" + } else if len(nonNullTypes) > 1 { + props["type"] = "string" // Represent complex union as string + props["description"] = props["description"].(string) + " (union type: " + strings.ReplaceAll(avroDef.String(), "\"", "") + ")" + return props, nil + } else { // Only null in union or empty union (invalid) + props["type"] = "string" // Fallback for null type + props["description"] = props["description"].(string) + " (effectively null type)" + return props, nil + } + } + + switch effectiveSchema.Type() { + case avro.String: + props["type"] = "string" + case avro.Int, avro.Long: + props["type"] = "number" + case avro.Float, avro.Double: + props["type"] = "number" + case avro.Boolean: + props["type"] = "boolean" + case avro.Bytes, avro.Fixed: // Fixed size bytes + props["type"] = "string" // Bytes/Fixed represented as string in MCP JSON schema + case avro.Array: + arraySchema, _ := effectiveSchema.(*avro.ArraySchema) + itemsProps, err := avroSchemaDefinitionToMcpProperties("item", arraySchema.Items(), "Array items") + if err != nil { + return nil, err + } + props["type"] = "array" + props["items"] = itemsProps + case avro.Map: + mapSchema, _ := effectiveSchema.(*avro.MapSchema) + // MCP object properties are named. Avro map keys are strings, values are of a single schema. + // We represent this as an object where all properties conform to the map's value schema. + // The key "*" can signify this pattern. + valuesProps, err := avroSchemaDefinitionToMcpProperties("value", mapSchema.Values(), "Map values schema") + if err != nil { + return nil, err + } + props["type"] = "object" + // To represent an Avro map (string keys, common value schema) in JSON schema properties: + // we can use "additionalProperties" with the schema of map values. + // Or, for mcp.Properties, we might define a placeholder like "*". + // For now, let's return a structure that indicates it's an object, + // and the `valuesProps` describes the schema for any property within this object. + // This is a common pattern for map-like structures in JSON Schema if not using additionalProperties. + props["properties"] = map[string]any{"*": valuesProps} // Indicating all values have this schema. + case avro.Record: + recordSchema, _ := effectiveSchema.(*avro.RecordSchema) + subProps := make(map[string]any) + for _, field := range recordSchema.Fields() { + var fieldDescription string + if field.Doc() != "" { + fieldDescription = field.Doc() + } else { + fieldDescription = fmt.Sprintf("%s (type: %s)", field.Name(), strings.ReplaceAll(field.Type().String(), "\"", "")) // Default description + } + fieldProp, err := avroSchemaDefinitionToMcpProperties(field.Name(), field.Type(), fieldDescription) + if err != nil { + return nil, err + } + subProps[field.Name()] = fieldProp + } + props["type"] = "object" + props["properties"] = subProps + case avro.Enum: + enumSchema, _ := effectiveSchema.(*avro.EnumSchema) + props["type"] = "string" + props["enum"] = enumSchema.Symbols() + case avro.Null: // Should be handled by union logic primarily + props["type"] = "string" // Fallback for a standalone null type. + props["description"] = props["description"].(string) + " (null type)" + default: + props["type"] = "string" // Fallback for unknown types + props["description"] = props["description"].(string) + " (unknown AVRO type: " + string(effectiveSchema.Type()) + ")" + } + return props, nil +} + +// validateArgumentsAgainstAvroSchemaString validates arguments against an AVRO schema string. +func validateArgumentsAgainstAvroSchemaString(arguments map[string]any, avroSchemaString string) error { + schema, err := avro.Parse(avroSchemaString) + if err != nil { + return fmt.Errorf("failed to parse AVRO schema for validation: %w", err) + } + + // Expecting a record schema at the top level for arguments map + recordSchema, ok := schema.(*avro.RecordSchema) + if !ok { + // If the schema is not a record, but arguments are a map, it's a mismatch unless the schema is a map itself. + // However, tool inputs are typically records/objects. + // If schema is a single type (e.g. string), arguments shouldn't be a map. This needs clarification on Pulsar schema use. + // For now, assume top-level schema for arguments is a record. + return fmt.Errorf("expected AVRO record schema for validating arguments map, got %s", reflect.TypeOf(schema).String()) + } + + // Check for missing required fields + for _, field := range recordSchema.Fields() { + fieldName := field.Name() + // Determine if the field is effectively required for validation purposes + // (not nullable in a union, or a non-union type without a default that makes it implicitly optional) + // This `isReq` is used to check if a field *must* be present in the arguments map if it *doesn't* have a default. + isReq := true // Assume required unless part of a nullable union + if unionSchemaVal, ok := field.Type().(*avro.UnionSchema); ok { + isNullableInUnion := false + for _, t := range unionSchemaVal.Types() { + if t.Type() == avro.Null { + isNullableInUnion = true + break + } + } + isReq = !isNullableInUnion + } + + // Check if the field is present in the arguments map + value, valueOk := arguments[fieldName] + + // If field is not in arguments map + if !valueOk { + // If it's considered required (isReq is true) AND it does not have a default value, + // then it's an error for it to be missing from arguments. + if isReq && !field.HasDefault() { + return fmt.Errorf("required field '%s' is missing and has no default value", fieldName) + } + // If not required (isReq is false), or if it has a default value, it's okay for it to be missing. + // The Avro library itself will handle applying the default during actual serialization/deserialization. + // Our validator's job here is primarily to ensure that if values ARE provided, they are correct, + // and that truly mandatory fields (required and no default) are present. + continue // Move to the next field in the schema + } + + // If field is present in arguments, validate its value against its schema type + if err := validateValueAgainstAvroType(value, field.Type(), fieldName); err != nil { + return err + } + } + + // After validating all fields defined in the schema, check for any extra fields in the arguments. + for argName := range arguments { + foundInSchema := false + for _, field := range recordSchema.Fields() { + if field.Name() == argName { + foundInSchema = true + break + } + } + if !foundInSchema { + return fmt.Errorf("unknown field '%s' provided in arguments", argName) + } + } + + return nil +} + +// validateValueAgainstAvroType validates a single value against a given AVRO schema type. +// path is for constructing helpful error messages. +func validateValueAgainstAvroType(value any, avroDef avro.Schema, path string) error { + if value == nil { + // If value is nil, check if avroDef allows null + if avroDef.Type() == avro.Null { + return nil // Explicitly null type allows nil + } + if unionSchema, ok := avroDef.(*avro.UnionSchema); ok { + for _, t := range unionSchema.Types() { + if t.Type() == avro.Null { + return nil // Union includes null type + } + } + } + return fmt.Errorf("field '%s' is null, but schema type '%s' does not permit null", path, avroDef.Type()) + } + + // If avroDef is a union, try to validate against each type in the union. + if unionSchema, ok := avroDef.(*avro.UnionSchema); ok { + var lastErr error + for _, schemaTypeInUnion := range unionSchema.Types() { + // Skip null type here as we've handled nil value above. If value is not nil, null type won't match. + if schemaTypeInUnion.Type() == avro.Null { + continue + } + err := validateValueAgainstAvroType(value, schemaTypeInUnion, path) + if err == nil { + return nil // Valid against one of the types in the union + } + lastErr = err // Keep the last error for context if none match + } + if lastErr != nil { + return fmt.Errorf("field '%s' (value: %v, type: %T) does not match any type in union schema '%s': last error: %w", path, value, value, unionSchema.String(), lastErr) + } + // If union was only ["null"] and value is not nil, this will be an error. + return fmt.Errorf("field '%s' (value: %v) of type %T does not match union schema '%s' (no non-null types matched or union is only null)", path, value, value, unionSchema.String()) + } + + // Non-union type validation + switch avroDef.Type() { + case avro.String: + if _, ok := value.(string); !ok { + return fmt.Errorf("field '%s': expected string, got %T (value: %v)", path, value, value) + } + case avro.Int: + switch value.(type) { + case int, int8, int16, int32, int64, float32, float64: + if fVal, ok := value.(float64); ok && fVal != float64(int64(fVal)) { + return fmt.Errorf("field '%s': expected int, got float64 with fractional part (value: %v)", path, value) + } + if fVal, ok := value.(float32); ok && fVal != float32(int32(fVal)) { + return fmt.Errorf("field '%s': expected int, got float32 with fractional part (value: %v)", path, value) + } + return nil + default: + return fmt.Errorf("field '%s': expected int, got %T (value: %v)", path, value, value) + } + case avro.Long: + switch value.(type) { + case int, int8, int16, int32, int64, float32, float64: + if fVal, ok := value.(float64); ok && fVal != float64(int64(fVal)) { + return fmt.Errorf("field '%s': expected long, got float64 with fractional part (value: %v)", path, value) + } + if fVal, ok := value.(float32); ok && fVal != float32(int64(fVal)) { // float32 to int64 comparison can be tricky with precision + return fmt.Errorf("field '%s': expected long, got float32 with fractional part (value: %v)", path, value) + } + return nil + default: + return fmt.Errorf("field '%s': expected long, got %T (value: %v)", path, value, value) + } + case avro.Float: + switch value.(type) { + case float32, float64, int, int8, int16, int32, int64: + return nil + default: + return fmt.Errorf("field '%s': expected float, got %T (value: %v)", path, value, value) + } + case avro.Double: + switch value.(type) { + case float32, float64, int, int8, int16, int32, int64: + return nil + default: + return fmt.Errorf("field '%s': expected double, got %T (value: %v)", path, value, value) + } + case avro.Boolean: + if _, ok := value.(bool); !ok { + return fmt.Errorf("field '%s': expected boolean, got %T (value: %v)", path, value, value) + } + + case avro.Bytes: + if _, okStr := value.(string); okStr { + return nil // Allow string for bytes/fixed as per previous logic + } + if _, okBytes := value.([]byte); okBytes { + return nil // Also allow []byte directly + } + return fmt.Errorf("field '%s': expected string or []byte for bytes, got %T (value: %v)", path, value, value) + case avro.Fixed: + if _, ok := value.(uint64); ok { + return nil // Allow uint64 for fixed as per previous logic + } + return fmt.Errorf("field '%s': expected uint64 for fixed, got %T (value: %v)", path, value, value) + case avro.Array: + arrSchema, _ := avroDef.(*avro.ArraySchema) + sliceVal, ok := value.([]any) // JSON unmarshals to []any + if !ok { + // Check if it's a typed slice, e.g. []string, []map[string]any, etc. + // This requires more reflection if we want to support e.g. []string directly. + // For map[string]any from JSON, []any is standard. + return fmt.Errorf("field '%s': expected array (slice of any), got %T (value: %v)", path, value, value) + } + for i, item := range sliceVal { + if err := validateValueAgainstAvroType(item, arrSchema.Items(), fmt.Sprintf("%s[%d]", path, i)); err != nil { + return err + } + } + case avro.Map: + mapSchema, _ := avroDef.(*avro.MapSchema) + mapVal, ok := value.(map[string]any) // JSON unmarshals to map[string]any + if !ok { + return fmt.Errorf("field '%s': expected map (map[string]any), got %T (value: %v)", path, value, value) + } + for k, v := range mapVal { + if err := validateValueAgainstAvroType(v, mapSchema.Values(), fmt.Sprintf("%s.%s", path, k)); err != nil { + return err + } + } + case avro.Record: + recSchema, _ := avroDef.(*avro.RecordSchema) + mapVal, ok := value.(map[string]any) // JSON unmarshals to map[string]any + if !ok { + return fmt.Errorf("field '%s': expected object (map[string]any) for record, got %T (value: %v)", path, value, value) + } + // Check required fields within the record + for _, f := range recSchema.Fields() { + isFieldRequired := true + if unionF, okF := f.Type().(*avro.UnionSchema); okF { + isNullableF := false + for _, t := range unionF.Types() { + if t.Type() == avro.Null { + isNullableF = true + break + } + } + if isNullableF { + isFieldRequired = false + } + } + if _, exists := mapVal[f.Name()]; !exists && isFieldRequired { + return fmt.Errorf("field '%s.%s' is required but missing", path, f.Name()) + } + } + // Validate present fields + for k, v := range mapVal { + var recField *avro.Field + for _, f := range recSchema.Fields() { + if f.Name() == k { + recField = f + break + } + } + if recField == nil { + return fmt.Errorf("field '%s.%s' is not defined in record schema", path, k) + } + if err := validateValueAgainstAvroType(v, recField.Type(), fmt.Sprintf("%s.%s", path, k)); err != nil { + return err + } + } + case avro.Enum: + enumSchema, _ := avroDef.(*avro.EnumSchema) + strVal, ok := value.(string) + if !ok { + return fmt.Errorf("field '%s': expected string for enum, got %T (value: %v)", path, value, value) + } + isValidSymbol := false + for _, s := range enumSchema.Symbols() { + if s == strVal { + isValidSymbol = true + break + } + } + if !isValidSymbol { + return fmt.Errorf("field '%s': value '%s' is not a valid symbol for enum %s. Valid symbols: %v", path, strVal, enumSchema.FullName(), enumSchema.Symbols()) + } + case avro.Null: + if value == nil { + // If value is nil, check if avroDef allows null + if avroDef.Type() == avro.Null { + return nil // Explicitly null type allows nil + } + if unionSchema, ok := avroDef.(*avro.UnionSchema); ok { + for _, t := range unionSchema.Types() { + if t.Type() == avro.Null { + return nil // Union includes null type + } + } + } + return fmt.Errorf("field '%s' is null, but schema type '%s' does not permit null", path, avroDef.Type()) + } + // If value is not nil, it's an error. Nil value handled at the start of the function. + // This means value is non-nil here. + return fmt.Errorf("field '%s': schema type is explicitly 'null' but received non-nil value %T (value: %v)", path, value, value) + + default: + return fmt.Errorf("field '%s': unsupported AVRO type '%s' for validation", path, avroDef.Type()) + } + return nil // Should be unreachable if all cases are handled or return, but as a fallback +} + +// serializeArgumentsToAvroBinary validates arguments against an AVRO schema string +// and then serializes them to AVRO binary format. +func serializeArgumentsToAvroBinary(arguments map[string]any, avroSchemaString string) ([]byte, error) { + // First, validate arguments. + // The validation logic already parses the schema string. + if err := validateArgumentsAgainstAvroSchemaString(arguments, avroSchemaString); err != nil { + return nil, fmt.Errorf("arguments validation failed before AVRO serialization: %w", err) + } + + // Parse schema again for marshaling (or pass parsed schema from validation if we refactor to return it) + schema, err := avro.Parse(avroSchemaString) + if err != nil { + // This error should ideally not happen if validation passed, as it also parses. + return nil, fmt.Errorf("failed to parse AVRO schema for serialization (should have been caught by validation): %w", err) + } + + // Before marshalling, we might need to coerce some types, e.g., string to []byte for "bytes" type if convention is base64 string. + coercedArgs := make(map[string]any, len(arguments)) + for k, v := range arguments { + coercedArgs[k] = v // Copy existing + } + + recordSchema, ok := schema.(*avro.RecordSchema) + if !ok { + // This should ideally not happen if validation passed, but as a safeguard: + return nil, fmt.Errorf("parsed schema is not a record schema, cannot prepare arguments for serialization") + } + + for _, field := range recordSchema.Fields() { + fieldName := field.Name() + val, argExists := arguments[fieldName] + if !argExists { + continue // If arg doesn't exist, skip (defaults or optional handled by avro lib or previous validation) + } + + fieldType := field.Type().Type() // Get the base type, handles unions by checking underlying + if unionSchema, isUnion := field.Type().(*avro.UnionSchema); isUnion { + // If it's a union, we need to find the actual non-null type for coercion logic + // This part can be complex if multiple non-null types are in union with bytes/fixed. + // For simplicity, assuming if 'bytes' or 'fixed' is a possibility, we check for string coercion. + // A more robust solution would inspect the actual type of 'val' against union possibilities. + for _, unionMemberType := range unionSchema.Types() { + if unionMemberType.Type() == avro.Bytes || unionMemberType.Type() == avro.Fixed { + fieldType = unionMemberType.Type() + break + } + } + } + + if fieldType == avro.Bytes { + if strVal, isStr := val.(string); isStr { + // Attempt to decode if it's a string, assuming base64 for bytes encoded as string + decodedBytes, err := base64.StdEncoding.DecodeString(strVal) + if err == nil { + coercedArgs[fieldName] = decodedBytes + } else { + coercedArgs[fieldName] = []byte(strVal) + } + } + } else if fieldType == avro.Fixed { + if strVal, isStr := val.(string); isStr { + // For fixed, if it's a string, it must be base64 decodable to the correct length array + fixedSchema, _ := field.Type().(*avro.FixedSchema) // Or resolve from union if necessary + if actualUnionFieldSchema, okUnion := field.Type().(*avro.UnionSchema); okUnion { + for _, ut := range actualUnionFieldSchema.Types() { + if fs, okUFS := ut.(*avro.FixedSchema); okUFS { + fixedSchema = fs + break + } + } + } + if fixedSchema != nil { + decodedBytes, err := base64.StdEncoding.DecodeString(strVal) + if err == nil { + if len(decodedBytes) == fixedSchema.Size() { + // Convert []byte to [N]byte array for fixed type + fixedArray := reflect.New(reflect.ArrayOf(fixedSchema.Size(), reflect.TypeOf(byte(0)))).Elem() + reflect.Copy(fixedArray, reflect.ValueOf(decodedBytes)) + coercedArgs[fieldName] = fixedArray.Interface() + } else { + // Length mismatch after decoding + return nil, fmt.Errorf("field '%s' (fixed[%d]): base64 decoded string has length %d, expected %d", fieldName, fixedSchema.Size(), len(decodedBytes), fixedSchema.Size()) + } + } // else: base64 decoding error, let avro.Marshal handle or error out + } + } else if byteSlice, isSlice := val.([]byte); isSlice { + // If it's already a []byte, check if it's for a Fixed type and needs conversion to [N]byte + fixedSchema, _ := field.Type().(*avro.FixedSchema) + if actualUnionFieldSchema, okUnion := field.Type().(*avro.UnionSchema); okUnion { + for _, ut := range actualUnionFieldSchema.Types() { + if fs, okUFS := ut.(*avro.FixedSchema); okUFS { + fixedSchema = fs + break + } + } + } + if fixedSchema != nil && len(byteSlice) == fixedSchema.Size() { + fixedArray := reflect.New(reflect.ArrayOf(fixedSchema.Size(), reflect.TypeOf(byte(0)))).Elem() + reflect.Copy(fixedArray, reflect.ValueOf(byteSlice)) + coercedArgs[fieldName] = fixedArray.Interface() + } else if fixedSchema != nil && len(byteSlice) != fixedSchema.Size() { + return nil, fmt.Errorf("field '%s' (fixed[%d]): provided []byte has length %d, expected %d", fieldName, fixedSchema.Size(), len(byteSlice), fixedSchema.Size()) + } // else it's not for a fixed schema or length mismatch, or it's for 'bytes' type, keep as []byte + + } + } + } + + // Marshal the potentially coerced arguments + return avro.Marshal(schema, coercedArgs) +} diff --git a/pkg/schema/avro_core_test.go b/pkg/schema/avro_core_test.go new file mode 100644 index 0000000..58362a4 --- /dev/null +++ b/pkg/schema/avro_core_test.go @@ -0,0 +1,833 @@ +package schema + +import ( + "testing" + + "github.com/hamba/avro/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mark3labs/mcp-go/mcp" +) + +// AVRO Schema Definitions for Testing + +const simpleRecordSchema = `{ + "type": "record", + "name": "SimpleRecord", + "fields": [ + {"name": "id", "type": "long"}, + {"name": "name", "type": "string"} + ] +}` + +const schemaWithAllPrimitives = `{ + "type": "record", + "name": "AllPrimitives", + "fields": [ + {"name": "boolField", "type": "boolean"}, + {"name": "intField", "type": "int"}, + {"name": "longField", "type": "long"}, + {"name": "floatField", "type": "float"}, + {"name": "doubleField", "type": "double"}, + {"name": "bytesField", "type": "bytes"}, + {"name": "stringField", "type": "string"} + ] +}` + +const schemaWithOptionalField = `{ + "type": "record", + "name": "OptionalFieldRecord", + "fields": [ + {"name": "requiredField", "type": "string"}, + {"name": "optionalField", "type": ["null", "string"], "default": null} + ] +}` + +const schemaWithDefaultValue = `{ + "type": "record", + "name": "DefaultValueRecord", + "fields": [ + {"name": "name", "type": "string", "default": "DefaultName"}, + {"name": "age", "type": "int", "default": 30} + ] +}` + +const nestedRecordSchema = `{ + "type": "record", + "name": "OuterRecord", + "fields": [ + {"name": "id", "type": "string"}, + { + "name": "inner", + "type": { + "type": "record", + "name": "InnerRecord", + "fields": [ + {"name": "value", "type": "int"}, + {"name": "description", "type": ["null", "string"], "default": null} + ] + } + } + ] +}` + +const arraySchemaPrimitive = `{ + "type": "record", + "name": "ArrayPrimitiveRecord", + "fields": [ + {"name": "stringArray", "type": {"type": "array", "items": "string"}}, + {"name": "optionalIntArray", "type": ["null", {"type": "array", "items": "int"}], "default": null} + ] +}` + +const arraySchemaRecord = `{ + "type": "record", + "name": "ArrayRecordContainer", + "fields": [ + { + "name": "records", + "type": {"type": "array", "items": { + "type": "record", + "name": "ContainedRecord", + "fields": [ + {"name": "key", "type": "string"}, + {"name": "val", "type": "long"} + ] + }} + } + ] +}` + +const mapSchemaPrimitive = `{ + "type": "record", + "name": "MapPrimitiveRecord", + "fields": [ + {"name": "stringMap", "type": {"type": "map", "values": "string"}}, + {"name": "optionalIntMap", "type": ["null", {"type": "map", "values": "int"}], "default": null} + ] +}` + +const mapSchemaRecord = `{ + "type": "record", + "name": "MapRecordContainer", + "fields": [ + { + "name": "recordsMap", + "type": {"type": "map", "values": { + "type": "record", + "name": "MappedRecord", + "fields": [ + {"name": "id", "type": "int"}, + {"name": "status", "type": "string"} + ] + }} + } + ] +}` + +const enumSchema = `{ + "type": "record", + "name": "EnumRecord", + "fields": [ + { + "name": "suit", + "type": { "type": "enum", "name": "Suit", "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] } + } + ] +}` + +const unionSchemaSimple = `{ + "type": "record", + "name": "UnionRecord", + "fields": [ + {"name": "stringOrInt", "type": ["string", "int"]}, + {"name": "nullableStringOrInt", "type": ["null", "string", "int"], "default": null} + ] +}` + +// Helper function to get expected AVRO binary for test comparisons +func getExpectedAvroBinary(t *testing.T, schemaStr string, data map[string]any) []byte { + avroSchema, err := avro.Parse(schemaStr) + require.NoError(t, err, "Failed to parse AVRO schema in helper") + + // The hamba/avro/v2 Marshal function can take a map[string]any directly + // if its structure is compatible with the schema. + bin, err := avro.Marshal(avroSchema, data) + require.NoError(t, err, "Failed to marshal to AVRO binary in helper") + return bin +} + +func TestProcessAvroSchemaStringToMCPToolInput(t *testing.T) { + tests := []struct { + name string + schemaStr string + expectedOptions []mcp.ToolOption + expectError bool + expectedErrorMsg string + }{ + { + name: "Simple Record", + schemaStr: simpleRecordSchema, + expectedOptions: []mcp.ToolOption{ + mcp.WithNumber("id", mcp.Description("id (type: long)"), mcp.Required()), + mcp.WithString("name", mcp.Description("name (type: string)"), mcp.Required()), + }, + expectError: false, + }, + { + name: "Schema With All Primitives", + schemaStr: schemaWithAllPrimitives, + expectedOptions: []mcp.ToolOption{ + mcp.WithBoolean("boolField", mcp.Description("boolField (type: boolean)"), mcp.Required()), + mcp.WithNumber("intField", mcp.Description("intField (type: int)"), mcp.Required()), + mcp.WithNumber("longField", mcp.Description("longField (type: long)"), mcp.Required()), + mcp.WithNumber("floatField", mcp.Description("floatField (type: float)"), mcp.Required()), + mcp.WithNumber("doubleField", mcp.Description("doubleField (type: double)"), mcp.Required()), + mcp.WithString("bytesField", mcp.Description("bytesField (type: bytes)"), mcp.Required()), + mcp.WithString("stringField", mcp.Description("stringField (type: string)"), mcp.Required()), + }, + expectError: false, + }, + { + name: "Invalid AVRO schema string", + schemaStr: `{"type": "invalid"`, + expectedOptions: nil, + expectError: true, + expectedErrorMsg: "failed to parse AVRO schema", + }, + { + name: "Top-level non-record (string)", + schemaStr: `"string"`, + expectedOptions: nil, + expectError: true, + expectedErrorMsg: "expected AVRO record schema at the top level, got *avro.PrimitiveSchema", + }, + { + name: "Schema With Optional Field (string)", + schemaStr: schemaWithOptionalField, + expectedOptions: []mcp.ToolOption{ + mcp.WithString("requiredField", mcp.Description("requiredField (type: string)"), mcp.Required()), + mcp.WithString("optionalField", mcp.Description("optionalField (type: [null,string])")), + }, + expectError: false, + }, + { + name: "Schema With Default Value (string and int)", + schemaStr: schemaWithDefaultValue, + expectedOptions: []mcp.ToolOption{ + mcp.WithString("name", mcp.Description("name (type: string)"), mcp.Required(), mcp.DefaultString("DefaultName")), + mcp.WithNumber("age", mcp.Description("age (type: int)"), mcp.Required(), mcp.DefaultNumber(30)), + }, + expectError: false, + }, + { + name: "Nested Record", + schemaStr: nestedRecordSchema, + expectedOptions: []mcp.ToolOption{ + mcp.WithString("id", mcp.Description("id (type: string)"), mcp.Required()), + mcp.WithObject("inner", + mcp.Description("inner (type: {name:InnerRecord,type:record,fields:[{name:value,type:int},{name:description,type:[null,string]}]})"), + mcp.Required(), + mcp.Properties(map[string]any{ + "value": map[string]any{ + "type": "number", + "description": "value (type: int)", + }, + "description": map[string]any{ + "type": "string", + "description": "description (type: [null,string]) (nullable)", + }, + }), + ), + }, + expectError: false, + }, + { + name: "Array of Primitives (stringArray, optionalIntArray)", + schemaStr: arraySchemaPrimitive, + expectedOptions: []mcp.ToolOption{ + mcp.WithArray("stringArray", + mcp.Description("stringArray (type: {type:array,items:string})"), + mcp.Required(), + mcp.Items(map[string]any{ + "type": "string", + "description": "Array items", + }), + ), + mcp.WithArray("optionalIntArray", + mcp.Description("optionalIntArray (type: [null,{type:array,items:int}])"), + mcp.Items(map[string]any{ + "type": "number", + "description": "Array items", + }), + ), + }, + expectError: false, + }, + { + name: "Array of Records", + schemaStr: arraySchemaRecord, + expectedOptions: []mcp.ToolOption{ + mcp.WithArray("records", + mcp.Description("records (type: {type:array,items:{name:ContainedRecord,type:record,fields:[{name:key,type:string},{name:val,type:long}]}})"), + mcp.Required(), + mcp.Items(map[string]any{ + "type": "object", + "description": "Array items", + "properties": map[string]any{ + "key": map[string]any{ + "type": "string", + "description": "key (type: string)", + }, + "val": map[string]any{ + "type": "number", + "description": "val (type: long)", + }, + }, + }), + ), + }, + expectError: false, + }, + { + name: "Map of Primitives (stringMap, optionalIntMap)", + schemaStr: mapSchemaPrimitive, + expectedOptions: []mcp.ToolOption{ + mcp.WithObject("stringMap", // Avro map becomes MCP object + mcp.Description("stringMap (type: {type:map,values:string})"), + mcp.Required(), + mcp.Properties(map[string]any{ // Based on avroFieldToMcpOption logic for map + "*": map[string]any{ + "type": "string", + "description": "Map values", + }, + }), + ), + mcp.WithObject("optionalIntMap", + mcp.Description("optionalIntMap (type: [null,{type:map,values:int}])"), + // Not required due to ["null", map] union + mcp.Properties(map[string]any{ + "*": map[string]any{ + "type": "number", + "description": "Map values", + }, + }), + // Avro default: null for the union handled by not being required. + ), + }, + expectError: false, + }, + { + name: "Map of Records", + schemaStr: mapSchemaRecord, + expectedOptions: []mcp.ToolOption{ + mcp.WithObject("recordsMap", + mcp.Description("recordsMap (type: {type:map,values:{name:MappedRecord,type:record,fields:[{name:id,type:int},{name:status,type:string}]}})"), + mcp.Required(), + mcp.Properties(map[string]any{ + "*": map[string]any{ + "type": "object", + "description": "Map values", + "properties": map[string]any{ + "id": map[string]any{ + "type": "number", + "description": "id (type: int)", + }, + "status": map[string]any{ + "type": "string", + "description": "status (type: string)", + }, + }, + }, + }), + ), + }, + expectError: false, + }, + { + name: "Enum Field", + schemaStr: enumSchema, + expectedOptions: []mcp.ToolOption{ + mcp.WithString("suit", + mcp.Description("suit (type: {name:Suit,type:enum,symbols:[SPADES,HEARTS,DIAMONDS,CLUBS]})"), + mcp.Required(), + mcp.Enum("SPADES", "HEARTS", "DIAMONDS", "CLUBS"), + ), + }, + expectError: false, + }, + { + name: "Simple Union Field (stringOrInt)", + schemaStr: unionSchemaSimple, + expectedOptions: []mcp.ToolOption{ + // Based on avroFieldToMcpOption, a complex union ["string", "int"] becomes mcp.WithString + // with a description indicating it's a union. It's marked as required by default. + mcp.WithString("stringOrInt", + mcp.Description("stringOrInt (type: [string,int]) (union type: [string,int])"), + mcp.Required(), + ), + // For ["null", "string", "int"], it's also a complex union but not required. + mcp.WithString("nullableStringOrInt", + mcp.Description("nullableStringOrInt (type: [null,string,int]) (union type: [null,string,int])"), + // Not mcp.Required() due to presence of "null" in union. + // Default is Avro null, so no mcp.DefaultString is added. + ), + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts, err := processAvroSchemaStringToMCPToolInput(tt.schemaStr) + + if tt.expectError { + assert.Error(t, err) + if tt.expectedErrorMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrorMsg) + } + assert.Nil(t, opts) + } else { + assert.NoError(t, err) + require.Equal(t, len(tt.expectedOptions), len(opts), "Number of options should match") + var actualTool, expectedTool mcp.Tool + actualTool = mcp.NewTool("test", opts...) + expectedTool = mcp.NewTool("test", tt.expectedOptions...) + actualToolInputSchemaJSON, _ := actualTool.InputSchema.MarshalJSON() + expectedToolInputSchemaJSON, _ := expectedTool.InputSchema.MarshalJSON() + assert.Equal(t, string(expectedToolInputSchemaJSON), string(actualToolInputSchemaJSON), "ToolOption did not produce the same property configuration. Expected: %+v, Got: %+v", expectedTool, actualTool) + } + }) + } +} + +// TestValidateArgumentsAgainstAvroSchemaString tests the validateArgumentsAgainstAvroSchemaString function. +func TestValidateArgumentsAgainstAvroSchemaString(t *testing.T) { + tests := []struct { + name string + schemaStr string + args map[string]any + wantErr bool + errText string // Optional: specific error text to check for + }{ + { + name: "Valid: Simple Record - all fields present", + schemaStr: simpleRecordSchema, + args: map[string]any{ + "id": int64(123), + "name": "test name", + }, + wantErr: false, + }, + { + name: "Valid: Schema With All Primitives - all fields present", + schemaStr: schemaWithAllPrimitives, + args: map[string]any{ + "boolField": true, + "intField": int32(100), + "longField": int64(2000), + "floatField": float32(3.14), + "doubleField": float64(6.28), + "bytesField": []byte("bytesdata"), + "stringField": "stringdata", + }, + wantErr: false, + }, + { + name: "Invalid: Simple Record - missing required field 'name'", + schemaStr: simpleRecordSchema, + args: map[string]any{ + "id": int64(123), + }, + wantErr: true, + errText: "required field 'name' is missing and has no default value", + }, + { + name: "Invalid: Simple Record - wrong type for 'id' (string instead of long)", + schemaStr: simpleRecordSchema, + args: map[string]any{ + "id": "not-a-long", + "name": "A Name", + }, + wantErr: true, + errText: "field 'id': expected long, got string", + }, + { + name: "Invalid: Simple Record - extra field 'extra'", + schemaStr: simpleRecordSchema, + args: map[string]any{ + "id": int64(123), + "name": "A Name", + "extra": "value", + }, + wantErr: true, + errText: "unknown field 'extra' provided in arguments", + }, + { + name: "Valid: Optional Field - present", + schemaStr: schemaWithOptionalField, + args: map[string]any{ + "requiredField": "req", + "optionalField": "opt", + }, + wantErr: false, + }, + { + name: "Valid: Optional Field - absent (should use default null)", + schemaStr: schemaWithOptionalField, + args: map[string]any{ + "requiredField": "req", + }, + wantErr: false, // AVRO validation itself passes if default is null and field is omitted + }, + { + name: "Valid: Default Value - fields omitted", + schemaStr: schemaWithDefaultValue, + args: map[string]any{}, + wantErr: false, // Default values are used + }, + { + name: "Valid: Default Value - fields provided", + schemaStr: schemaWithDefaultValue, + args: map[string]any{ + "name": "ProvidedName", + "age": int32(40), + }, + wantErr: false, + }, + { + name: "Valid: Nested Record", + schemaStr: nestedRecordSchema, + args: map[string]any{ + "id": "outerID", + "inner": map[string]any{ + "value": int32(101), + "description": "inner desc", + }, + }, + wantErr: false, + }, + { + name: "Invalid: Nested Record - missing field in inner record", + schemaStr: nestedRecordSchema, + args: map[string]any{ + "id": "outerID", + "inner": map[string]any{"name": "Inner Name"}, // Missing inner.value + }, + wantErr: true, + errText: "field 'inner.value' is required but missing", + }, + { + name: "Valid: Array of Primitives", + schemaStr: arraySchemaPrimitive, + args: map[string]any{ + "stringArray": []any{"a", "b", "c"}, + "optionalIntArray": []any{int32(1), int32(2)}, + }, + wantErr: false, + }, + { + name: "Invalid: Array of Primitives - wrong item type", + schemaStr: arraySchemaPrimitive, + args: map[string]any{ + "stringArray": []any{"hello", 1, "world"}, // int is not string + }, + wantErr: true, + errText: "field 'stringArray[1]': expected string, got int", + }, + { + name: "Valid: Array of Records", + schemaStr: arraySchemaRecord, + args: map[string]any{ + "records": []any{ + map[string]any{"key": "k1", "val": int64(1)}, + map[string]any{"key": "k2", "val": int64(2)}, + }, + }, + wantErr: false, + }, + { + name: "Valid: Map of Primitives", + schemaStr: mapSchemaPrimitive, + args: map[string]any{ + "stringMap": map[string]any{"key1": "val1", "key2": "val2"}, + "optionalIntMap": map[string]any{"opt1": int32(10)}, + }, + wantErr: false, + }, + { + name: "Invalid: Map of Primitives - wrong value type", + schemaStr: mapSchemaPrimitive, + args: map[string]any{ + "stringMap": map[string]any{"key1": 123, "key2": "val2"}, // 123 is not string + }, + wantErr: true, + errText: "field 'stringMap.key1': expected string, got int", + }, + { + name: "Valid: Map of Records", + schemaStr: mapSchemaRecord, + args: map[string]any{ + "recordsMap": map[string]any{ + "recA": map[string]any{"id": int32(1), "status": "active"}, + "recB": map[string]any{"id": int32(2), "status": "inactive"}, + }, + }, + wantErr: false, + }, + { + name: "Valid: Enum", + schemaStr: enumSchema, + args: map[string]any{ + "suit": "SPADES", + }, + wantErr: false, + }, + { + name: "Invalid: Enum - invalid symbol", + schemaStr: enumSchema, + args: map[string]any{ + "suit": "INVALID_SUIT", + }, + wantErr: true, + errText: "value 'INVALID_SUIT' is not a valid symbol for enum Suit", + }, + { + name: "Valid: Union (stringOrInt) - string", + schemaStr: unionSchemaSimple, + args: map[string]any{ + "stringOrInt": "hello", + "nullableStringOrInt": "world", + }, + wantErr: false, + }, + { + name: "Valid: Union (stringOrInt) - int", + schemaStr: unionSchemaSimple, + args: map[string]any{ + "stringOrInt": int32(123), + "nullableStringOrInt": int32(456), + }, + wantErr: false, + }, + { + name: "Invalid: Union (stringOrInt) - boolean (not in union)", + schemaStr: unionSchemaSimple, + args: map[string]any{ + "stringOrInt": true, + }, + wantErr: true, + errText: "does not match any type in union schema", + }, + { + name: "Invalid: Schema string is empty", + schemaStr: "", + args: map[string]any{"foo": "bar"}, + wantErr: true, + errText: "failed to parse AVRO schema for validation: avro: unknown type: ", + }, + { + name: "Invalid: Schema string is not valid AVRO json", + schemaStr: "{invalid json", + args: map[string]any{"foo": "bar"}, + wantErr: true, + errText: "failed to parse AVRO schema", + }, + { + name: "Valid: schemaWithAllPrimitives - bytes field with string input (should be accepted as per current code)", + schemaStr: schemaWithAllPrimitives, + args: map[string]any{ + "boolField": true, + "intField": int32(100), + "longField": int64(2000), + "floatField": float32(3.14), + "doubleField": float64(6.28), + "bytesField": "stringtobytes", // This is the key part for this test + "stringField": "stringdata", + }, + wantErr: false, // Current validateValueAgainstAvroType for bytes accepts string + }, + { + name: "Invalid: schemaWithAllPrimitives - bytes field with int input", + schemaStr: schemaWithAllPrimitives, + args: map[string]any{ + "boolField": true, + "intField": int32(100), + "longField": int64(2000), + "floatField": float32(3.14), + "doubleField": float64(6.28), + "bytesField": 123, // int is not convertible to bytes in this path + "stringField": "stringdata", + }, + wantErr: true, + errText: "field 'bytesField': expected string or []byte for bytes, got int", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateArgumentsAgainstAvroSchemaString(tt.args, tt.schemaStr) + if tt.wantErr { + assert.Error(t, err) + if tt.errText != "" { + assert.Contains(t, err.Error(), tt.errText) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestSerializeArgumentsToAvroBinary tests the serializeArgumentsToAvroBinary function. +func TestSerializeArgumentsToAvroBinary(t *testing.T) { + tests := []struct { + name string + schemaStr string + args map[string]any + expectError bool + errorContains string + expectedBinary []byte // Optional: if nil, don't check binary equality, just no error + }{ + { + name: "Valid: Simple Record", + schemaStr: simpleRecordSchema, + args: map[string]any{ + "id": int64(123), + "name": "test name", + }, + expectError: false, + }, + { + name: "Valid: Schema With All Primitives (matching getExpectedAvroBinary)", + schemaStr: schemaWithAllPrimitives, + args: map[string]any{ + "boolField": true, + "intField": int32(100), + "longField": int64(2000), + "floatField": float32(3.14), + "doubleField": float64(6.28), + "bytesField": []byte("bytesdata"), + "stringField": "stringdata", + }, + expectError: false, + }, + { + name: "Valid: Nested Record", + schemaStr: nestedRecordSchema, + args: map[string]any{ + "id": "outerID", + "inner": map[string]any{ + "value": int32(101), + "description": "inner desc", + }, + }, + expectError: false, + }, + { + name: "Valid: Array of Primitives", + schemaStr: arraySchemaPrimitive, + args: map[string]any{ + "stringArray": []any{"a", "b", "c"}, + "optionalIntArray": []any{int32(1), int32(2)}, + }, + expectError: false, + }, + { + name: "Valid: Map of Records", + schemaStr: mapSchemaRecord, + args: map[string]any{ + "recordsMap": map[string]any{ + "recA": map[string]any{"id": int32(1), "status": "active"}, + }, + }, + expectError: false, + }, + { + name: "Valid: Enum", + schemaStr: enumSchema, + args: map[string]any{ + "suit": "SPADES", + }, + expectError: false, + }, + { + name: "Valid: Union - int type", + schemaStr: unionSchemaSimple, + args: map[string]any{ + "stringOrInt": int32(123), + }, + expectError: false, + }, + { + name: "Invalid: Serialization fails due to validation (missing required field)", + schemaStr: simpleRecordSchema, + args: map[string]any{ + "id": int64(123), // "name" is missing + }, + expectError: true, + errorContains: "required field 'name' is missing and has no default value", + }, + { + name: "Invalid: Serialization fails due to validation (wrong type)", + schemaStr: simpleRecordSchema, + args: map[string]any{ + "id": "not-a-long", + "name": "Test Name", + }, + expectError: true, + errorContains: "field 'id': expected long, got string (value: not-a-long)", + }, + { + name: "Invalid: Empty schema string", + schemaStr: "", + args: map[string]any{"id": int64(1)}, + expectError: true, + errorContains: "failed to parse AVRO schema for validation: avro: unknown type: ", + }, + { + name: "Invalid: Malformed schema string", + schemaStr: `{"type": "record"`, + args: map[string]any{"id": 123}, + expectError: true, + errorContains: "failed to parse AVRO schema for validation: avro: unknown type: {\"type\": \"record\"", + }, + { + name: "Valid: schemaWithAllPrimitives - bytes field with string input for serialization", + schemaStr: schemaWithAllPrimitives, + args: map[string]any{ + "boolField": true, + "intField": int32(100), + "longField": int64(2000), + "floatField": float32(3.14), + "doubleField": float64(6.28), + "bytesField": []byte("stringtobytes"), + "stringField": "stringdata", + }, + expectError: false, // Should serialize correctly as current code handles string to []byte for bytes type + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actualBinary, err := serializeArgumentsToAvroBinary(tt.args, tt.schemaStr) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, actualBinary) + } else { + assert.NoError(t, err) + assert.NotNil(t, actualBinary) + + // To validate the binary output, we use the helper. + // This ensures that our serialization matches a known-good Avro library's output. + expectedBinary := getExpectedAvroBinary(t, tt.schemaStr, tt.args) + assert.Equal(t, expectedBinary, actualBinary, "Serialized binary output does not match expected binary.") + } + }) + } +} diff --git a/pkg/schema/avro_test.go b/pkg/schema/avro_test.go new file mode 100644 index 0000000..82d6248 --- /dev/null +++ b/pkg/schema/avro_test.go @@ -0,0 +1,375 @@ +package schema + +import ( + "testing" + + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/stretchr/testify/assert" + + "github.com/mark3labs/mcp-go/mcp" +) + +// complexRecordSchemaString is used by TestAvroConverter_ValidateArguments +// and is not defined in avro_core_test.go, so we define it here. +const complexRecordSchemaString = `{ + "type": "record", + "name": "ComplexRecord", + "fields": [ + {"name": "fieldString", "type": "string"}, + {"name": "fieldInt", "type": "int"}, + {"name": "fieldLong", "type": "long"}, + {"name": "fieldDouble", "type": "double"}, + {"name": "fieldFloat", "type": "float"}, + {"name": "fieldBool", "type": "boolean"}, + {"name": "fieldBytes", "type": "bytes"}, + {"name": "fieldFixed", "type": {"type": "fixed", "name": "MyFixed", "size": 8}}, + {"name": "fieldEnum", "type": {"type": "enum", "name": "MyEnum", "symbols": ["A", "B", "C"]}}, + {"name": "fieldArray", "type": {"type": "array", "items": "int"}}, + {"name": "fieldMap", "type": {"type": "map", "values": "long"}}, + {"name": "fieldRecord", "type": { + "type": "record", "name": "SubRecord", "fields": [ + {"name": "subField", "type": "string"} + ]} + }, + {"name": "fieldUnion", "type": ["null", "int", "string"]} + ] +}` + +// newAvroSchemaInfo is a helper to create SchemaInfo for AVRO type with a given AVRO schema string. +func newAvroSchemaInfo(avroSchemaString string) *utils.SchemaInfo { + return &utils.SchemaInfo{ + Name: "test-avro-schema", + Type: "AVRO", // Pulsar schema type is string + Schema: []byte(avroSchemaString), + } +} + +// TestNewAvroConverter tests the NewAvroConverter constructor. +func TestNewAvroConverter(t *testing.T) { + converter := NewAvroConverter() + assert.NotNil(t, converter, "NewAvroConverter should not return nil") + // AvroConverter also relies on Avro structure primarily, similar to JSONConverter. +} + +// TestAvroConverter_ToMCPToolInputSchemaProperties tests ToMCPToolInputSchemaProperties for AvroConverter. +func TestAvroConverter_ToMCPToolInputSchemaProperties(t *testing.T) { + converter := NewAvroConverter() + + const localSimpleRecordSchemaForAvro = `{ + "type": "record", + "name": "SimpleRecordForAvro", + "fields": [ + {"name": "id", "type": "long"}, + {"name": "data", "type": "string"} + ] + }` + + const invalidAvroSchemaForAvro = `{"type": "invalidAvro}` + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + expectedOptions []mcp.ToolOption + expectError bool + errorContains string + }{ + { + name: "Valid AVRO schema", + schemaInfo: newAvroSchemaInfo(localSimpleRecordSchemaForAvro), + expectedOptions: []mcp.ToolOption{ + mcp.WithNumber("id", mcp.Description("id (type: long)"), mcp.Required()), + mcp.WithString("data", mcp.Description("data (type: string)"), mcp.Required()), + }, + expectError: false, + }, + { + name: "SchemaInfo type is not AVRO (e.g., JSON)", + schemaInfo: &utils.SchemaInfo{Type: "JSON", Schema: []byte(localSimpleRecordSchemaForAvro)}, + expectedOptions: nil, + expectError: true, + errorContains: "expected AVRO schema, got JSON", + }, + { + name: "Invalid underlying Avro schema string", + schemaInfo: newAvroSchemaInfo(invalidAvroSchemaForAvro), + expectedOptions: nil, + expectError: true, + errorContains: "unknown type: {\"type\": \"invalidAvro}", + }, + { + name: "Underlying Avro schema is nil", + schemaInfo: &utils.SchemaInfo{Type: "AVRO", Schema: nil}, + expectedOptions: nil, + expectError: true, + errorContains: "unknown type: ", + }, + { + name: "Underlying Avro schema is empty string", + schemaInfo: newAvroSchemaInfo(""), + expectedOptions: nil, + expectError: true, + errorContains: "unknown type: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts, err := converter.ToMCPToolInputSchemaProperties(tt.schemaInfo) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, opts) + } else { + assert.NoError(t, err) + var expectedTool, actualTool mcp.Tool + expectedTool = mcp.NewTool("test", tt.expectedOptions...) + actualTool = mcp.NewTool("test", opts...) + expectedToolSchemaJSON, _ := expectedTool.InputSchema.MarshalJSON() + actualToolSchemaJSON, _ := actualTool.InputSchema.MarshalJSON() + assert.Equal(t, expectedToolSchemaJSON, actualToolSchemaJSON, "Tool mismatch") + } + }) + } +} + +func TestAvroConverter_ValidateArguments(t *testing.T) { + tests := []struct { + name string + schemaInfo utils.SchemaInfo + args map[string]interface{} + wantErr bool + }{ + { + name: "valid arguments for simple record", + schemaInfo: utils.SchemaInfo{ + Name: "SimpleAvro", + Schema: []byte(simpleRecordSchema), + Type: "AVRO", + }, + args: map[string]interface{}{ + "id": int64(123), + "name": "TestName", + }, + wantErr: false, + }, + { + name: "invalid type for field in simple record", + schemaInfo: utils.SchemaInfo{ + Name: "SimpleAvroInvalidType", + Schema: []byte(simpleRecordSchema), + Type: "AVRO", + }, + args: map[string]interface{}{ + "id": "not_a_long", + "name": "TestName", + }, + wantErr: true, + }, + { + name: "missing required field in simple record", + schemaInfo: utils.SchemaInfo{ + Name: "SimpleAvroMissingField", + Schema: []byte(simpleRecordSchema), + Type: "AVRO", + }, + args: map[string]interface{}{ + "name": "TestName", + }, + wantErr: true, + }, + { + name: "valid arguments for complex record", + schemaInfo: utils.SchemaInfo{ + Name: "ComplexAvroValid", + Schema: []byte(complexRecordSchemaString), + Type: "AVRO", + }, + args: map[string]interface{}{ + "fieldString": "test string", + "fieldInt": int32(123), + "fieldLong": int64(456), + "fieldDouble": float64(12.34), + "fieldFloat": float32(56.78), + "fieldBool": true, + "fieldBytes": []byte("test bytes"), + "fieldFixed": uint64(0x1234567890123456), + "fieldEnum": "A", + "fieldArray": []interface{}{int32(1), int32(2)}, + "fieldMap": map[string]interface{}{"key1": int64(100)}, + "fieldRecord": map[string]interface{}{"subField": "sub value"}, + "fieldUnion": int32(99), + }, + wantErr: false, + }, + { + name: "invalid enum value for complex record", + schemaInfo: utils.SchemaInfo{ + Name: "ComplexAvroInvalidEnum", + Schema: []byte(complexRecordSchemaString), + Type: "AVRO", + }, + args: map[string]interface{}{ + "fieldString": "test string", + "fieldInt": int32(123), + "fieldLong": int64(456), + "fieldDouble": float64(12.34), + "fieldFloat": float32(56.78), + "fieldBool": true, + "fieldBytes": []byte("test bytes"), + "fieldFixed": uint64(0x1234567890123456), + "fieldEnum": "X", // Invalid enum + "fieldArray": []interface{}{int32(1), int32(2)}, + "fieldMap": map[string]interface{}{"key1": int64(100)}, + "fieldRecord": map[string]interface{}{"subField": "sub value"}, + "fieldUnion": int32(99), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conv := NewAvroConverter() + + err := conv.ValidateArguments(tt.args, &tt.schemaInfo) + if (err != nil) != tt.wantErr { + t.Errorf("AvroConverter.ValidateArguments() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAvroConverter_SerializeMCPRequestToPulsarPayload(t *testing.T) { + // Assumes simpleRecordSchema and complexRecordSchemaString are available + // and getExpectedAvroBinary is accessible from avro_core_test.go (same package) + + tests := []struct { + name string + schemaInfo utils.SchemaInfo + args map[string]interface{} + expectedPayload []byte // Can be nil if wantErr is true + wantErr bool + assertPayload bool // True if we want to compare payload, false if only error matters + errorContains string + }{ + { + name: "valid serialization for simple record", + schemaInfo: utils.SchemaInfo{ + Name: "SimpleAvroSerialize", + Schema: []byte(simpleRecordSchema), // Uses simpleRecordSchema from avro_core_test.go + Type: "AVRO", + }, + args: map[string]interface{}{ + "id": int64(42), // Note: field names in avro_core_test.go's simpleRecordSchema are lowercase + "name": "Test Payload", + }, + assertPayload: true, + wantErr: false, + }, + { + name: "serialization with mismatched argument type for simple record", + schemaInfo: utils.SchemaInfo{ + Name: "SimpleAvroSerializeInvalidArg", + Schema: []byte(simpleRecordSchema), + Type: "AVRO", + }, + args: map[string]interface{}{ + "id": "not_a_long", // Invalid type + "name": "Test Payload", + }, + assertPayload: false, + wantErr: true, + errorContains: "arguments validation failed", + }, + { + name: "valid serialization for complex record", + schemaInfo: utils.SchemaInfo{ + Name: "ComplexAvroSerialize", + Schema: []byte(complexRecordSchemaString), // Uses complexRecordSchemaString defined in avro_test.go + Type: "AVRO", + }, + args: map[string]interface{}{ + "fieldString": "hello avro", + "fieldInt": int32(101), + "fieldLong": int64(202), + "fieldDouble": float64(30.3), + "fieldFloat": float32(4.04), + "fieldBool": true, + "fieldBytes": []byte("avro bytes"), + "fieldFixed": uint64(0x1234567890123456), // Must be 16 bytes, will be corrected in loop + "fieldEnum": "B", + "fieldArray": []interface{}{int32(11), int32(22)}, + "fieldMap": map[string]interface{}{"mKey": int64(333)}, + "fieldRecord": map[string]interface{}{"subField": "sub data"}, + "fieldUnion": "union string val", + }, + assertPayload: true, // We'll generate expected payload for this + wantErr: false, + }, + { + name: "serialization with invalid schema type", + schemaInfo: utils.SchemaInfo{ + Name: "InvalidSchemaTypeSerialize", + Schema: []byte(simpleRecordSchema), + Type: "JSON", // Invalid type for AvroConverter + }, + args: map[string]interface{}{ + "id": int64(1), + "name": "Dummy", + }, + assertPayload: false, + wantErr: true, + errorContains: "arguments validation failed", // ValidateArguments checks schema type first + }, + } + + for i := range tests { + tt := &tests[i] // Use pointer to allow modification for expectedPayload + + // Pre-calculate expected payload for valid cases + if tt.assertPayload && !tt.wantErr { + var schemaToUse string + var argsToMarshal map[string]interface{} + + if tt.schemaInfo.Name == "SimpleAvroSerialize" { + schemaToUse = simpleRecordSchema + argsToMarshal = tt.args + } else if tt.schemaInfo.Name == "ComplexAvroSerialize" { + schemaToUse = complexRecordSchemaString + complexArgsCopy := make(map[string]interface{}) + for k, v := range tt.args { + complexArgsCopy[k] = v + } + complexArgsCopy["fieldFixed"] = uint64(0x0123456789abcdef) // 16 bytes + argsToMarshal = complexArgsCopy + tt.args = complexArgsCopy // Update tt.args with corrected fixed field + } + + if schemaToUse != "" { + tt.expectedPayload = getExpectedAvroBinary(t, schemaToUse, argsToMarshal) + } else { + t.Fatalf("Test setup error for %s: schemaToUse not set for payload generation", tt.name) + } + } + + t.Run(tt.name, func(t *testing.T) { + conv := NewAvroConverter() + payload, err := conv.SerializeMCPRequestToPulsarPayload(tt.args, &tt.schemaInfo) + + if tt.wantErr { + assert.Error(t, err, "Expected an error for test case: %s", tt.name) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains, "Error message mismatch for test case: %s", tt.name) + } + } else { + assert.NoError(t, err, "Did not expect an error for test case: %s", tt.name) + if tt.assertPayload { + assert.Equal(t, tt.expectedPayload, payload, "Payload mismatch for test case: %s", tt.name) + } + } + }) + } +} diff --git a/pkg/schema/boolean.go b/pkg/schema/boolean.go new file mode 100644 index 0000000..96d74a4 --- /dev/null +++ b/pkg/schema/boolean.go @@ -0,0 +1,59 @@ +package schema + +import ( + "fmt" + + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/streamnative/streamnative-mcp-server/pkg/common" +) + +// BooleanConverter handles the conversion for Pulsar BOOLEAN schemas. +type BooleanConverter struct { + BaseConverter +} + +// NewBooleanConverter creates a new instance of BooleanConverter. +func NewBooleanConverter() *BooleanConverter { + return &BooleanConverter{ + BaseConverter: BaseConverter{ + ParamName: ParamName, + }, + } +} + +func (c *BooleanConverter) ToMCPToolInputSchemaProperties(schemaInfo *cliutils.SchemaInfo) ([]mcp.ToolOption, error) { + if schemaInfo.Type != "BOOLEAN" { + return nil, fmt.Errorf("expected BOOLEAN schema, got %s", schemaInfo.Type) + } + + return []mcp.ToolOption{ + mcp.WithBoolean(c.ParamName, mcp.Description(fmt.Sprintf("The input schema is a %s schema", schemaInfo.Type)), mcp.Required()), + }, nil +} + +func (c *BooleanConverter) SerializeMCPRequestToPulsarPayload(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) ([]byte, error) { + if err := c.ValidateArguments(arguments, targetPulsarSchemaInfo); err != nil { + return nil, fmt.Errorf("arguments validation failed: %w", err) + } + + payload, err := common.RequiredParam[bool](arguments, c.ParamName) + if err != nil { + return nil, fmt.Errorf("failed to get payload: %w", err) + } + + return []byte(fmt.Sprintf("%t", payload)), nil +} + +func (c *BooleanConverter) ValidateArguments(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) error { + if targetPulsarSchemaInfo.Type != "BOOLEAN" { + return fmt.Errorf("expected BOOLEAN schema, got %s", targetPulsarSchemaInfo.Type) + } + + _, err := common.RequiredParam[bool](arguments, c.ParamName) + if err != nil { + return fmt.Errorf("failed to get payload: %w", err) + } + + return nil +} diff --git a/pkg/schema/boolean_test.go b/pkg/schema/boolean_test.go new file mode 100644 index 0000000..9123a00 --- /dev/null +++ b/pkg/schema/boolean_test.go @@ -0,0 +1,182 @@ +package schema + +import ( + "fmt" + "testing" + + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" +) + +// Helper function to create SchemaInfo for boolean tests +func newBoolSchemaInfo(schemaType string) *utils.SchemaInfo { + return &utils.SchemaInfo{ + Type: schemaType, + Schema: []byte{}, + } +} + +func TestNewBooleanConverter(t *testing.T) { + converter := NewBooleanConverter() + assert.NotNil(t, converter) + assert.Equal(t, ParamName, converter.ParamName, "ParamName should be initialized to the package constant") +} + +func TestBooleanConverter_ToMCPToolInputSchemaProperties(t *testing.T) { + converter := NewBooleanConverter() + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + wantOpts []mcp.ToolOption + wantErr bool + }{ + { + name: "Valid BOOLEAN schema", + schemaInfo: newBoolSchemaInfo("BOOLEAN"), + wantOpts: []mcp.ToolOption{ + mcp.WithBoolean(ParamName, mcp.Description(fmt.Sprintf("The input schema is a %s schema", "BOOLEAN")), mcp.Required()), + }, + wantErr: false, + }, + { + name: "Invalid schema type (STRING)", + schemaInfo: newBoolSchemaInfo("STRING"), + wantOpts: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOpts, err := converter.ToMCPToolInputSchemaProperties(tt.schemaInfo) + if (err != nil) != tt.wantErr { + t.Errorf("ToMCPToolInputSchemaProperties() error = %v, wantErr %v", err, tt.wantErr) + return + } + var expectedTool, actualTool mcp.Tool + expectedTool = mcp.NewTool("test", tt.wantOpts...) + actualTool = mcp.NewTool("test", gotOpts...) + expectedToolSchemaJSON, _ := expectedTool.InputSchema.MarshalJSON() + actualToolSchemaJSON, _ := actualTool.InputSchema.MarshalJSON() + assert.Equal(t, expectedToolSchemaJSON, actualToolSchemaJSON) + if tt.wantErr && err != nil { + assert.Contains(t, err.Error(), "expected BOOLEAN schema, got") + } + }) + } +} + +func TestBooleanConverter_ValidateArguments(t *testing.T) { + converter := NewBooleanConverter() + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + args map[string]any + wantErr bool + errContain string // Substring to check in error message if wantErr is true + }{ + { + name: "Valid arguments for BOOLEAN schema", + schemaInfo: newBoolSchemaInfo("BOOLEAN"), + args: map[string]any{ParamName: true}, + wantErr: false, + }, + { + name: "Invalid schema type (STRING)", + schemaInfo: newBoolSchemaInfo("STRING"), + args: map[string]any{ParamName: true}, + wantErr: true, + errContain: "expected BOOLEAN schema, got STRING", + }, + { + name: "Missing payload argument", + schemaInfo: newBoolSchemaInfo("BOOLEAN"), + args: map[string]any{}, + wantErr: true, + errContain: "missing required parameter: payload", + }, + { + name: "Incorrect payload type (string instead of bool)", + schemaInfo: newBoolSchemaInfo("BOOLEAN"), + args: map[string]any{ParamName: "true"}, + wantErr: true, + errContain: "parameter payload is not of type bool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := converter.ValidateArguments(tt.args, tt.schemaInfo) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateArguments() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err != nil { + assert.Contains(t, err.Error(), tt.errContain) + } + }) + } +} + +func TestBooleanConverter_SerializeMCPRequestToPulsarPayload(t *testing.T) { + converter := NewBooleanConverter() + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + args map[string]any + want []byte + wantErr bool + errContain string + }{ + { + name: "Serialize true for BOOLEAN schema", + schemaInfo: newBoolSchemaInfo("BOOLEAN"), + args: map[string]any{ParamName: true}, + want: []byte("true"), + wantErr: false, + }, + { + name: "Serialize false for BOOLEAN schema", + schemaInfo: newBoolSchemaInfo("BOOLEAN"), + args: map[string]any{ParamName: false}, + want: []byte("false"), + wantErr: false, + }, + { + name: "Validation error (e.g., missing payload)", + schemaInfo: newBoolSchemaInfo("BOOLEAN"), + args: map[string]any{}, + want: nil, + wantErr: true, + errContain: "arguments validation failed", // Outer error message from SerializeMCPRequestToPulsarPayload + }, + { + name: "Validation error (e.g., wrong schema type during ValidateArguments)", + schemaInfo: newBoolSchemaInfo("STRING"), // Invalid schema type for this converter + args: map[string]any{ParamName: true}, + want: nil, + wantErr: true, + errContain: "arguments validation failed", // Outer error message from SerializeMCPRequestToPulsarPayload + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := converter.SerializeMCPRequestToPulsarPayload(tt.args, tt.schemaInfo) + if (err != nil) != tt.wantErr { + t.Errorf("SerializeMCPRequestToPulsarPayload() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + if tt.wantErr && err != nil { + assert.Contains(t, err.Error(), tt.errContain) + } + }) + } +} + +// Future test functions will be added here. diff --git a/pkg/schema/common.go b/pkg/schema/common.go new file mode 100644 index 0000000..4540fa0 --- /dev/null +++ b/pkg/schema/common.go @@ -0,0 +1,37 @@ +package schema + +import ( + "fmt" + + "github.com/apache/pulsar-client-go/pulsar" +) + +// GetSchemaType 返回Schema类型的字符串表示 +func GetSchemaType(schemaType pulsar.SchemaType) string { + switch schemaType { + case pulsar.AVRO: + return "AVRO" + case pulsar.JSON: + return "JSON" + case pulsar.STRING: + return "STRING" + case pulsar.INT8: + return "INT8" + case pulsar.INT16: + return "INT16" + case pulsar.INT32: + return "INT32" + case pulsar.INT64: + return "INT64" + case pulsar.FLOAT: + return "FLOAT" + case pulsar.DOUBLE: + return "DOUBLE" + case pulsar.BOOLEAN: + return "BOOLEAN" + case pulsar.BYTES: + return "BYTES" + default: + return fmt.Sprintf("Unknown(%d)", schemaType) + } +} diff --git a/pkg/schema/common_test.go b/pkg/schema/common_test.go new file mode 100644 index 0000000..8b923b8 --- /dev/null +++ b/pkg/schema/common_test.go @@ -0,0 +1,38 @@ +package schema + +import ( + "testing" + + "github.com/apache/pulsar-client-go/pulsar" +) + +// Future test functions will be added here. + +func TestGetSchemaType(t *testing.T) { + tests := []struct { + name string + schemaType pulsar.SchemaType + want string + }{ + {name: "AVRO", schemaType: pulsar.AVRO, want: "AVRO"}, + {name: "JSON", schemaType: pulsar.JSON, want: "JSON"}, + {name: "STRING", schemaType: pulsar.STRING, want: "STRING"}, + {name: "INT8", schemaType: pulsar.INT8, want: "INT8"}, + {name: "INT16", schemaType: pulsar.INT16, want: "INT16"}, + {name: "INT32", schemaType: pulsar.INT32, want: "INT32"}, + {name: "INT64", schemaType: pulsar.INT64, want: "INT64"}, + {name: "FLOAT", schemaType: pulsar.FLOAT, want: "FLOAT"}, + {name: "DOUBLE", schemaType: pulsar.DOUBLE, want: "DOUBLE"}, + {name: "BOOLEAN", schemaType: pulsar.BOOLEAN, want: "BOOLEAN"}, + {name: "BYTES", schemaType: pulsar.BYTES, want: "BYTES"}, + {name: "Unknown", schemaType: pulsar.SchemaType(999), want: "Unknown(999)"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetSchemaType(tt.schemaType); got != tt.want { + t.Errorf("GetSchemaType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/schema/converter.go b/pkg/schema/converter.go new file mode 100644 index 0000000..987dca5 --- /dev/null +++ b/pkg/schema/converter.go @@ -0,0 +1,41 @@ +package schema + +import ( + "fmt" + + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" +) + +const ( + ParamName = "payload" +) + +type Converter interface { + ToMCPToolInputSchemaProperties(pulsarSchemaInfo *cliutils.SchemaInfo) ([]mcp.ToolOption, error) + + SerializeMCPRequestToPulsarPayload(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) ([]byte, error) + + ValidateArguments(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) error +} + +func ConverterFactory(schemaType string) (Converter, error) { + switch schemaType { + case "AVRO": + return NewAvroConverter(), nil + case "JSON": + return NewJSONConverter(), nil + case "STRING", "BYTES": + return NewStringConverter(), nil + case "INT8", "INT16", "INT32", "INT64", "FLOAT", "DOUBLE": + return NewNumberConverter(), nil + case "BOOLEAN": + return NewBooleanConverter(), nil + default: + return nil, fmt.Errorf("unsupported schema type: %v", schemaType) + } +} + +type BaseConverter struct { + ParamName string +} diff --git a/pkg/schema/converter_test.go b/pkg/schema/converter_test.go new file mode 100644 index 0000000..8d657df --- /dev/null +++ b/pkg/schema/converter_test.go @@ -0,0 +1,50 @@ +package schema + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Future test functions will be added here. + +func TestConverterFactory(t *testing.T) { + tests := []struct { + name string + schemaType string + wantType reflect.Type + wantErr bool + }{ + {name: "AVRO", schemaType: "AVRO", wantType: reflect.TypeOf(&AvroConverter{}), wantErr: false}, + {name: "JSON", schemaType: "JSON", wantType: reflect.TypeOf(&JSONConverter{}), wantErr: false}, + {name: "STRING", schemaType: "STRING", wantType: reflect.TypeOf(&StringConverter{}), wantErr: false}, + {name: "BYTES", schemaType: "BYTES", wantType: reflect.TypeOf(&StringConverter{}), wantErr: false}, + {name: "INT8", schemaType: "INT8", wantType: reflect.TypeOf(&NumberConverter{}), wantErr: false}, + {name: "INT16", schemaType: "INT16", wantType: reflect.TypeOf(&NumberConverter{}), wantErr: false}, + {name: "INT32", schemaType: "INT32", wantType: reflect.TypeOf(&NumberConverter{}), wantErr: false}, + {name: "INT64", schemaType: "INT64", wantType: reflect.TypeOf(&NumberConverter{}), wantErr: false}, + {name: "FLOAT", schemaType: "FLOAT", wantType: reflect.TypeOf(&NumberConverter{}), wantErr: false}, + {name: "DOUBLE", schemaType: "DOUBLE", wantType: reflect.TypeOf(&NumberConverter{}), wantErr: false}, + {name: "BOOLEAN", schemaType: "BOOLEAN", wantType: reflect.TypeOf(&BooleanConverter{}), wantErr: false}, + {name: "avro_lowercase", schemaType: "avro", wantType: nil, wantErr: true}, + {name: "UNKNOWN_TYPE", schemaType: "UNKNOWN_TYPE", wantType: nil, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ConverterFactory(tt.schemaType) + if (err != nil) != tt.wantErr { + t.Errorf("ConverterFactory() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && reflect.TypeOf(got) != tt.wantType { + t.Errorf("ConverterFactory() got = %v, want %v", reflect.TypeOf(got), tt.wantType) + } + // For error cases, we might also want to assert the error message if it's specific. + if tt.wantErr && err != nil { + assert.Contains(t, err.Error(), "unsupported schema type") + } + }) + } +} diff --git a/pkg/schema/json.go b/pkg/schema/json.go new file mode 100644 index 0000000..4638bce --- /dev/null +++ b/pkg/schema/json.go @@ -0,0 +1,60 @@ +package schema + +import ( + "encoding/json" // Required for json.Marshal + "fmt" + + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" +) + +// JSONConverter handles the conversion for Pulsar JSON schemas. +// It relies on the underlying AVRO schema definition for structure and validation, +// but serializes to a standard JSON text payload. +type JSONConverter struct { + BaseConverter +} + +// NewJSONConverter creates a new instance of JSONConverter. +func NewJSONConverter() *JSONConverter { + return &JSONConverter{} +} + +// ToMCPToolInputSchemaProperties converts the Pulsar JSON SchemaInfo (which is AVRO based) +// to MCP tool input schema properties. +func (c *JSONConverter) ToMCPToolInputSchemaProperties(schemaInfo *cliutils.SchemaInfo) ([]mcp.ToolOption, error) { + if schemaInfo.Type != "JSON" { + // Assuming GetSchemaType will be available from somewhere in the package (e.g. converter.go) + return nil, fmt.Errorf("expected JSON schema, got %s", schemaInfo.Type) + } + // The schemaInfo.Schema for JSON type is the AVRO schema string definition. + // Delegate to the core AVRO processing function from avro_core.go. + return processAvroSchemaStringToMCPToolInput(string(schemaInfo.Schema)) +} + +// SerializeMCPRequestToPulsarPayload validates arguments against the underlying AVRO schema definition +// and then serializes them to a JSON text payload for Pulsar. +func (c *JSONConverter) SerializeMCPRequestToPulsarPayload(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) ([]byte, error) { + if err := c.ValidateArguments(arguments, targetPulsarSchemaInfo); err != nil { + return nil, fmt.Errorf("arguments validation failed: %w", err) + } + + // Serialize arguments to standard JSON []byte. + jsonData, err := json.Marshal(arguments) + if err != nil { + return nil, fmt.Errorf("failed to marshal arguments to JSON: %w", err) + } + + return jsonData, nil +} + +// ValidateArguments validates the given arguments against the Pulsar JSON SchemaInfo's +// underlying AVRO schema definition. +func (c *JSONConverter) ValidateArguments(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) error { + if targetPulsarSchemaInfo.Type != "JSON" { + return fmt.Errorf("expected JSON schema for validation, got %s", targetPulsarSchemaInfo.Type) + } + // The schemaInfo.Schema for JSON type is the AVRO schema string definition. + // Delegate to the core AVRO validation function from avro_core.go. + return validateArgumentsAgainstAvroSchemaString(arguments, string(targetPulsarSchemaInfo.Schema)) +} diff --git a/pkg/schema/json_test.go b/pkg/schema/json_test.go new file mode 100644 index 0000000..dc48caf --- /dev/null +++ b/pkg/schema/json_test.go @@ -0,0 +1,306 @@ +package schema + +import ( + "testing" + + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/stretchr/testify/assert" + + "github.com/mark3labs/mcp-go/mcp" +) + +// newJSONSchemaInfo is a helper to create SchemaInfo for JSON type with a given AVRO schema string. +func newJSONSchemaInfo(avroSchemaString string) *utils.SchemaInfo { + return &utils.SchemaInfo{ + Name: "test-json-schema", + Type: "JSON", + Schema: []byte(avroSchemaString), + } +} + +// TestNewJSONConverter tests the NewJSONConverter constructor. +func TestNewJSONConverter(t *testing.T) { + converter := NewJSONConverter() + assert.NotNil(t, converter, "NewJSONConverter should not return nil") + // JSONConverter does not have a ParamName like simpler converters, it relies on Avro structure. +} + +// TestJSONConverter_ToMCPToolInputSchemaProperties tests ToMCPToolInputSchemaProperties for JSONConverter. +func TestJSONConverter_ToMCPToolInputSchemaProperties(t *testing.T) { + converter := NewJSONConverter() + + // Re-define or ensure accessibility of AVRO schema constants from avro_core_test.go + // For this example, let's use a simple one. Ideally, these would be shared or accessible. + const localSimpleRecordSchema = `{ + "type": "record", + "name": "SimpleRecordForJSON", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int"} + ] + }` + + const invalidAvroSchema = `{"type": "invalid}` + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + expectedOptions []mcp.ToolOption // Simplified expectation for brevity + expectError bool + errorContains string + }{ + { + name: "Valid JSON schema (based on simple Avro record)", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchema), + expectedOptions: []mcp.ToolOption{ // This structure depends on processAvroSchemaStringToMCPToolInput + mcp.WithString("name", mcp.Description("name (type: string)"), mcp.Required()), + mcp.WithNumber("age", mcp.Description("age (type: int)"), mcp.Required()), + }, + expectError: false, + }, + { + name: "SchemaInfo type is not JSON", + schemaInfo: &utils.SchemaInfo{Type: "AVRO", Schema: []byte(localSimpleRecordSchema)}, + expectedOptions: nil, + expectError: true, + errorContains: "expected JSON schema, got AVRO", + }, + { + name: "Invalid underlying Avro schema string", + schemaInfo: newJSONSchemaInfo(invalidAvroSchema), + expectedOptions: nil, + expectError: true, + errorContains: "unknown type: {\"type\": \"invalid}", // Outer error from JSONConverter + }, + { + name: "Underlying Avro schema is nil", + schemaInfo: &utils.SchemaInfo{Type: "JSON", Schema: nil}, + expectedOptions: nil, + expectError: true, + errorContains: "avro: unknown type: ", + }, + { + name: "Underlying Avro schema is empty string", + schemaInfo: newJSONSchemaInfo(""), + expectedOptions: nil, + expectError: true, + errorContains: "avro: unknown type: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts, err := converter.ToMCPToolInputSchemaProperties(tt.schemaInfo) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, opts) // Or check for empty slice if appropriate + } else { + assert.NoError(t, err) + var expectedTool = mcp.NewTool("test", tt.expectedOptions...) + var actualTool = mcp.NewTool("test", opts...) + expectedToolSchemaJSON, _ := expectedTool.InputSchema.MarshalJSON() + actualToolSchemaJSON, _ := actualTool.InputSchema.MarshalJSON() + assert.Equal(t, expectedToolSchemaJSON, actualToolSchemaJSON, "ToolOptions mismatch") + } + }) + } +} + +// TestJSONConverter_ValidateArguments tests ValidateArguments for JSONConverter. +func TestJSONConverter_ValidateArguments(t *testing.T) { + converter := NewJSONConverter() + + // Re-use localSimpleRecordSchema and invalidAvroSchema from previous test or ensure they are accessible. + const localSimpleRecordSchemaForValidation = `{ + "type": "record", + "name": "SimpleRecordForJSONValidation", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int"} + ] + }` + const invalidAvroSchemaForValidation = `{"type": "invalid}` + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + args map[string]any + expectError bool + errorContains string + }{ + { + name: "Valid arguments for JSON schema (simple record)", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchemaForValidation), + args: map[string]any{ + "name": "testUser", + "age": 30, + }, + expectError: false, + }, + { + name: "SchemaInfo type is not JSON", + schemaInfo: &utils.SchemaInfo{Type: "AVRO", Schema: []byte(localSimpleRecordSchemaForValidation)}, + args: map[string]any{"name": "testUser", "age": 30}, + expectError: true, + errorContains: "expected JSON schema for validation, got AVRO", + }, + { + name: "Invalid underlying Avro schema string", + schemaInfo: newJSONSchemaInfo(invalidAvroSchemaForValidation), + args: map[string]any{"name": "testUser", "age": 30}, + expectError: true, + errorContains: "unknown type: {\"type\": \"invalid}", // Outer error + }, + { + name: "Missing required field (age) in args for JSON schema", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchemaForValidation), + args: map[string]any{"name": "testUser"}, + expectError: true, + errorContains: "required field 'age' is missing and has no default value", // Error from validateArgumentsAgainstAvroSchemaString + }, + { + name: "Wrong type for field (age as string) in args for JSON schema", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchemaForValidation), + args: map[string]any{"name": "testUser", "age": "thirty"}, + expectError: true, + errorContains: "field 'age': expected int, got string (value: thirty)", // Error from validateArgumentsAgainstAvroSchemaString + }, + { + name: "Nil arguments map", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchemaForValidation), + args: nil, // validateArgumentsAgainstAvroSchemaString treats nil as empty map + expectError: true, + errorContains: "required field 'name' is missing and has no default value", + }, + { + name: "Underlying Avro schema is nil", + schemaInfo: &utils.SchemaInfo{Type: "JSON", Schema: nil}, + args: map[string]any{"name": "testUser", "age": 30}, + expectError: true, + errorContains: "failed to parse AVRO schema for validation: avro: unknown type: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := converter.ValidateArguments(tt.args, tt.schemaInfo) + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestJSONConverter_SerializeMCPRequestToPulsarPayload tests SerializeMCPRequestToPulsarPayload for JSONConverter. +func TestJSONConverter_SerializeMCPRequestToPulsarPayload(t *testing.T) { + converter := NewJSONConverter() + + const localSimpleRecordSchemaForSerialization = `{ + "type": "record", + "name": "SimpleRecordForJSONSerialization", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int"}, + {"name": "city", "type": ["null", "string"], "default": null} + ] + }` + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + args map[string]any + expectedJSON string // Expected JSON string output + expectError bool + errorContains string + }{ + { + name: "Valid arguments, serialize to JSON", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchemaForSerialization), + args: map[string]any{ + "name": "Alice", + "age": 30, + "city": "New York", + }, + // Note: JSON marshaling of maps doesn't guarantee key order. + // We will unmarshal and compare maps for robust checking if direct string comparison is flaky. + // For simplicity here, we assume a common order or use a more robust comparison later if needed. + expectedJSON: `{"age":30,"city":"New York","name":"Alice"}`, + expectError: false, + }, + { + name: "Valid arguments with optional field null, serialize to JSON", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchemaForSerialization), + args: map[string]any{ + "name": "Bob", + "age": 40, + "city": nil, // Explicit null for optional field + }, + expectedJSON: `{"age":40,"city":null,"name":"Bob"}`, + expectError: false, + }, + { + name: "Valid arguments with optional field omitted, serialize to JSON", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchemaForSerialization), + args: map[string]any{ + "name": "Charlie", + "age": 35, + // city is omitted, should be treated as null by Avro logic but json.Marshal will omit it if not in map + }, + // If Avro layer adds default null to args map before JSON marshal, then `"city":null` would be here. + // JSONConverter.SerializeMCPRequestToPulsarPayload directly marshals the provided args map. + expectedJSON: `{"age":35,"name":"Charlie"}`, + expectError: false, + }, + { + name: "Validation error (missing required field name)", + schemaInfo: newJSONSchemaInfo(localSimpleRecordSchemaForSerialization), + args: map[string]any{ + "age": 25, + }, + expectedJSON: "", + expectError: true, + errorContains: "arguments validation failed", // Outer error from SerializeMCPRequestToPulsarPayload + }, + { + name: "SchemaInfo type is not JSON", + schemaInfo: &utils.SchemaInfo{Type: "AVRO", Schema: []byte(localSimpleRecordSchemaForSerialization)}, + args: map[string]any{ + "name": "David", + "age": 28, + }, + expectedJSON: "", + expectError: true, + errorContains: "expected JSON schema for validation, got AVRO", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := converter.SerializeMCPRequestToPulsarPayload(tt.args, tt.schemaInfo) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, payload) + } else { + assert.NoError(t, err) + assert.NotNil(t, payload) + // For robust JSON comparison, unmarshal both expected and actual to maps and compare maps, + // or use a library that does canonical JSON comparison. + // For now, direct string comparison of compact JSON. + assert.JSONEq(t, tt.expectedJSON, string(payload), "Serialized JSON does not match expected JSON") + } + }) + } +} diff --git a/pkg/schema/number.go b/pkg/schema/number.go new file mode 100644 index 0000000..bbf89dd --- /dev/null +++ b/pkg/schema/number.go @@ -0,0 +1,104 @@ +package schema + +import ( + "fmt" + "math" + + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/streamnative/streamnative-mcp-server/pkg/common" +) + +// NumberConverter handles the conversion for Pulsar numeric schemas (INT8, INT16, INT32, INT64, FLOAT, DOUBLE). +type NumberConverter struct { + BaseConverter +} + +// NewNumberConverter creates a new instance of NumberConverter. +func NewNumberConverter() *NumberConverter { + return &NumberConverter{ + BaseConverter: BaseConverter{ + ParamName: ParamName, + }, + } +} + +func (c *NumberConverter) ToMCPToolInputSchemaProperties(schemaInfo *cliutils.SchemaInfo) ([]mcp.ToolOption, error) { + if schemaInfo.Type != "INT8" && schemaInfo.Type != "INT16" && schemaInfo.Type != "INT32" && schemaInfo.Type != "INT64" && schemaInfo.Type != "FLOAT" && schemaInfo.Type != "DOUBLE" { + return nil, fmt.Errorf("expected INT8, INT16, INT32, INT64, FLOAT, or DOUBLE schema, got %s", schemaInfo.Type) + } + + return []mcp.ToolOption{ + mcp.WithNumber(c.ParamName, mcp.Description(fmt.Sprintf("The input schema is a %s schema", schemaInfo.Type)), mcp.Required()), + }, nil +} + +func (c *NumberConverter) SerializeMCPRequestToPulsarPayload(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) ([]byte, error) { + if err := c.ValidateArguments(arguments, targetPulsarSchemaInfo); err != nil { + return nil, fmt.Errorf("arguments validation failed: %w", err) + } + + payload, err := common.RequiredParam[float64](arguments, c.ParamName) + if err != nil { + return nil, fmt.Errorf("failed to get payload: %w", err) + } + + switch targetPulsarSchemaInfo.Type { + case "INT8": + return []byte(fmt.Sprintf("%d", int8(payload))), nil + case "INT16": + return []byte(fmt.Sprintf("%d", int16(payload))), nil + case "INT32": + return []byte(fmt.Sprintf("%d", int32(payload))), nil + case "INT64": + return []byte(fmt.Sprintf("%d", int64(payload))), nil + case "FLOAT": + return []byte(fmt.Sprintf("%f", payload)), nil + case "DOUBLE": + return []byte(fmt.Sprintf("%f", payload)), nil + default: + return nil, fmt.Errorf("unsupported schema type: %s", targetPulsarSchemaInfo.Type) + } +} + +func (c *NumberConverter) ValidateArguments(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) error { + if targetPulsarSchemaInfo.Type != "INT8" && targetPulsarSchemaInfo.Type != "INT16" && targetPulsarSchemaInfo.Type != "INT32" && targetPulsarSchemaInfo.Type != "INT64" && targetPulsarSchemaInfo.Type != "FLOAT" && targetPulsarSchemaInfo.Type != "DOUBLE" { + return fmt.Errorf("expected INT8, INT16, INT32, INT64, FLOAT, or DOUBLE schema, got %s", targetPulsarSchemaInfo.Type) + } + + payload, err := common.RequiredParam[float64](arguments, c.ParamName) + if err != nil { + return fmt.Errorf("failed to get payload: %w", err) + } + + switch targetPulsarSchemaInfo.Type { + case "INT8": + if payload < math.MinInt8 || payload > math.MaxInt8 { + return fmt.Errorf("payload out of range for INT8") + } + case "INT16": + if payload < math.MinInt16 || payload > math.MaxInt16 { + return fmt.Errorf("payload out of range for INT16") + } + case "INT32": + if payload < math.MinInt32 || payload > math.MaxInt32 { + return fmt.Errorf("payload out of range for INT32") + } + case "INT64": + if payload < math.MinInt64 || payload > math.MaxInt64 { + return fmt.Errorf("payload out of range for INT64") + } + case "FLOAT": + if payload < math.SmallestNonzeroFloat32 || payload > math.MaxFloat32 { + return fmt.Errorf("payload out of range for FLOAT") + } + case "DOUBLE": + if payload < math.SmallestNonzeroFloat64 || payload > math.MaxFloat64 { + return fmt.Errorf("payload out of range for DOUBLE") + } + default: + return fmt.Errorf("unsupported schema type: %s", targetPulsarSchemaInfo.Type) + } + + return nil +} diff --git a/pkg/schema/number_test.go b/pkg/schema/number_test.go new file mode 100644 index 0000000..d8d5e43 --- /dev/null +++ b/pkg/schema/number_test.go @@ -0,0 +1,300 @@ +package schema + +import ( + "fmt" + "math" + "testing" + + "github.com/apache/pulsar-client-go/pulsar" + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" +) + +// Helper function to create SchemaInfo for number tests. +// Based on persistent linter errors, cliutils.SchemaInfo.Type is likely a string. +func newNumberSchemaInfo(schemaType pulsar.SchemaType) *cliutils.SchemaInfo { + return &cliutils.SchemaInfo{ + Type: GetSchemaType(schemaType), // Convert pulsar.SchemaType to string for cliutils.SchemaInfo.Type + Schema: []byte{}, + } +} + +func TestNewNumberConverter(t *testing.T) { + converter := NewNumberConverter() + assert.Equal(t, "payload", converter.ParamName) +} + +func TestNumberConverter_ToMCPToolInputSchemaProperties(t *testing.T) { + converter := NewNumberConverter() + tests := []struct { + name string + schemaInfo *cliutils.SchemaInfo // This is schema.SchemaInfo which has Type as string + expectedProps []mcp.ToolOption + expectError bool + expectedErrorMsg string + }{ + { + name: "INT8 type", + // newNumberSchemaInfo now correctly sets schemaInfo.Type as string (e.g., "INT8") + schemaInfo: newNumberSchemaInfo(pulsar.INT8), + expectedProps: []mcp.ToolOption{mcp.WithNumber("payload", mcp.Description("The input schema is a INT8 schema"), mcp.Required())}, + expectError: false, + }, + { + name: "INT16 type", + schemaInfo: newNumberSchemaInfo(pulsar.INT16), + expectedProps: []mcp.ToolOption{mcp.WithNumber("payload", mcp.Description("The input schema is a INT16 schema"), mcp.Required())}, + expectError: false, + }, + { + name: "INT32 type", + schemaInfo: newNumberSchemaInfo(pulsar.INT32), + expectedProps: []mcp.ToolOption{mcp.WithNumber("payload", mcp.Description("The input schema is a INT32 schema"), mcp.Required())}, + expectError: false, + }, + { + name: "INT64 type", + schemaInfo: newNumberSchemaInfo(pulsar.INT64), + expectedProps: []mcp.ToolOption{mcp.WithNumber("payload", mcp.Description("The input schema is a INT64 schema"), mcp.Required())}, + expectError: false, + }, + { + name: "FLOAT type", + schemaInfo: newNumberSchemaInfo(pulsar.FLOAT), + expectedProps: []mcp.ToolOption{mcp.WithNumber("payload", mcp.Description("The input schema is a FLOAT schema"), mcp.Required())}, + expectError: false, + }, + { + name: "DOUBLE type", + schemaInfo: newNumberSchemaInfo(pulsar.DOUBLE), + expectedProps: []mcp.ToolOption{mcp.WithNumber("payload", mcp.Description("The input schema is a DOUBLE schema"), mcp.Required())}, + expectError: false, + }, + { + name: "Unsupported type STRING", + schemaInfo: newNumberSchemaInfo(pulsar.STRING), + expectedProps: nil, + expectError: true, + expectedErrorMsg: "expected INT8, INT16, INT32, INT64, FLOAT, or DOUBLE schema, got STRING", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + props, err := converter.ToMCPToolInputSchemaProperties(tt.schemaInfo) + if tt.expectError { + assert.Error(t, err) + if tt.expectedErrorMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrorMsg) + } + assert.Nil(t, props) + } else { + assert.NoError(t, err) + var expectedTool = mcp.NewTool("test", tt.expectedProps...) + var actualTool = mcp.NewTool("test", props...) + expectedToolSchemaJSON, _ := expectedTool.InputSchema.MarshalJSON() + actualToolSchemaJSON, _ := actualTool.InputSchema.MarshalJSON() + assert.Equal(t, string(expectedToolSchemaJSON), string(actualToolSchemaJSON), "ToolOptions mismatch") + } + }) + } +} + +func TestNumberConverter_ValidateArguments(t *testing.T) { + converter := NewNumberConverter() + type testArgs struct { + name string + schemaInfo *cliutils.SchemaInfo + args map[string]any + expectError bool + expectedErrorMsg string + } + + tests := []testArgs{} + + numericTypes := []struct { + pulsarType pulsar.SchemaType + minVal float64 + maxVal float64 + }{ + {pulsar.INT8, -128, 127}, + {pulsar.INT16, -32768, 32767}, + {pulsar.INT32, -2147483648, 2147483647}, + {pulsar.INT64, -9007199254740991, 9007199254740991}, + {pulsar.FLOAT, math.SmallestNonzeroFloat32, math.MaxFloat32}, + {pulsar.DOUBLE, math.SmallestNonzeroFloat64, math.MaxFloat64}, + } + + for _, nt := range numericTypes { + schemaTypeName := GetSchemaType(nt.pulsarType) + tests = append(tests, []testArgs{ + { + name: fmt.Sprintf("%s - valid value", schemaTypeName), + schemaInfo: newNumberSchemaInfo(nt.pulsarType), + args: map[string]any{"payload": (nt.minVal + nt.maxVal) / 2}, + expectError: false, + }, + { + name: fmt.Sprintf("%s - min value", schemaTypeName), + schemaInfo: newNumberSchemaInfo(nt.pulsarType), + args: map[string]any{"payload": nt.minVal}, + expectError: false, + }, + { + name: fmt.Sprintf("%s - max value", schemaTypeName), + schemaInfo: newNumberSchemaInfo(nt.pulsarType), + args: map[string]any{"payload": nt.maxVal}, + expectError: false, + }, + }...) + if nt.pulsarType != pulsar.INT64 && nt.pulsarType != pulsar.FLOAT && nt.pulsarType != pulsar.DOUBLE { + tests = append(tests, []testArgs{ + { + name: fmt.Sprintf("%s - value below min", schemaTypeName), + schemaInfo: newNumberSchemaInfo(nt.pulsarType), + args: map[string]any{"payload": nt.minVal - 1}, + expectError: true, + expectedErrorMsg: fmt.Sprintf("payload out of range for %s", schemaTypeName), + }, + { + name: fmt.Sprintf("%s - value above max", schemaTypeName), + schemaInfo: newNumberSchemaInfo(nt.pulsarType), + args: map[string]any{"payload": nt.maxVal + 1}, + expectError: true, + expectedErrorMsg: fmt.Sprintf("payload out of range for %s", schemaTypeName), + }, + }...) + } + } + + // Common error cases for one type (e.g., INT32) - these should behave similarly for others + tests = append(tests, []testArgs{ + { + name: "INT32 - missing payload", + schemaInfo: newNumberSchemaInfo(pulsar.INT32), + args: map[string]any{}, + expectError: true, + expectedErrorMsg: "failed to get payload: missing required parameter: payload", + }, + { + name: "INT32 - wrong payload type (string)", + schemaInfo: newNumberSchemaInfo(pulsar.INT32), + args: map[string]any{"payload": "not a number"}, + expectError: true, + expectedErrorMsg: "failed to get payload: parameter payload is not of type float64", + }, + { + name: "INT32 - wrong schemaInfo.Type (e.g. STRING)", + schemaInfo: newNumberSchemaInfo(pulsar.STRING), + args: map[string]any{"payload": 123}, + expectError: true, + expectedErrorMsg: "expected INT8, INT16, INT32, INT64, FLOAT, or DOUBLE schema, got STRING", + }, + }...) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := converter.ValidateArguments(tt.args, tt.schemaInfo) + if tt.expectError { + assert.Error(t, err) + if tt.expectedErrorMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestNumberConverter_SerializeMCPRequestToPulsarPayload(t *testing.T) { + converter := NewNumberConverter() + type testArgs struct { + name string + schemaInfo *cliutils.SchemaInfo + args map[string]any + expectedPayload []byte + expectError bool + expectedErrorMsg string + } + + tests := []testArgs{} + + numericTypeCases := []struct { + pulsarType pulsar.SchemaType + validPayload any // Use any because direct float64 might lose precision for int64 for ex. + expectedBytes []byte + // Add specific out-of-range or invalid format cases if serialization handles them distinctly from validation + }{ + {pulsar.INT8, float64(127), []byte("127")}, + {pulsar.INT8, float64(-128), []byte("-128")}, + {pulsar.INT16, float64(32767), []byte("32767")}, + {pulsar.INT32, float64(2147483647), []byte("2147483647")}, + // For INT64, using a number representable by float64 without precision loss + {pulsar.INT64, float64(9007199254740991), []byte("9007199254740991")}, + {pulsar.FLOAT, float64(123.456), []byte(fmt.Sprintf("%f", float64(123.456)))}, // Note: float formatting can be tricky, use simple case + {pulsar.DOUBLE, float64(1.23456789e10), []byte(fmt.Sprintf("%f", float64(1.23456789e10)))}, // Again, simple case for float formatting + } + + for _, tc := range numericTypeCases { + schemaTypeName := GetSchemaType(tc.pulsarType) + tests = append(tests, testArgs{ + name: fmt.Sprintf("%s - valid serialization", schemaTypeName), + schemaInfo: newNumberSchemaInfo(tc.pulsarType), + args: map[string]any{"payload": tc.validPayload}, + expectedPayload: tc.expectedBytes, + expectError: false, + }) + } + + // Error cases (mostly delegation to ValidateArguments, so these check that path) + tests = append(tests, []testArgs{ + { + name: "Error - INT32 - missing payload", + schemaInfo: newNumberSchemaInfo(pulsar.INT32), + args: map[string]any{}, + expectError: true, + expectedErrorMsg: "arguments validation failed: failed to get payload: missing required parameter", + }, + { + name: "Error - INT32 - wrong payload type", + schemaInfo: newNumberSchemaInfo(pulsar.INT32), + args: map[string]any{"payload": "not a number"}, + expectError: true, + expectedErrorMsg: "arguments validation failed: failed to get payload: parameter payload is not of type float64", + }, + { + name: "Error - INT32 - value out of range", + schemaInfo: newNumberSchemaInfo(pulsar.INT32), + args: map[string]any{"payload": float64(21474836480)}, // Clearly out of INT32 range + expectError: true, + expectedErrorMsg: "arguments validation failed: payload out of range for INT32", + }, + { + name: "Error - Unsupported Schema Type (STRING)", + schemaInfo: newNumberSchemaInfo(pulsar.STRING), + args: map[string]any{"payload": 123}, + expectError: true, + expectedErrorMsg: "arguments validation failed: expected INT8, INT16, INT32, INT64, FLOAT, or DOUBLE schema, got STRING", + }, + }...) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := converter.SerializeMCPRequestToPulsarPayload(tt.args, tt.schemaInfo) + if tt.expectError { + assert.Error(t, err) + if tt.expectedErrorMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrorMsg) + } + assert.Nil(t, payload) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedPayload, payload) + } + }) + } +} + +// Future test functions will be added here. diff --git a/pkg/schema/string.go b/pkg/schema/string.go new file mode 100644 index 0000000..297a07f --- /dev/null +++ b/pkg/schema/string.go @@ -0,0 +1,63 @@ +package schema + +import ( + "fmt" + + cliutils "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/streamnative/streamnative-mcp-server/pkg/common" +) + +// StringConverter handles the conversion for Pulsar STRING schemas. +type StringConverter struct { + BaseConverter +} + +// NewStringConverter creates a new instance of StringConverter. +func NewStringConverter() *StringConverter { + return &StringConverter{ + BaseConverter: BaseConverter{ + ParamName: ParamName, + }, + } +} + +func (c *StringConverter) ToMCPToolInputSchemaProperties(schemaInfo *cliutils.SchemaInfo) ([]mcp.ToolOption, error) { + if schemaInfo.Type != "STRING" && schemaInfo.Type != "BYTES" { + return nil, fmt.Errorf("expected STRING or BYTES schema, got %s", schemaInfo.Type) + } + + return []mcp.ToolOption{ + mcp.WithString(c.ParamName, mcp.Description(fmt.Sprintf("The input schema is a %s schema", schemaInfo.Type)), mcp.Required()), + }, nil +} + +func (c *StringConverter) SerializeMCPRequestToPulsarPayload(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) ([]byte, error) { + if err := c.ValidateArguments(arguments, targetPulsarSchemaInfo); err != nil { + return nil, fmt.Errorf("arguments validation failed: %w", err) + } + + payload, err := common.RequiredParam[string](arguments, c.ParamName) + if err != nil { + return nil, fmt.Errorf("failed to get payload: %w", err) + } + + return []byte(payload), nil +} + +func (c *StringConverter) ValidateArguments(arguments map[string]any, targetPulsarSchemaInfo *cliutils.SchemaInfo) error { + if targetPulsarSchemaInfo.Type != "STRING" && targetPulsarSchemaInfo.Type != "BYTES" { + return fmt.Errorf("expected STRING or BYTES schema, got %s", targetPulsarSchemaInfo.Type) + } + + payload, err := common.RequiredParam[string](arguments, c.ParamName) + if err != nil { + return fmt.Errorf("failed to get payload: %w", err) + } + + if payload == "" { + return fmt.Errorf("payload cannot be empty") + } + + return nil +} diff --git a/pkg/schema/string_test.go b/pkg/schema/string_test.go new file mode 100644 index 0000000..78a66bb --- /dev/null +++ b/pkg/schema/string_test.go @@ -0,0 +1,204 @@ +package schema + +import ( + "fmt" + "testing" + + "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" +) + +// Helper function to create SchemaInfo for string tests +func newStringSchemaInfo(schemaType string) *utils.SchemaInfo { + return &utils.SchemaInfo{ + Type: schemaType, + Schema: []byte{}, + } +} + +func TestNewStringConverter(t *testing.T) { + converter := NewStringConverter() + assert.NotNil(t, converter) + assert.Equal(t, ParamName, converter.ParamName, "ParamName should be initialized to the package constant") +} + +func TestStringConverter_ToMCPToolInputSchemaProperties(t *testing.T) { + converter := NewStringConverter() + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + wantOpts []mcp.ToolOption + wantErr bool + errContain string + }{ + { + name: "Valid STRING schema", + schemaInfo: newStringSchemaInfo("STRING"), + wantOpts: []mcp.ToolOption{ + mcp.WithString(ParamName, mcp.Description(fmt.Sprintf("The input schema is a %s schema", "STRING")), mcp.Required()), + }, + wantErr: false, + }, + { + name: "Valid BYTES schema", + schemaInfo: newStringSchemaInfo("BYTES"), + wantOpts: []mcp.ToolOption{ + mcp.WithString(ParamName, mcp.Description(fmt.Sprintf("The input schema is a %s schema", "BYTES")), mcp.Required()), + }, + wantErr: false, + }, + { + name: "Invalid schema type (JSON)", + schemaInfo: newStringSchemaInfo("JSON"), + wantOpts: nil, + wantErr: true, + errContain: "expected STRING or BYTES schema, got JSON", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOpts, err := converter.ToMCPToolInputSchemaProperties(tt.schemaInfo) + if (err != nil) != tt.wantErr { + t.Errorf("ToMCPToolInputSchemaProperties() error = %v, wantErr %v", err, tt.wantErr) + return + } + var expectedTool = mcp.NewTool("test", gotOpts...) + var actualTool = mcp.NewTool("test", tt.wantOpts...) + expectedToolInputSchemaJSON, _ := expectedTool.InputSchema.MarshalJSON() + actualToolInputSchemaJSON, _ := actualTool.InputSchema.MarshalJSON() + assert.Equal(t, string(expectedToolInputSchemaJSON), string(actualToolInputSchemaJSON)) + if tt.wantErr && err != nil { + assert.Contains(t, err.Error(), tt.errContain) + } + }) + } +} + +func TestStringConverter_ValidateArguments(t *testing.T) { + converter := NewStringConverter() + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + args map[string]any + wantErr bool + errContain string + }{ + { + name: "Valid arguments for STRING schema", + schemaInfo: newStringSchemaInfo("STRING"), + args: map[string]any{ParamName: "hello world"}, + wantErr: false, + }, + { + name: "Valid arguments for BYTES schema", + schemaInfo: newStringSchemaInfo("BYTES"), + args: map[string]any{ParamName: "bytes content"}, + wantErr: false, + }, + { + name: "Invalid schema type (JSON)", + schemaInfo: newStringSchemaInfo("JSON"), + args: map[string]any{ParamName: "test"}, + wantErr: true, + errContain: "expected STRING or BYTES schema, got JSON", + }, + { + name: "Missing payload argument", + schemaInfo: newStringSchemaInfo("STRING"), + args: map[string]any{}, + wantErr: true, + errContain: "failed to get payload: missing required parameter: payload", + }, + { + name: "Incorrect payload type (int instead of string)", + schemaInfo: newStringSchemaInfo("STRING"), + args: map[string]any{ParamName: 123}, + wantErr: true, + errContain: "failed to get payload: parameter payload is not of type string", + }, + { + name: "Empty string payload", + schemaInfo: newStringSchemaInfo("STRING"), + args: map[string]any{ParamName: ""}, + wantErr: true, + errContain: "failed to get payload: missing required parameter: payload", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := converter.ValidateArguments(tt.args, tt.schemaInfo) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateArguments() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err != nil { + assert.Contains(t, err.Error(), tt.errContain) + } + }) + } +} + +func TestStringConverter_SerializeMCPRequestToPulsarPayload(t *testing.T) { + converter := NewStringConverter() + + tests := []struct { + name string + schemaInfo *utils.SchemaInfo + args map[string]any + want []byte + wantErr bool + errContain string + }{ + { + name: "Serialize 'hello' for STRING schema", + schemaInfo: newStringSchemaInfo("STRING"), + args: map[string]any{ParamName: "hello"}, + want: []byte("hello"), + wantErr: false, + }, + { + name: "Serialize 'bytes content' for BYTES schema", + schemaInfo: newStringSchemaInfo("BYTES"), + args: map[string]any{ParamName: "bytes content"}, + want: []byte("bytes content"), + wantErr: false, + }, + { + name: "Validation error (e.g., empty string)", + schemaInfo: newStringSchemaInfo("STRING"), + args: map[string]any{ParamName: ""}, + want: nil, + wantErr: true, + errContain: "arguments validation failed", + }, + { + name: "Validation error (e.g., missing payload)", + schemaInfo: newStringSchemaInfo("STRING"), + args: map[string]any{}, + want: nil, + wantErr: true, + errContain: "arguments validation failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := converter.SerializeMCPRequestToPulsarPayload(tt.args, tt.schemaInfo) + if (err != nil) != tt.wantErr { + t.Errorf("SerializeMCPRequestToPulsarPayload() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + if tt.wantErr && err != nil { + assert.Contains(t, err.Error(), tt.errContain) + } + }) + } +} + +// Future test functions will be added here. diff --git a/sdk/sdk-apiserver/go.mod b/sdk/sdk-apiserver/go.mod index 276dd21..c780ba7 100644 --- a/sdk/sdk-apiserver/go.mod +++ b/sdk/sdk-apiserver/go.mod @@ -1,8 +1,9 @@ module github.com/streamnative/streamnative-mcp-server/sdk/sdk-apiserver -go 1.13 +go 1.21 -require ( - github.com/google/go-cmp v0.6.0 // indirect - golang.org/x/oauth2 v0.25.0 -) +toolchain go1.24.3 + +require golang.org/x/oauth2 v0.25.0 + +require github.com/google/go-cmp v0.7.0 // indirect diff --git a/sdk/sdk-apiserver/go.sum b/sdk/sdk-apiserver/go.sum index f7db442..74a187d 100644 --- a/sdk/sdk-apiserver/go.sum +++ b/sdk/sdk-apiserver/go.sum @@ -1,2 +1,2 @@ -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=