|
| 1 | +package mcpbackend |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "database/sql" |
| 6 | + "encoding/json" |
| 7 | + "fmt" |
| 8 | + |
| 9 | + "github.com/sirupsen/logrus" |
| 10 | + "github.com/stackql/stackql/internal/stackql/handler" |
| 11 | + "github.com/stackql/stackql/pkg/mcp_server" |
| 12 | +) |
| 13 | + |
| 14 | +type stackqlMCPReverseProxyService struct { |
| 15 | + isReadOnly bool |
| 16 | + dsn string |
| 17 | + handlerCtx handler.HandlerContext |
| 18 | + logger *logrus.Logger |
| 19 | + db *sql.DB |
| 20 | + interrogator StackqlInterrogator |
| 21 | + renderer resultsRenderer |
| 22 | +} |
| 23 | + |
| 24 | +func NewStackqlMCPReverseProxyService( |
| 25 | + isReadOnly bool, |
| 26 | + dsn string, |
| 27 | + db *sql.DB, |
| 28 | + handlerCtx handler.HandlerContext, |
| 29 | + logger *logrus.Logger, |
| 30 | +) (mcp_server.Backend, error) { |
| 31 | + if logger == nil { |
| 32 | + logger = logrus.New() |
| 33 | + logger.SetLevel(logrus.InfoLevel) |
| 34 | + } |
| 35 | + if handlerCtx == nil { |
| 36 | + return nil, fmt.Errorf("handler context is nil") |
| 37 | + } |
| 38 | + return &stackqlMCPReverseProxyService{ |
| 39 | + dsn: dsn, |
| 40 | + isReadOnly: isReadOnly, |
| 41 | + interrogator: NewSimpleStackqlInterrogator(), |
| 42 | + logger: logger, |
| 43 | + handlerCtx: handlerCtx, |
| 44 | + db: db, |
| 45 | + renderer: NewResultsRenderer(), |
| 46 | + }, nil |
| 47 | +} |
| 48 | + |
| 49 | +func (b *stackqlMCPReverseProxyService) getDefaultFormat() string { |
| 50 | + return resultsFormatMarkdown |
| 51 | +} |
| 52 | + |
| 53 | +func (b *stackqlMCPReverseProxyService) Ping(ctx context.Context) error { |
| 54 | + return nil |
| 55 | +} |
| 56 | + |
| 57 | +func (b *stackqlMCPReverseProxyService) Close() error { |
| 58 | + return nil |
| 59 | +} |
| 60 | + |
| 61 | +// Server and environment info |
| 62 | +func (b *stackqlMCPReverseProxyService) ServerInfo(ctx context.Context, args any) (mcp_server.ServerInfoOutput, error) { |
| 63 | + return mcp_server.ServerInfoOutput{ |
| 64 | + Name: "Stackql MCP Reverse Proxy Service", |
| 65 | + Info: "This is the Stackql MCP Reverse Proxy Service.", |
| 66 | + IsReadOnly: b.isReadOnly, |
| 67 | + }, nil |
| 68 | +} |
| 69 | + |
| 70 | +// Current DB identity details |
| 71 | +func (b *stackqlMCPReverseProxyService) DBIdentity(ctx context.Context, args any) (map[string]any, error) { |
| 72 | + return map[string]any{ |
| 73 | + "identity": "stackql_mcp_reverse_proxy_service", |
| 74 | + }, nil |
| 75 | +} |
| 76 | + |
| 77 | +func (b *stackqlMCPReverseProxyService) Greet(ctx context.Context, args mcp_server.GreetInput) (string, error) { |
| 78 | + return "Hi " + args.Name, nil |
| 79 | +} |
| 80 | + |
| 81 | +//nolint:gocognit,funlen // acceptable |
| 82 | +func (b *stackqlMCPReverseProxyService) query(ctx context.Context, query string, rowLimit int) ([]map[string]any, error) { |
| 83 | + r, sqlErr := b.db.Query(query) |
| 84 | + if sqlErr != nil { |
| 85 | + return nil, sqlErr |
| 86 | + } |
| 87 | + rowsErr := r.Err() |
| 88 | + if rowsErr != nil { |
| 89 | + return nil, rowsErr |
| 90 | + } |
| 91 | + defer r.Close() //nolint:errcheck // TODO: investigate |
| 92 | + columnTypes, err := r.ColumnTypes() |
| 93 | + if err != nil { |
| 94 | + return nil, err |
| 95 | + } |
| 96 | + |
| 97 | + count := len(columnTypes) |
| 98 | + var finalRows []map[string]any |
| 99 | + |
| 100 | + rowCount := 0 |
| 101 | + for r.Next() { |
| 102 | + if rowLimit > 0 && rowCount >= rowLimit { |
| 103 | + break |
| 104 | + } |
| 105 | + rowCount++ |
| 106 | + scanArgs := make([]interface{}, count) |
| 107 | + |
| 108 | + for i, v := range columnTypes { |
| 109 | + switch v.DatabaseTypeName() { |
| 110 | + case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": |
| 111 | + scanArgs[i] = new(sql.NullString) |
| 112 | + break |
| 113 | + case "BOOL": |
| 114 | + scanArgs[i] = new(sql.NullBool) |
| 115 | + break |
| 116 | + case "INT4": |
| 117 | + scanArgs[i] = new(sql.NullInt64) |
| 118 | + break |
| 119 | + default: |
| 120 | + scanArgs[i] = new(sql.NullString) |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + scanErr := r.Scan(scanArgs...) |
| 125 | + |
| 126 | + if scanErr != nil { |
| 127 | + return nil, scanErr |
| 128 | + } |
| 129 | + |
| 130 | + masterData := map[string]any{} |
| 131 | + |
| 132 | + for i, v := range columnTypes { |
| 133 | + if z, ok := (scanArgs[i]).(*sql.NullBool); ok { |
| 134 | + masterData[v.Name()] = z.Bool |
| 135 | + continue |
| 136 | + } |
| 137 | + if z, ok := (scanArgs[i]).(*sql.NullString); ok { |
| 138 | + masterData[v.Name()] = z.String |
| 139 | + continue |
| 140 | + } |
| 141 | + if z, ok := (scanArgs[i]).(*sql.NullInt64); ok { |
| 142 | + masterData[v.Name()] = z.Int64 |
| 143 | + continue |
| 144 | + } |
| 145 | + if z, ok := (scanArgs[i]).(*sql.NullFloat64); ok { |
| 146 | + masterData[v.Name()] = z.Float64 |
| 147 | + continue |
| 148 | + } |
| 149 | + if z, ok := (scanArgs[i]).(*sql.NullInt32); ok { |
| 150 | + masterData[v.Name()] = z.Int32 |
| 151 | + continue |
| 152 | + } |
| 153 | + masterData[v.Name()] = scanArgs[i] |
| 154 | + } |
| 155 | + finalRows = append(finalRows, masterData) |
| 156 | + } |
| 157 | + return finalRows, nil |
| 158 | +} |
| 159 | + |
| 160 | +func (b *stackqlMCPReverseProxyService) renderQueryResults(query string, format string, rowLimit int) (string, error) { |
| 161 | + if format == "" { |
| 162 | + format = b.getDefaultFormat() |
| 163 | + } |
| 164 | + ctx := context.Background() |
| 165 | + rows, err := b.query(ctx, query, rowLimit) |
| 166 | + if err != nil { |
| 167 | + return "", err |
| 168 | + } |
| 169 | + switch format { |
| 170 | + case resultsFormatMarkdown: |
| 171 | + return b.renderer.RenderAsMarkdown(rows), nil |
| 172 | + case resultsFormatJSON: |
| 173 | + jsonStr, jsonErr := json.Marshal(rows) |
| 174 | + if jsonErr != nil { |
| 175 | + return "", jsonErr |
| 176 | + } |
| 177 | + return string(jsonStr), nil |
| 178 | + default: |
| 179 | + return "", fmt.Errorf("unknown format: %s", format) |
| 180 | + } |
| 181 | +} |
| 182 | + |
| 183 | +func (b *stackqlMCPReverseProxyService) RunQuery(ctx context.Context, args mcp_server.QueryInput) (string, error) { |
| 184 | + if args.Format == "" { |
| 185 | + args.Format = b.getDefaultFormat() |
| 186 | + } |
| 187 | + rows, err := b.query(ctx, args.SQL, args.RowLimit) |
| 188 | + if err != nil { |
| 189 | + return "", err |
| 190 | + } |
| 191 | + switch args.Format { |
| 192 | + case resultsFormatMarkdown: |
| 193 | + return b.renderer.RenderAsMarkdown(rows), nil |
| 194 | + case resultsFormatJSON: |
| 195 | + jsonStr, jsonErr := json.Marshal(rows) |
| 196 | + if jsonErr != nil { |
| 197 | + return "", jsonErr |
| 198 | + } |
| 199 | + return string(jsonStr), nil |
| 200 | + default: |
| 201 | + return "", fmt.Errorf("unknown format: %s", args.Format) |
| 202 | + } |
| 203 | +} |
| 204 | + |
| 205 | +func (b *stackqlMCPReverseProxyService) RunQueryJSON(ctx context.Context, input mcp_server.QueryJSONInput) ([]map[string]interface{}, error) { |
| 206 | + return b.query(ctx, input.SQL, input.RowLimit) |
| 207 | +} |
| 208 | + |
| 209 | +// func (b *stackqlMCPReverseProxyService) ListTableResources(ctx context.Context, hI mcp_server.HierarchyInput) ([]string, error) { |
| 210 | + |
| 211 | +// TODO: implement the remaining methods, using the db connection as sole sql data source |
| 212 | + |
| 213 | +// return []string{}, nil |
| 214 | +// } |
| 215 | + |
| 216 | +func (b *stackqlMCPReverseProxyService) ReadTableResource(ctx context.Context, hI mcp_server.HierarchyInput) ([]map[string]interface{}, error) { |
| 217 | + if hI.Provider == "" || hI.Service == "" || hI.Resource == "" { |
| 218 | + return nil, fmt.Errorf("provider, service, and resource must be specified") |
| 219 | + } |
| 220 | + query := fmt.Sprintf("SELECT * FROM %s.%s", hI.Service, hI.Resource) |
| 221 | + return b.query(ctx, query, hI.RowLimit) |
| 222 | +} |
| 223 | + |
| 224 | +func (b *stackqlMCPReverseProxyService) PromptWriteSafeSelectTool(ctx context.Context, args mcp_server.HierarchyInput) (string, error) { |
| 225 | + return mcp_server.ExplainerPromptWriteSafeSelectTool, nil |
| 226 | +} |
| 227 | + |
| 228 | +// func (b *stackqlMCPReverseProxyService) PromptExplainPlanTipsTool(ctx context.Context) (string, error) { |
| 229 | +// return "stub", nil |
| 230 | +// } |
| 231 | + |
| 232 | +func (b *stackqlMCPReverseProxyService) DescribeTable(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) { |
| 233 | + q, qErr := b.interrogator.GetDescribeTable(hI) |
| 234 | + if qErr != nil { |
| 235 | + return "", qErr |
| 236 | + } |
| 237 | + return b.renderQueryResults(q, hI.Format, hI.RowLimit) |
| 238 | +} |
| 239 | + |
| 240 | +func (b *stackqlMCPReverseProxyService) GetForeignKeys(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) { |
| 241 | + return b.interrogator.GetForeignKeys(hI) |
| 242 | +} |
| 243 | + |
| 244 | +func (b *stackqlMCPReverseProxyService) FindRelationships(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) { |
| 245 | + return b.interrogator.FindRelationships(hI) |
| 246 | +} |
| 247 | + |
| 248 | +func (b *stackqlMCPReverseProxyService) ListProviders(ctx context.Context) (string, error) { |
| 249 | + q, qErr := b.interrogator.GetShowProviders(mcp_server.HierarchyInput{}, "") |
| 250 | + if qErr != nil { |
| 251 | + return "", qErr |
| 252 | + } |
| 253 | + return b.renderQueryResults(q, "", unlimitedRowLimit) |
| 254 | +} |
| 255 | + |
| 256 | +func (b *stackqlMCPReverseProxyService) ListServices(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) { |
| 257 | + q, qErr := b.interrogator.GetShowServices(hI, "") |
| 258 | + if qErr != nil { |
| 259 | + return "", qErr |
| 260 | + } |
| 261 | + return b.renderQueryResults(q, hI.Format, hI.RowLimit) |
| 262 | +} |
| 263 | + |
| 264 | +func (b *stackqlMCPReverseProxyService) ListResources(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) { |
| 265 | + q, qErr := b.interrogator.GetShowResources(hI, "") |
| 266 | + if qErr != nil { |
| 267 | + return "", qErr |
| 268 | + } |
| 269 | + return b.renderQueryResults(q, hI.Format, hI.RowLimit) |
| 270 | +} |
| 271 | + |
| 272 | +func (b *stackqlMCPReverseProxyService) ListTablesJSON(ctx context.Context, input mcp_server.ListTablesInput) ([]map[string]interface{}, error) { |
| 273 | + hI := mcp_server.HierarchyInput{} |
| 274 | + likeStr := "" |
| 275 | + if input.Hierarchy != nil { |
| 276 | + hI = *input.Hierarchy |
| 277 | + } |
| 278 | + if input.NameLike != nil { |
| 279 | + likeStr = *input.NameLike |
| 280 | + } |
| 281 | + q, qErr := b.interrogator.GetShowResources(hI, likeStr) |
| 282 | + if qErr != nil { |
| 283 | + return nil, qErr |
| 284 | + } |
| 285 | + return b.query(ctx, q, input.RowLimit) |
| 286 | +} |
| 287 | + |
| 288 | +func (b *stackqlMCPReverseProxyService) ListTablesJSONPage(ctx context.Context, input mcp_server.ListTablesPageInput) (map[string]interface{}, error) { |
| 289 | + return map[string]interface{}{}, nil |
| 290 | +} |
| 291 | + |
| 292 | +func (b *stackqlMCPReverseProxyService) ListTables(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) { |
| 293 | + return b.ListResources(ctx, hI) |
| 294 | +} |
| 295 | + |
| 296 | +func (b *stackqlMCPReverseProxyService) ListMethods(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) { |
| 297 | + q, qErr := b.interrogator.GetShowMethods(hI) |
| 298 | + if qErr != nil { |
| 299 | + return "", qErr |
| 300 | + } |
| 301 | + return b.renderQueryResults(q, hI.Format, hI.RowLimit) |
| 302 | +} |
0 commit comments