Skip to content

Commit 5210730

Browse files
mcp-reverse-proxy
Summary: - MCP server pass through to postgres conection.
1 parent 2d08cb5 commit 5210730

File tree

5 files changed

+391
-34
lines changed

5 files changed

+391
-34
lines changed

internal/stackql/cmd/mcp.go

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,18 @@ import (
2121

2222
"github.com/spf13/cobra"
2323

24+
"github.com/stackql/any-sdk/pkg/db/db_util"
25+
"github.com/stackql/any-sdk/pkg/dto"
2426
"github.com/stackql/any-sdk/pkg/logging"
2527
"github.com/stackql/stackql/internal/stackql/acid/tsm_physio"
2628
"github.com/stackql/stackql/internal/stackql/entryutil"
2729
"github.com/stackql/stackql/internal/stackql/handler"
2830
"github.com/stackql/stackql/internal/stackql/iqlerror"
2931
"github.com/stackql/stackql/internal/stackql/mcpbackend"
3032
"github.com/stackql/stackql/pkg/mcp_server"
33+
34+
"github.com/jackc/pgx/v5"
35+
_ "github.com/jackc/pgx/v5" //nolint:revive // canonical driver pattern
3136
)
3237

3338
//nolint:gochecknoglobals // cobra pattern
@@ -68,27 +73,47 @@ func runMCPServer(handlerCtx handler.HandlerContext) {
6873
if config.Server.IsReadOnly != nil {
6974
isReadOnly = *config.Server.IsReadOnly
7075
}
71-
orchestrator, orchestratorErr := tsm_physio.NewOrchestrator(handlerCtx)
72-
iqlerror.PrintErrorAndExitOneIfError(orchestratorErr)
73-
iqlerror.PrintErrorAndExitOneIfNil(orchestrator, "orchestrator is unexpectedly nil")
74-
// handlerCtx.SetTSMOrchestrator(orchestrator)
75-
backend, backendErr := mcpbackend.NewStackqlMCPBackendService(
76-
isReadOnly,
77-
orchestrator,
78-
handlerCtx,
79-
logging.GetLogger(),
80-
)
81-
iqlerror.PrintErrorAndExitOneIfError(backendErr)
82-
iqlerror.PrintErrorAndExitOneIfNil(backend, "mcp backend is unexpectedly nil")
83-
76+
var backend mcp_server.Backend
77+
var backendErr error
78+
if mcpServerType == "reverse_proxy" {
79+
dsn := config.GetBackendConnectionString()
80+
conn, connErr := pgx.Connect(context.Background(), dsn)
81+
iqlerror.PrintErrorAndExitOneIfError(connErr)
82+
defer conn.Close(context.Background()) //nolint:errcheck // TODO: investigate
83+
// conn
84+
var cfg dto.SQLBackendCfg
85+
cfg.DSN = dsn
86+
cfg.InitMaxRetries = 5
87+
cfg.InitRetryInitialDelay = 2
88+
db, err := db_util.GetDB("pgx", "postgres", cfg)
89+
iqlerror.PrintErrorAndExitOneIfError(err)
90+
backend, backendErr = mcpbackend.NewStackqlMCPReverseProxyService(
91+
isReadOnly,
92+
dsn,
93+
db,
94+
handlerCtx,
95+
logging.GetLogger(),
96+
)
97+
iqlerror.PrintErrorAndExitOneIfError(backendErr)
98+
} else {
99+
orchestrator, orchestratorErr := tsm_physio.NewOrchestrator(handlerCtx)
100+
iqlerror.PrintErrorAndExitOneIfError(orchestratorErr)
101+
iqlerror.PrintErrorAndExitOneIfNil(orchestrator, "orchestrator is unexpectedly nil")
102+
// handlerCtx.SetTSMOrchestrator(orchestrator)
103+
backend, backendErr = mcpbackend.NewStackqlMCPBackendService(
104+
isReadOnly,
105+
orchestrator,
106+
handlerCtx,
107+
logging.GetLogger(),
108+
)
109+
iqlerror.PrintErrorAndExitOneIfError(backendErr)
110+
iqlerror.PrintErrorAndExitOneIfNil(backend, "mcp backend is unexpectedly nil")
111+
}
84112
server, serverErr := mcp_server.NewAgnosticBackendServer(
85113
backend,
86114
&config,
87115
logging.GetLogger(),
88116
)
89-
// server, serverErr := mcp_server.NewExampleHTTPBackendServer(
90-
// logging.GetLogger(),
91-
// )
92117
iqlerror.PrintErrorAndExitOneIfError(serverErr)
93118
server.Start(context.Background()) //nolint:errcheck // TODO: investigate
94119
}
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
package mcpbackend
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"encoding/json"
7+
"fmt"
8+
9+
"github.com/sirupsen/logrus"
10+
"github.com/stackql/stackql/internal/stackql/handler"
11+
"github.com/stackql/stackql/pkg/mcp_server"
12+
)
13+
14+
type stackqlMCPReverseProxyService struct {
15+
isReadOnly bool
16+
dsn string
17+
handlerCtx handler.HandlerContext
18+
logger *logrus.Logger
19+
db *sql.DB
20+
interrogator StackqlInterrogator
21+
renderer resultsRenderer
22+
}
23+
24+
func NewStackqlMCPReverseProxyService(
25+
isReadOnly bool,
26+
dsn string,
27+
db *sql.DB,
28+
handlerCtx handler.HandlerContext,
29+
logger *logrus.Logger,
30+
) (mcp_server.Backend, error) {
31+
if logger == nil {
32+
logger = logrus.New()
33+
logger.SetLevel(logrus.InfoLevel)
34+
}
35+
if handlerCtx == nil {
36+
return nil, fmt.Errorf("handler context is nil")
37+
}
38+
return &stackqlMCPReverseProxyService{
39+
dsn: dsn,
40+
isReadOnly: isReadOnly,
41+
interrogator: NewSimpleStackqlInterrogator(),
42+
logger: logger,
43+
handlerCtx: handlerCtx,
44+
db: db,
45+
renderer: NewResultsRenderer(),
46+
}, nil
47+
}
48+
49+
func (b *stackqlMCPReverseProxyService) getDefaultFormat() string {
50+
return resultsFormatMarkdown
51+
}
52+
53+
func (b *stackqlMCPReverseProxyService) Ping(ctx context.Context) error {
54+
return nil
55+
}
56+
57+
func (b *stackqlMCPReverseProxyService) Close() error {
58+
return nil
59+
}
60+
61+
// Server and environment info
62+
func (b *stackqlMCPReverseProxyService) ServerInfo(ctx context.Context, args any) (mcp_server.ServerInfoOutput, error) {
63+
return mcp_server.ServerInfoOutput{
64+
Name: "Stackql MCP Reverse Proxy Service",
65+
Info: "This is the Stackql MCP Reverse Proxy Service.",
66+
IsReadOnly: b.isReadOnly,
67+
}, nil
68+
}
69+
70+
// Current DB identity details
71+
func (b *stackqlMCPReverseProxyService) DBIdentity(ctx context.Context, args any) (map[string]any, error) {
72+
return map[string]any{
73+
"identity": "stackql_mcp_reverse_proxy_service",
74+
}, nil
75+
}
76+
77+
func (b *stackqlMCPReverseProxyService) Greet(ctx context.Context, args mcp_server.GreetInput) (string, error) {
78+
return "Hi " + args.Name, nil
79+
}
80+
81+
//nolint:gocognit,funlen // acceptable
82+
func (b *stackqlMCPReverseProxyService) query(ctx context.Context, query string, rowLimit int) ([]map[string]any, error) {
83+
r, sqlErr := b.db.Query(query)
84+
if sqlErr != nil {
85+
return nil, sqlErr
86+
}
87+
rowsErr := r.Err()
88+
if rowsErr != nil {
89+
return nil, rowsErr
90+
}
91+
defer r.Close() //nolint:errcheck // TODO: investigate
92+
columnTypes, err := r.ColumnTypes()
93+
if err != nil {
94+
return nil, err
95+
}
96+
97+
count := len(columnTypes)
98+
var finalRows []map[string]any
99+
100+
rowCount := 0
101+
for r.Next() {
102+
if rowLimit > 0 && rowCount >= rowLimit {
103+
break
104+
}
105+
rowCount++
106+
scanArgs := make([]interface{}, count)
107+
108+
for i, v := range columnTypes {
109+
switch v.DatabaseTypeName() {
110+
case "VARCHAR", "TEXT", "UUID", "TIMESTAMP":
111+
scanArgs[i] = new(sql.NullString)
112+
break
113+
case "BOOL":
114+
scanArgs[i] = new(sql.NullBool)
115+
break
116+
case "INT4":
117+
scanArgs[i] = new(sql.NullInt64)
118+
break
119+
default:
120+
scanArgs[i] = new(sql.NullString)
121+
}
122+
}
123+
124+
scanErr := r.Scan(scanArgs...)
125+
126+
if scanErr != nil {
127+
return nil, scanErr
128+
}
129+
130+
masterData := map[string]any{}
131+
132+
for i, v := range columnTypes {
133+
if z, ok := (scanArgs[i]).(*sql.NullBool); ok {
134+
masterData[v.Name()] = z.Bool
135+
continue
136+
}
137+
if z, ok := (scanArgs[i]).(*sql.NullString); ok {
138+
masterData[v.Name()] = z.String
139+
continue
140+
}
141+
if z, ok := (scanArgs[i]).(*sql.NullInt64); ok {
142+
masterData[v.Name()] = z.Int64
143+
continue
144+
}
145+
if z, ok := (scanArgs[i]).(*sql.NullFloat64); ok {
146+
masterData[v.Name()] = z.Float64
147+
continue
148+
}
149+
if z, ok := (scanArgs[i]).(*sql.NullInt32); ok {
150+
masterData[v.Name()] = z.Int32
151+
continue
152+
}
153+
masterData[v.Name()] = scanArgs[i]
154+
}
155+
finalRows = append(finalRows, masterData)
156+
}
157+
return finalRows, nil
158+
}
159+
160+
func (b *stackqlMCPReverseProxyService) renderQueryResults(query string, format string, rowLimit int) (string, error) {
161+
if format == "" {
162+
format = b.getDefaultFormat()
163+
}
164+
ctx := context.Background()
165+
rows, err := b.query(ctx, query, rowLimit)
166+
if err != nil {
167+
return "", err
168+
}
169+
switch format {
170+
case resultsFormatMarkdown:
171+
return b.renderer.RenderAsMarkdown(rows), nil
172+
case resultsFormatJSON:
173+
jsonStr, jsonErr := json.Marshal(rows)
174+
if jsonErr != nil {
175+
return "", jsonErr
176+
}
177+
return string(jsonStr), nil
178+
default:
179+
return "", fmt.Errorf("unknown format: %s", format)
180+
}
181+
}
182+
183+
func (b *stackqlMCPReverseProxyService) RunQuery(ctx context.Context, args mcp_server.QueryInput) (string, error) {
184+
if args.Format == "" {
185+
args.Format = b.getDefaultFormat()
186+
}
187+
rows, err := b.query(ctx, args.SQL, args.RowLimit)
188+
if err != nil {
189+
return "", err
190+
}
191+
switch args.Format {
192+
case resultsFormatMarkdown:
193+
return b.renderer.RenderAsMarkdown(rows), nil
194+
case resultsFormatJSON:
195+
jsonStr, jsonErr := json.Marshal(rows)
196+
if jsonErr != nil {
197+
return "", jsonErr
198+
}
199+
return string(jsonStr), nil
200+
default:
201+
return "", fmt.Errorf("unknown format: %s", args.Format)
202+
}
203+
}
204+
205+
func (b *stackqlMCPReverseProxyService) RunQueryJSON(ctx context.Context, input mcp_server.QueryJSONInput) ([]map[string]interface{}, error) {
206+
return b.query(ctx, input.SQL, input.RowLimit)
207+
}
208+
209+
// func (b *stackqlMCPReverseProxyService) ListTableResources(ctx context.Context, hI mcp_server.HierarchyInput) ([]string, error) {
210+
211+
// TODO: implement the remaining methods, using the db connection as sole sql data source
212+
213+
// return []string{}, nil
214+
// }
215+
216+
func (b *stackqlMCPReverseProxyService) ReadTableResource(ctx context.Context, hI mcp_server.HierarchyInput) ([]map[string]interface{}, error) {
217+
if hI.Provider == "" || hI.Service == "" || hI.Resource == "" {
218+
return nil, fmt.Errorf("provider, service, and resource must be specified")
219+
}
220+
query := fmt.Sprintf("SELECT * FROM %s.%s", hI.Service, hI.Resource)
221+
return b.query(ctx, query, hI.RowLimit)
222+
}
223+
224+
func (b *stackqlMCPReverseProxyService) PromptWriteSafeSelectTool(ctx context.Context, args mcp_server.HierarchyInput) (string, error) {
225+
return mcp_server.ExplainerPromptWriteSafeSelectTool, nil
226+
}
227+
228+
// func (b *stackqlMCPReverseProxyService) PromptExplainPlanTipsTool(ctx context.Context) (string, error) {
229+
// return "stub", nil
230+
// }
231+
232+
func (b *stackqlMCPReverseProxyService) DescribeTable(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) {
233+
q, qErr := b.interrogator.GetDescribeTable(hI)
234+
if qErr != nil {
235+
return "", qErr
236+
}
237+
return b.renderQueryResults(q, hI.Format, hI.RowLimit)
238+
}
239+
240+
func (b *stackqlMCPReverseProxyService) GetForeignKeys(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) {
241+
return b.interrogator.GetForeignKeys(hI)
242+
}
243+
244+
func (b *stackqlMCPReverseProxyService) FindRelationships(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) {
245+
return b.interrogator.FindRelationships(hI)
246+
}
247+
248+
func (b *stackqlMCPReverseProxyService) ListProviders(ctx context.Context) (string, error) {
249+
q, qErr := b.interrogator.GetShowProviders(mcp_server.HierarchyInput{}, "")
250+
if qErr != nil {
251+
return "", qErr
252+
}
253+
return b.renderQueryResults(q, "", unlimitedRowLimit)
254+
}
255+
256+
func (b *stackqlMCPReverseProxyService) ListServices(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) {
257+
q, qErr := b.interrogator.GetShowServices(hI, "")
258+
if qErr != nil {
259+
return "", qErr
260+
}
261+
return b.renderQueryResults(q, hI.Format, hI.RowLimit)
262+
}
263+
264+
func (b *stackqlMCPReverseProxyService) ListResources(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) {
265+
q, qErr := b.interrogator.GetShowResources(hI, "")
266+
if qErr != nil {
267+
return "", qErr
268+
}
269+
return b.renderQueryResults(q, hI.Format, hI.RowLimit)
270+
}
271+
272+
func (b *stackqlMCPReverseProxyService) ListTablesJSON(ctx context.Context, input mcp_server.ListTablesInput) ([]map[string]interface{}, error) {
273+
hI := mcp_server.HierarchyInput{}
274+
likeStr := ""
275+
if input.Hierarchy != nil {
276+
hI = *input.Hierarchy
277+
}
278+
if input.NameLike != nil {
279+
likeStr = *input.NameLike
280+
}
281+
q, qErr := b.interrogator.GetShowResources(hI, likeStr)
282+
if qErr != nil {
283+
return nil, qErr
284+
}
285+
return b.query(ctx, q, input.RowLimit)
286+
}
287+
288+
func (b *stackqlMCPReverseProxyService) ListTablesJSONPage(ctx context.Context, input mcp_server.ListTablesPageInput) (map[string]interface{}, error) {
289+
return map[string]interface{}{}, nil
290+
}
291+
292+
func (b *stackqlMCPReverseProxyService) ListTables(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) {
293+
return b.ListResources(ctx, hI)
294+
}
295+
296+
func (b *stackqlMCPReverseProxyService) ListMethods(ctx context.Context, hI mcp_server.HierarchyInput) (string, error) {
297+
q, qErr := b.interrogator.GetShowMethods(hI)
298+
if qErr != nil {
299+
return "", qErr
300+
}
301+
return b.renderQueryResults(q, hI.Format, hI.RowLimit)
302+
}

0 commit comments

Comments
 (0)