Skip to content

Commit 6ec60dd

Browse files
jbayasomaru
authored andcommitted
mcp: pass TokenInfo to server handler (modelcontextprotocol#292)
If there is a TokenInfo in the request context of a StreamableServerTransport, then propagate it through to the ServerRequest that is passed to server methods like callTool. Fixes modelcontextprotocol#317.
1 parent 076baf9 commit 6ec60dd

File tree

7 files changed

+109
-16
lines changed

7 files changed

+109
-16
lines changed

auth/auth.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,41 @@ import (
1313
"time"
1414
)
1515

16+
// TokenInfo holds information from a bearer token.
1617
type TokenInfo struct {
1718
Scopes []string
1819
Expiration time.Time
20+
// TODO: add standard JWT fields
21+
Extra map[string]any
1922
}
2023

24+
// The error that a TokenVerifier should return if the token cannot be verified.
25+
var ErrInvalidToken = errors.New("invalid token")
26+
27+
// A TokenVerifier checks the validity of a bearer token, and extracts information
28+
// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken.
2129
type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error)
2230

31+
// RequireBearerTokenOptions are options for [RequireBearerToken].
2332
type RequireBearerTokenOptions struct {
24-
Scopes []string
33+
// The URL for the resource server metadata OAuth flow, to be returned as part
34+
// of the WWW-Authenticate header.
2535
ResourceMetadataURL string
36+
// The required scopes.
37+
Scopes []string
2638
}
2739

28-
var ErrInvalidToken = errors.New("invalid token")
29-
3040
type tokenInfoKey struct{}
3141

42+
// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none.
43+
func TokenInfoFromContext(ctx context.Context) *TokenInfo {
44+
ti := ctx.Value(tokenInfoKey{})
45+
if ti == nil {
46+
return nil
47+
}
48+
return ti.(*TokenInfo)
49+
}
50+
3251
// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier.
3352
// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds.
3453
// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header
@@ -75,7 +94,7 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke
7594
return nil, err.Error(), http.StatusInternalServerError
7695
}
7796

78-
// Check scopes.
97+
// Check scopes. All must be present.
7998
if opts != nil {
8099
// Note: quadratic, but N is small.
81100
for _, s := range opts.Scopes {

internal/jsonrpc2/messages.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ type Request struct {
5656
Method string
5757
// Params is either a struct or an array with the parameters of the method.
5858
Params json.RawMessage
59+
// Extra is additional information that does not appear on the wire. It can be
60+
// used to pass information from the application to the underlying transport.
61+
Extra any
5962
}
6063

6164
// Response is a Message used as a reply to a call Request.
@@ -67,6 +70,9 @@ type Response struct {
6770
Error error
6871
// id of the request this is a response to.
6972
ID ID
73+
// Extra is additional information that does not appear on the wire. It can be
74+
// used to pass information from the underlying transport to the application.
75+
Extra any
7076
}
7177

7278
// StringID creates a new string request identifier.

mcp/shared.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"strings"
2020
"time"
2121

22+
"github.com/modelcontextprotocol/go-sdk/auth"
2223
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2324
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2425
)
@@ -126,7 +127,8 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ
126127
}
127128

