Skip to content

Commit 6dec52a

Browse files
authored
Re-read user session and refresh token when using MCP (#356)
1 parent 2c78e58 commit 6dec52a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+165
-121
lines changed

internal/cmd/audittrail/list.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func searchAuditTrailEntries(ctx context.Context, input structs.SearchInput) (se
182182
} `graphql:"searchAuditTrailEntries(input: $input)"`
183183
}
184184

185-
if err := authenticated.Client.Query(
185+
if err := authenticated.Client().Query(
186186
ctx,
187187
&query,
188188
map[string]interface{}{"input": input},

internal/cmd/authenticated/client.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"net/http"
99
"os"
10+
"sync"
1011

1112
"github.com/urfave/cli/v3"
1213

@@ -33,11 +34,30 @@ var (
3334
)
3435

3536
// Client is the authenticated client that can be used by all CLI commands.
36-
var Client client.Client
37+
var (
38+
auth client.Client
39+
m sync.Mutex
40+
)
41+
42+
// Client returns the authenticated client.
43+
//
44+
// This is an unfortunate global which we have to lock for MCP.
45+
// TODO: Refactor this to not use a global.
46+
func Client() client.Client {
47+
m.Lock()
48+
defer m.Unlock()
49+
50+
return auth
51+
}
3752

3853
// Ensure is a way of ensuring that the Client exists, and it meant to be used
3954
// as a Before action for commands that need it.
55+
//
56+
// You can also use it diretly to refresh the client.
4057
func Ensure(ctx context.Context, _ *cli.Command) (context.Context, error) {
58+
m.Lock()
59+
defer m.Unlock()
60+
4161
httpClient := client.GetHTTPClient()
4262

4363
if err := configureTLS(httpClient); err != nil {
@@ -49,7 +69,7 @@ func Ensure(ctx context.Context, _ *cli.Command) (context.Context, error) {
4969
return ctx, err
5070
}
5171

52-
Client = client.New(httpClient, session)
72+
auth = client.New(httpClient, session)
5373

5474
return ctx, nil
5575
}

internal/cmd/authenticated/viewer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func CurrentViewer(ctx context.Context) (*Viewer, error) {
1717
var query struct {
1818
Viewer *Viewer
1919
}
20-
if err := Client.Query(ctx, &query, map[string]interface{}{}); err != nil {
20+
if err := Client().Query(ctx, &query, map[string]interface{}{}); err != nil {
2121
return nil, errors.Wrap(err, "failed to query user information")
2222
}
2323
if query.Viewer == nil {

internal/cmd/blueprint/deploy.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (c *deployCommand) deploy(ctx context.Context, cliCmd *cli.Command) error {
7777
} `graphql:"blueprintCreateStack(id: $id, input: $input)"`
7878
}
7979

80-
err = authenticated.Client.Mutate(
80+
err = authenticated.Client().Mutate(
8181
ctx,
8282
&mutation,
8383
map[string]any{
@@ -91,7 +91,7 @@ func (c *deployCommand) deploy(ctx context.Context, cliCmd *cli.Command) error {
9191
return fmt.Errorf("failed to deploy stack from the blueprint: %w", err)
9292
}
9393

94-
url := authenticated.Client.URL("/stack/%s", mutation.BlueprintCreateStack.StackID)
94+
url := authenticated.Client().URL("/stack/%s", mutation.BlueprintCreateStack.StackID)
9595
fmt.Printf("\nCreated stack: %q", url)
9696

9797
return nil

internal/cmd/blueprint/list.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func searchBlueprints(ctx context.Context, input structs.SearchInput) (searchBlu
207207
} `graphql:"searchBlueprints(input: $input)"`
208208
}
209209

210-
if err := authenticated.Client.Query(
210+
if err := authenticated.Client().Query(
211211
ctx,
212212
&query,
213213
map[string]interface{}{"input": input},

internal/cmd/blueprint/show.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ func getBlueprintByID(ctx context.Context, blueprintID string) (blueprint, bool,
164164
"blueprintId": graphql.ID(blueprintID),
165165
}
166166

167-
if err := authenticated.Client.Query(ctx, &query, variables); err != nil {
167+
if err := authenticated.Client().Query(ctx, &query, variables); err != nil {
168168
return blueprint{}, false, errors.Wrapf(err, "failed to query for blueprint ID %q", blueprintID)
169169
}
170170

internal/cmd/draw/data/workerpools.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type WorkerPool struct {
2121

2222
// Selected opens the selected worker pool in the browser.
2323
func (q *WorkerPool) Selected(row table.Row) error {
24-
return browser.OpenURL(authenticated.Client.URL("/stack/%s/run/%s", row[1], row[2]))
24+
return browser.OpenURL(authenticated.Client().URL("/stack/%s/run/%s", row[1], row[2]))
2525
}
2626

2727
// Columns returns the columns of the worker pool table.
@@ -73,7 +73,7 @@ func (q *WorkerPool) getPublicPoolRuns(ctx context.Context) ([]runsEdge, error)
7373
} `graphql:"publicWorkerPool"`
7474
}
7575

76-
if err := authenticated.Client.Query(ctx, &query, q.baseSearchParams()); err != nil {
76+
if err := authenticated.Client().Query(ctx, &query, q.baseSearchParams()); err != nil {
7777
return nil, errors.Wrap(err, "failed to query run list")
7878
}
7979

@@ -90,7 +90,7 @@ func (q *WorkerPool) getPrivatePoolRuns(ctx context.Context) ([]runsEdge, error)
9090
vars := q.baseSearchParams()
9191
vars["id"] = q.WokerPoolID
9292

93-
if err := authenticated.Client.Query(ctx, &query, vars); err != nil {
93+
if err := authenticated.Client().Query(ctx, &query, vars); err != nil {
9494
return nil, errors.Wrap(err, "failed to query run list")
9595
}
9696

internal/cmd/graphql/mcp.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func registerIntrospectSchemaTool(s *server.MCPServer) {
3838
)
3939

4040
s.AddTool(introspectTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
41+
authenticated.Ensure(ctx, nil)
4142
format := request.GetString("format", "summary")
4243

4344
var query struct {
@@ -120,7 +121,7 @@ func registerIntrospectSchemaTool(s *server.MCPServer) {
120121
} `graphql:"__schema"`
121122
}
122123

123-
if err := authenticated.Client.Query(ctx, &query, map[string]any{}); err != nil {
124+
if err := authenticated.Client().Query(ctx, &query, map[string]any{}); err != nil {
124125
return nil, errors.Wrap(err, "failed to introspect GraphQL schema")
125126
}
126127

@@ -144,6 +145,7 @@ func registerGetTypeDetailsTool(s *server.MCPServer) {
144145
)
145146

146147
s.AddTool(typeDetailsTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
148+
authenticated.Ensure(ctx, nil)
147149
typeName, err := request.RequireString("type_name")
148150
if err != nil {
149151
return nil, err
@@ -191,7 +193,7 @@ func registerGetTypeDetailsTool(s *server.MCPServer) {
191193
} `graphql:"__type(name: $name)"`
192194
}
193195

194-
if err := authenticated.Client.Query(ctx, &query, map[string]any{"name": graphql.String(typeName)}); err != nil {
196+
if err := authenticated.Client().Query(ctx, &query, map[string]any{"name": graphql.String(typeName)}); err != nil {
195197
return nil, errors.Wrap(err, "failed to get type details")
196198
}
197199

@@ -223,6 +225,7 @@ func registerSearchSchemaFieldsTool(s *server.MCPServer) {
223225
)
224226

225227
s.AddTool(searchTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
228+
authenticated.Ensure(ctx, nil)
226229
searchTerm, err := request.RequireString("search_term")
227230
if err != nil {
228231
return nil, err
@@ -271,7 +274,7 @@ func registerSearchSchemaFieldsTool(s *server.MCPServer) {
271274
} `graphql:"__schema"`
272275
}
273276

274-
if err := authenticated.Client.Query(ctx, &query, map[string]any{}); err != nil {
277+
if err := authenticated.Client().Query(ctx, &query, map[string]any{}); err != nil {
275278
return nil, errors.Wrap(err, "failed to introspect GraphQL schema")
276279
}
277280

@@ -505,6 +508,7 @@ func registerAuthenticationGuideTool(s *server.MCPServer) {
505508
)
506509

507510
s.AddTool(authTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
511+
authenticated.Ensure(ctx, nil)
508512
authMethod := request.GetString("auth_method", "all")
509513

510514
var guide strings.Builder

internal/cmd/module/create_version.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func createVersion(ctx context.Context, cliCmd *cli.Command) error {
3737
"version": version,
3838
}
3939

40-
if err := authenticated.Client.Mutate(ctx, &mutation, variables); err != nil {
40+
if err := authenticated.Client().Mutate(ctx, &mutation, variables); err != nil {
4141
return err
4242
}
4343

internal/cmd/module/delete_version.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func deleteVersion(ctx context.Context, cliCmd *cli.Command) error {
2525
"module": graphql.ID(moduleID),
2626
}
2727

28-
if err := authenticated.Client.Mutate(ctx, &mutation, variables); err != nil {
28+
if err := authenticated.Client().Mutate(ctx, &mutation, variables); err != nil {
2929
return err
3030
}
3131

0 commit comments

Comments
 (0)