Skip to content

Commit 38995e9

Browse files
tiwilliaclaude
andcommitted
Add access token support for OCM authentication
- Add Authorization: Bearer <token> header support for short-lived access tokens - Maintain backward compatibility with X-OCM-OFFLINE-TOKEN header for offline tokens - Access tokens take precedence when both are provided - Add secure logging that only logs header keys, never token values - Add conservative token expiration detection with helpful error messages - Add AUTHENTICATION_FAILED MCP error code for expired tokens - Update all 4 MCP tools to support both authentication methods - Add comprehensive test coverage with realistic JWT tokens 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent baf2d6f commit 38995e9

File tree

6 files changed

+339
-81
lines changed

6 files changed

+339
-81
lines changed

go.mod

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ require (
88
github.com/mark3labs/mcp-go v0.37.0
99
github.com/openshift-online/ocm-sdk-go v0.1.473
1010
github.com/spf13/cobra v1.9.1
11+
github.com/spf13/pflag v1.0.6
12+
github.com/stretchr/testify v1.9.0
1113
)
1214

1315
require (
@@ -17,6 +19,7 @@ require (
1719
github.com/buger/jsonparser v1.1.1 // indirect
1820
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
1921
github.com/cespare/xxhash/v2 v2.1.2 // indirect
22+
github.com/davecgh/go-spew v1.1.1 // indirect
2023
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
2124
github.com/golang/protobuf v1.5.3 // indirect
2225
github.com/google/uuid v1.6.0 // indirect
@@ -31,13 +34,13 @@ require (
3134
github.com/modern-go/reflect2 v1.0.2 // indirect
3235
github.com/openshift-online/ocm-api-model/clientapi v0.0.426 // indirect
3336
github.com/openshift-online/ocm-api-model/model v0.0.426 // indirect
37+
github.com/pmezard/go-difflib v1.0.0 // indirect
3438
github.com/prometheus/client_golang v1.12.1 // indirect
3539
github.com/prometheus/client_model v0.2.0 // indirect
3640
github.com/prometheus/common v0.32.1 // indirect
3741
github.com/prometheus/procfs v0.7.3 // indirect
3842
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect
3943
github.com/spf13/cast v1.7.1 // indirect
40-
github.com/spf13/pflag v1.0.6 // indirect
4144
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
4245
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
4346
golang.org/x/net v0.21.0 // indirect

pkg/mcp/server.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,16 @@ func (s *Server) ServeSSE() error {
7676

7777
// getAuthenticatedOCMClient extracts token from context and creates authenticated OCM client
7878
func (s *Server) getAuthenticatedOCMClient(ctx context.Context) (*ocm.Client, error) {
79-
// Extract token based on transport mode
80-
token, err := ocm.ExtractTokenFromContext(ctx, s.config.Transport)
79+
// Extract token info from context
80+
tokenInfo, err := ocm.ExtractTokenInfoFromContext(ctx, s.config.Transport)
8181
if err != nil {
8282
glog.Errorf("Failed to extract OCM token: %v", err)
8383
return nil, err
8484
}
8585

8686
// Create OCM client and authenticate
8787
baseClient := ocm.NewClient(s.config.OCMBaseURL, s.config.OCMClientID)
88-
authenticatedClient, err := baseClient.WithToken(token)
88+
authenticatedClient, err := baseClient.WithToken(tokenInfo)
8989
if err != nil {
9090
authErr := fmt.Errorf("OCM authentication failed: %w", err)
9191
glog.Errorf("OCM client authentication failed: %v", authErr)

pkg/mcp/tools.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ func (s *Server) handleWhoami(ctx context.Context, ctr mcp.CallToolRequest) (*mc
6969
// Call OCM client to get current account
7070
account, err := client.GetCurrentAccount()
7171
if err != nil {
72-
// Handle OCM API errors with code and reason exposure
72+
// Handle OCM API errors with enhanced token expiration detection
7373
if ocmErr, ok := err.(*ocm.OCMError); ok {
74+
// Return specific error for token expiration
75+
if ocm.IsAccessTokenExpiredError(ocmErr) {
76+
return mcp.NewToolResultError("AUTHENTICATION_FAILED: " + ocmErr.Error()), nil
77+
}
7478
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason)), nil
7579
}
7680
return NewTextResult("", errors.New("failed to get account: "+err.Error())), nil
@@ -102,8 +106,12 @@ func (s *Server) handleGetClusters(ctx context.Context, ctr mcp.CallToolRequest)
102106
// Call OCM client to get clusters with state filter
103107
clusters, err := client.GetClusters(state)
104108
if err != nil {
105-
// Handle OCM API errors with code and reason exposure
109+
// Handle OCM API errors with enhanced token expiration detection
106110
if ocmErr, ok := err.(*ocm.OCMError); ok {
111+
// Return specific error for token expiration
112+
if ocm.IsAccessTokenExpiredError(ocmErr) {
113+
return mcp.NewToolResultError("AUTHENTICATION_FAILED: " + ocmErr.Error()), nil
114+
}
107115
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason)), nil
108116
}
109117
return NewTextResult("", errors.New("failed to get clusters: "+err.Error())), nil
@@ -135,8 +143,12 @@ func (s *Server) handleGetCluster(ctx context.Context, ctr mcp.CallToolRequest)
135143
// Call OCM client to get cluster details
136144
cluster, err := client.GetCluster(clusterID)
137145
if err != nil {
138-
// Handle OCM API errors with code and reason exposure
146+
// Handle OCM API errors with enhanced token expiration detection
139147
if ocmErr, ok := err.(*ocm.OCMError); ok {
148+
// Return specific error for token expiration
149+
if ocm.IsAccessTokenExpiredError(ocmErr) {
150+
return mcp.NewToolResultError("AUTHENTICATION_FAILED: " + ocmErr.Error()), nil
151+
}
140152
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason)), nil
141153
}
142154
return NewTextResult("", errors.New("failed to get cluster: "+err.Error())), nil
@@ -260,8 +272,12 @@ func (s *Server) handleCreateROSAHCPCluster(ctx context.Context, ctr mcp.CallToo
260272
multiArchEnabled,
261273
)
262274
if err != nil {
263-
// Expose OCM API errors directly without modification
275+
// Handle OCM API errors with enhanced token expiration detection
264276
if ocmErr, ok := err.(*ocm.OCMError); ok {
277+
// Return specific error for token expiration
278+
if ocm.IsAccessTokenExpiredError(ocmErr) {
279+
return mcp.NewToolResultError("AUTHENTICATION_FAILED: " + ocmErr.Error()), nil
280+
}
265281
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason)), nil
266282
}
267283
return NewTextResult("", errors.New("cluster creation failed: "+err.Error())), nil

pkg/ocm/auth.go

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net/http"
77
"os"
8+
"strings"
89

910
"github.com/golang/glog"
1011
)
@@ -23,6 +24,9 @@ func RequestHeaderKey() contextKey {
2324
}
2425

2526
const (
27+
// Authorization header for access tokens (standard OAuth2)
28+
AuthorizationHeader = "Authorization"
29+
2630
// SSE transport header for OCM offline token
2731
// NOTE: http.Header keys are stored in canonical format, hence the different casing required here.
2832
// The provided X-OCM-OFFLINE-TOKEN header key will be translated to X-Ocm-Offline-Token
@@ -32,19 +36,61 @@ const (
3236
StdioTokenEnv = "OCM_OFFLINE_TOKEN"
3337
)
3438

35-
// ExtractTokenFromSSE extracts OCM offline token from X-OCM-OFFLINE-TOKEN header
36-
func ExtractTokenFromSSE(headers map[string]string) (string, error) {
37-
glog.V(3).Infof("SSE headers received: %+v", headers)
39+
// TokenInfo represents extracted token information
40+
type TokenInfo struct {
41+
Token string
42+
TokenType string // "access" or "offline"
43+
}
44+
45+
// ExtractBearerToken extracts access token from Authorization header
46+
func ExtractBearerToken(headers map[string]string) (string, error) {
47+
authHeader, exists := headers[AuthorizationHeader]
48+
if !exists || authHeader == "" {
49+
return "", fmt.Errorf("missing or empty %s header", AuthorizationHeader)
50+
}
51+
52+
// Check for "Bearer " prefix (case-insensitive)
53+
const bearerPrefix = "Bearer "
54+
if len(authHeader) <= len(bearerPrefix) {
55+
return "", fmt.Errorf("invalid Authorization header format")
56+
}
57+
58+
if !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) {
59+
return "", fmt.Errorf("Authorization header must use Bearer scheme")
60+
}
61+
62+
token := strings.TrimSpace(authHeader[len(bearerPrefix):])
63+
if token == "" {
64+
return "", fmt.Errorf("empty Bearer token")
65+
}
66+
67+
return token, nil
68+
}
69+
70+
// ExtractTokenInfoFromSSE extracts token info from SSE headers, preferring access tokens
71+
func ExtractTokenInfoFromSSE(headers map[string]string) (*TokenInfo, error) {
72+
// Log only header keys for security (never log header values which may contain tokens)
73+
headerKeys := make([]string, 0, len(headers))
74+
for key := range headers {
75+
headerKeys = append(headerKeys, key)
76+
}
77+
glog.V(3).Infof("SSE headers received (keys only): %v", headerKeys)
78+
79+
// Try Authorization header first (access token)
80+
if token, err := ExtractBearerToken(headers); err == nil {
81+
glog.V(3).Info("Found access token in Authorization header")
82+
return &TokenInfo{Token: token, TokenType: "access"}, nil
83+
}
3884

39-
// Try exact header name first
85+
// Fallback to offline token header
4086
token, exists := headers[SSETokenHeader]
4187
if exists && token != "" {
42-
glog.V(3).Infof("Found OCM token in header %s", SSETokenHeader)
43-
return token, nil
88+
glog.V(3).Infof("Found offline token in header %s", SSETokenHeader)
89+
return &TokenInfo{Token: token, TokenType: "offline"}, nil
4490
}
4591

46-
glog.Warningf("Missing or empty %s header in SSE request. Available headers: %+v", SSETokenHeader, headers)
47-
return "", fmt.Errorf("missing or empty %s header", SSETokenHeader)
92+
glog.Warningf("No valid tokens found in SSE headers. Available header keys: %v", headerKeys)
93+
return nil, fmt.Errorf("missing valid authentication token")
4894
}
4995

5096
// ExtractTokenFromStdio extracts OCM offline token from OCM_OFFLINE_TOKEN environment variable
@@ -56,40 +102,45 @@ func ExtractTokenFromStdio() (string, error) {
56102
return token, nil
57103
}
58104

59-
// ExtractTokenFromContext extracts OCM offline token from context based on transport mode
60-
func ExtractTokenFromContext(ctx context.Context, transport string) (string, error) {
105+
// ExtractTokenInfoFromContext extracts token info from context based on transport mode
106+
func ExtractTokenInfoFromContext(ctx context.Context, transport string) (*TokenInfo, error) {
61107
glog.V(2).Infof("Extracting token for transport mode: %s", transport)
62108

63109
switch transport {
64110
case "stdio":
65-
return ExtractTokenFromStdio()
111+
// Stdio only supports offline tokens via environment variable
112+
token, err := ExtractTokenFromStdio()
113+
if err != nil {
114+
return nil, err
115+
}
116+
return &TokenInfo{Token: token, TokenType: "offline"}, nil
117+
66118
case "sse":
67-
// For SSE transport, extract token from HTTP headers in the context
119+
// For SSE transport, extract from HTTP headers
68120
headers := extractHeadersFromContext(ctx)
69121
if headers != nil {
70-
glog.V(2).Infof("Headers found in context for SSE transport")
71-
if token, err := ExtractTokenFromSSE(headers); err == nil {
72-
return token, nil
122+
if tokenInfo, err := ExtractTokenInfoFromSSE(headers); err == nil {
123+
return tokenInfo, nil
73124
} else {
74125
glog.Warningf("Failed to extract token from SSE headers: %v", err)
75126
}
76-
} else {
77-
glog.Warningf("No headers found in context for SSE transport")
78127
}
79128

80-
// Fallback to environment variable for MVP compatibility
129+
// Fallback to environment variable
81130
token := os.Getenv(StdioTokenEnv)
82131
if token == "" {
83-
err := fmt.Errorf("SSE transport requires %s header or %s environment variable", SSETokenHeader, StdioTokenEnv)
132+
err := fmt.Errorf("SSE transport requires %s header or %s environment variable",
133+
AuthorizationHeader, StdioTokenEnv)
84134
glog.Errorf("Authentication failed: %v", err)
85-
return "", err
135+
return nil, err
86136
}
87137
glog.V(2).Infof("Using fallback environment variable for SSE transport")
88-
return token, nil
138+
return &TokenInfo{Token: token, TokenType: "offline"}, nil
139+
89140
default:
90141
err := fmt.Errorf("unsupported transport mode: %s", transport)
91142
glog.Errorf("Authentication failed: %v", err)
92-
return "", err
143+
return nil, err
93144
}
94145
}
95146

0 commit comments

Comments
 (0)