128129
mh := session.receivingMethodHandler()
129-
req := info.newRequest(session, params)
130+
re, _ := jreq.Extra.(*RequestExtra)
131+
req := info.newRequest(session, params, re)
130132
// mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol.
131133
res, err := mh(ctx, jreq.Method, req)
132134
if err != nil {
@@ -173,7 +175,7 @@ type methodInfo struct {
173175
// Unmarshal params from the wire into a Params struct.
174176
// Used on the receive side.
175177
unmarshalParams func(json.RawMessage) (Params, error)
176-
newRequest func(Session, Params) Request
178+
newRequest func(Session, Params, *RequestExtra) Request
177179
// Run the code when a call to the method is received.
178180
// Used on the receive side.
179181
handleMethod MethodHandler
@@ -208,7 +210,7 @@ const (
208210

209211
func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo {
210212
mi := newMethodInfo[P, R](flags)
211-
mi.newRequest = func(s Session, p Params) Request {
213+
mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request {
212214
r := &ClientRequest[P]{Session: s.(*ClientSession)}
213215
if p != nil {
214216
r.Params = p.(P)
@@ -223,19 +225,15 @@ func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHan
223225

224226
func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo {
225227
mi := newMethodInfo[P, R](flags)
226-
mi.newRequest = func(s Session, p Params) Request {
227-
r := &ServerRequest[P]{Session: s.(*ServerSession)}
228+
mi.newRequest = func(s Session, p Params, re *RequestExtra) Request {
229+
r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re}
228230
if p != nil {
229231
r.Params = p.(P)
230232
}
231233
return r
232234
}
233235
mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) {
234-
rf := &ServerRequest[P]{Session: req.GetSession().(*ServerSession)}
235-
if req.GetParams() != nil {
236-
rf.Params = req.GetParams().(P)
237-
}
238-
return d(ctx, rf)
236+
return d(ctx, req.(*ServerRequest[P]))
239237
})
240238
return mi
241239
}
@@ -391,6 +389,13 @@ type ClientRequest[P Params] struct {
391389
type ServerRequest[P Params] struct {
392390
Session *ServerSession
393391
Params P
392+
Extra *RequestExtra
393+
}
394+
395+
// RequestExtra is extra information included in requests, typically from
396+
// the transport layer.
397+
type RequestExtra struct {
398+
TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any
394399
}
395400

396401
func (*ClientRequest[P]) isRequest() {}

mcp/streamable.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"sync/atomic"
2222
"time"
2323

24+
"github.com/modelcontextprotocol/go-sdk/auth"
2425
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2526
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2627
)
@@ -579,12 +580,17 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
579580
// This also requires access to the negotiated version, which would either be
580581
// set by the MCP-Protocol-Version header, or would require peeking into the
581582
// session.
583+
if err != nil {
584+
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
585+
return
586+
}
582587
incoming, _, err := readBatch(body)
583588
if err != nil {
584589
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
585590
return
586591
}
587592
requests := make(map[jsonrpc.ID]struct{})
593+
tokenInfo := auth.TokenInfoFromContext(req.Context())
588594
for _, msg := range incoming {
589595
if req, ok := msg.(*jsonrpc.Request); ok {
590596
// Preemptively check that this is a valid request, so that we can fail
@@ -594,6 +600,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
594600
http.Error(w, err.Error(), http.StatusBadRequest)
595601
return
596602
}
603+
req.Extra = &RequestExtra{TokenInfo: tokenInfo}
597604
if req.IsCall() {
598605
requests[req.ID] = struct{}{}
599606
}
@@ -1182,6 +1189,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
11821189
return nil
11831190
}
11841191

1192+
// testAuth controls whether a fake Authorization header is added to outgoing requests.
1193+
// TODO: replace with a better mechanism when client-side auth is in place.
1194+
var testAuth = false
1195+
11851196
func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
11861197
c.mu.Lock()
11871198
defer c.mu.Unlock()
@@ -1192,6 +1203,9 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
11921203
if c.sessionID != "" {
11931204
req.Header.Set(sessionIDHeader, c.sessionID)
11941205
}
1206+
if testAuth {
1207+
req.Header.Set("Authorization", "Bearer foo")
1208+
}
11951209
}
11961210

11971211
func (c *streamableClientConn) handleJSON(resp *http.Response) {

mcp/streamable_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/google/go-cmp/cmp"
2727
"github.com/google/go-cmp/cmp/cmpopts"
2828
"github.com/google/jsonschema-go/jsonschema"
29+
"github.com/modelcontextprotocol/go-sdk/auth"
2930
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
3031
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
3132
)
@@ -1098,3 +1099,51 @@ func textContent(t *testing.T, res *CallToolResult) string {
10981099
}
10991100
return text.Text
11001101
}
1102+
1103+
func TestTokenInfo(t *testing.T) {
1104+
defer func(b bool) { testAuth = b }(testAuth)
1105+
testAuth = true
1106+
ctx := context.Background()
1107+
1108+
// Create a server with a tool that returns TokenInfo.
1109+
tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) {
1110+
return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil
1111+
}
1112+
server := NewServer(testImpl, nil)
1113+
AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo)
1114+
1115+
streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
1116+
verifier := func(context.Context, string) (*auth.TokenInfo, error) {
1117+
return &auth.TokenInfo{
1118+
Scopes: []string{"scope"},
1119+
// Expiration is far, far in the future.
1120+
Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC),
1121+
}, nil
1122+
}
1123+
handler := auth.RequireBearerToken(verifier, nil)(streamHandler)
1124+
httpServer := httptest.NewServer(handler)
1125+
defer httpServer.Close()
1126+
1127+
transport := NewStreamableClientTransport(httpServer.URL, nil)
1128+
client := NewClient(testImpl, nil)
1129+
session, err := client.Connect(ctx, transport, nil)
1130+
if err != nil {
1131+
t.Fatalf("client.Connect() failed: %v", err)
1132+
}
1133+
defer session.Close()
1134+
1135+
res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"})
1136+
if err != nil {
1137+
t.Fatal(err)
1138+
}
1139+
if len(res.Content) == 0 {
1140+
t.Fatal("missing content")
1141+
}
1142+
tc, ok := res.Content[0].(*TextContent)
1143+
if !ok {
1144+
t.Fatal("not TextContent")
1145+
}
1146+
if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w {
1147+
t.Errorf("got %q, want %q", g, w)
1148+
}
1149+
}

mcp/tool.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool
6868
res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{
6969
Session: req.Session,
7070
Params: params,
71+
Extra: req.Extra,
7172
})
7273
// TODO(rfindley): investigate why server errors are embedded in this strange way,
7374
// rather than returned as jsonrpc2 server errors.

mcp/transport.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ type serverConnection interface {
8686

8787
// A StdioTransport is a [Transport] that communicates over stdin/stdout using
8888
// newline-delimited JSON.
89-
type StdioTransport struct {
90-
}
89+
type StdioTransport struct{}
9190

9291
// Connect implements the [Transport] interface.
9392
func (*StdioTransport) Connect(context.Context) (Connection, error) {

0 commit comments

Comments
 (0)