Skip to content

Commit d7cc5bd

Browse files
authored
Merge pull request #15 from redhat-ai-tools/access_tokens
Add access token support for OCM authentication
2 parents 7f10a9e + 2824cb2 commit d7cc5bd

File tree

6 files changed

+345
-101
lines changed

6 files changed

+345
-101
lines changed

go.mod

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ require (
88
github.com/mark3labs/mcp-go v0.37.0
99
github.com/openshift-online/ocm-sdk-go v0.1.473
1010
github.com/spf13/cobra v1.9.1
11+
github.com/spf13/pflag v1.0.6
12+
github.com/stretchr/testify v1.9.0
1113
)
1214

1315
require (
@@ -17,6 +19,7 @@ require (
1719
github.com/buger/jsonparser v1.1.1 // indirect
1820
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
1921
github.com/cespare/xxhash/v2 v2.1.2 // indirect
22+
github.com/davecgh/go-spew v1.1.1 // indirect
2023
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
2124
github.com/golang/protobuf v1.5.3 // indirect
2225
github.com/google/uuid v1.6.0 // indirect
@@ -31,13 +34,13 @@ require (
3134
github.com/modern-go/reflect2 v1.0.2 // indirect
3235
github.com/openshift-online/ocm-api-model/clientapi v0.0.426 // indirect
3336
github.com/openshift-online/ocm-api-model/model v0.0.426 // indirect
37+
github.com/pmezard/go-difflib v1.0.0 // indirect
3438
github.com/prometheus/client_golang v1.12.1 // indirect
3539
github.com/prometheus/client_model v0.2.0 // indirect
3640
github.com/prometheus/common v0.32.1 // indirect
3741
github.com/prometheus/procfs v0.7.3 // indirect
3842
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect
3943
github.com/spf13/cast v1.7.1 // indirect
40-
github.com/spf13/pflag v1.0.6 // indirect
4144
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
4245
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
4346
golang.org/x/net v0.21.0 // indirect

pkg/mcp/server.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ func (s *Server) ServeSSE() error {
7777

7878
// getAuthenticatedOCMClient extracts token from context and creates authenticated OCM client
7979
func (s *Server) getAuthenticatedOCMClient(ctx context.Context) (*ocm.Client, error) {
80-
// Extract token based on transport mode
81-
token, err := ocm.ExtractTokenFromContext(ctx, s.config.Transport)
80+
// Extract token info from context
81+
tokenInfo, err := ocm.ExtractTokenInfoFromContext(ctx, s.config.Transport)
8282
if err != nil {
8383
glog.Errorf("Failed to extract OCM token: %v", err)
8484
return nil, err
8585
}
8686

8787
// Create OCM client and authenticate
8888
baseClient := ocm.NewClient(s.config.OCMBaseURL, s.config.OCMClientID)
89-
authenticatedClient, err := baseClient.WithToken(token)
89+
authenticatedClient, err := baseClient.WithToken(tokenInfo)
9090
if err != nil {
9191
authErr := fmt.Errorf("OCM authentication failed: %w", err)
9292
glog.Errorf("OCM client authentication failed: %v", authErr)

pkg/mcp/tools.go

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,8 @@ func (s *Server) handleWhoami(ctx context.Context, ctr mcp.CallToolRequest) (*mc
9191

9292
// Call OCM client to get current account
9393
account, err := client.GetCurrentAccount()
94-
if err != nil {
95-
// Handle OCM API errors with code and reason exposure
96-
if ocmErr, ok := err.(*ocm.OCMError); ok {
97-
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason)), nil
98-
}
99-
return NewTextResult("", errors.New("failed to get account: "+err.Error())), nil
94+
if errorResult := handleOCMError(err, "failed to get account"); errorResult != nil {
95+
return errorResult, nil
10096
}
10197

10298
// Format response using MCP layer formatter
@@ -124,12 +120,8 @@ func (s *Server) handleGetClusters(ctx context.Context, ctr mcp.CallToolRequest)
124120

125121
// Call OCM client to get clusters with state filter
126122
clusters, err := client.GetClusters(state)
127-
if err != nil {
128-
// Handle OCM API errors with code and reason exposure
129-
if ocmErr, ok := err.(*ocm.OCMError); ok {
130-
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason)), nil
131-
}
132-
return NewTextResult("", errors.New("failed to get clusters: "+err.Error())), nil
123+
if errorResult := handleOCMError(err, "failed to get clusters"); errorResult != nil {
124+
return errorResult, nil
133125
}
134126

135127
// Format response using MCP layer formatter
@@ -157,12 +149,8 @@ func (s *Server) handleGetCluster(ctx context.Context, ctr mcp.CallToolRequest)
157149

158150
// Call OCM client to get cluster details
159151
cluster, err := client.GetCluster(clusterID)
160-
if err != nil {
161-
// Handle OCM API errors with code and reason exposure
162-
if ocmErr, ok := err.(*ocm.OCMError); ok {
163-
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason)), nil
164-
}
165-
return NewTextResult("", errors.New("failed to get cluster: "+err.Error())), nil
152+
if errorResult := handleOCMError(err, "failed to get cluster"); errorResult != nil {
153+
return errorResult, nil
166154
}
167155

168156
// Format response using MCP layer formatter
@@ -282,12 +270,8 @@ func (s *Server) handleCreateROSAHCPCluster(ctx context.Context, ctr mcp.CallToo
282270
subnetIDs, availabilityZones, region,
283271
multiArchEnabled,
284272
)
285-
if err != nil {
286-
// Expose OCM API errors directly without modification
287-
if ocmErr, ok := err.(*ocm.OCMError); ok {
288-
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason)), nil
289-
}
290-
return NewTextResult("", errors.New("cluster creation failed: "+err.Error())), nil
273+
if errorResult := handleOCMError(err, "cluster creation"); errorResult != nil {
274+
return errorResult, nil
291275
}
292276

293277
// Format response using MCP layer formatter
@@ -325,3 +309,21 @@ func NewTextResult(content string, err error) *mcp.CallToolResult {
325309
},
326310
}
327311
}
312+
313+
// handleOCMError processes OCM API errors with enhanced token expiration detection
314+
// Returns an appropriate MCP CallToolResult for the error, or nil if no error
315+
func handleOCMError(err error, operation string) *mcp.CallToolResult {
316+
if err == nil {
317+
return nil
318+
}
319+
320+
// Handle OCM API errors with enhanced token expiration detection
321+
if ocmErr, ok := err.(*ocm.OCMError); ok {
322+
// Return specific error for token expiration
323+
if ocm.IsAccessTokenExpiredError(ocmErr) {
324+
return mcp.NewToolResultError("AUTHENTICATION_FAILED: " + ocmErr.Error())
325+
}
326+
return NewTextResult("", errors.New("OCM API Error ["+ocmErr.Code+"]: "+ocmErr.Reason))
327+
}
328+
return NewTextResult("", errors.New(operation+" failed: "+err.Error()))
329+
}

pkg/ocm/auth.go

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net/http"
77
"os"
8+
"strings"
89

910
"github.com/golang/glog"
1011
)
@@ -23,6 +24,9 @@ func RequestHeaderKey() contextKey {
2324
}
2425

2526
const (
27+
// Authorization header for access tokens (standard OAuth2)
28+
AuthorizationHeader = "Authorization"
29+
2630
// SSE transport header for OCM offline token
2731
// NOTE: http.Header keys are stored in canonical format, hence the different casing required here.
2832
// The provided X-OCM-OFFLINE-TOKEN header key will be translated to X-Ocm-Offline-Token
@@ -32,19 +36,61 @@ const (
3236
StdioTokenEnv = "OCM_OFFLINE_TOKEN"
3337
)
3438

35-
// ExtractTokenFromSSE extracts OCM offline token from X-OCM-OFFLINE-TOKEN header
36-
func ExtractTokenFromSSE(headers map[string]string) (string, error) {
37-
glog.V(3).Infof("SSE headers received: %+v", headers)
39+
// TokenInfo represents extracted token information
40+
type TokenInfo struct {
41+
Token string
42+
TokenType string // "access" or "offline"
43+
}
44+
45+
// ExtractBearerToken extracts access token from Authorization header
46+
func ExtractBearerToken(headers map[string]string) (string, error) {
47+
authHeader, exists := headers[AuthorizationHeader]
48+
if !exists || authHeader == "" {
49+
return "", fmt.Errorf("missing or empty %s header", AuthorizationHeader)
50+
}
51+
52+
// Check for "Bearer " prefix (case-insensitive)
53+
const bearerPrefix = "Bearer "
54+
if len(authHeader) <= len(bearerPrefix) {
55+
return "", fmt.Errorf("invalid Authorization header format")
56+
}
57+
58+
if !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) {
59+
return "", fmt.Errorf("Authorization header must use Bearer scheme")
60+
}
61+
62+
token := strings.TrimSpace(authHeader[len(bearerPrefix):])
63+
if token == "" {
64+
return "", fmt.Errorf("empty Bearer token")
65+
}
66+
67+
return token, nil
68+
}
69+
70+
// ExtractTokenInfoFromSSE extracts token info from SSE headers, preferring access tokens
71+
func ExtractTokenInfoFromSSE(headers map[string]string) (*TokenInfo, error) {
72+
// Log only header keys for security (never log header values which may contain tokens)
73+
headerKeys := make([]string, 0, len(headers))
74+
for key := range headers {
75+
headerKeys = append(headerKeys, key)
76+
}
77+
glog.V(3).Infof("SSE headers received (keys only): %v", headerKeys)
78+
79+
// Try Authorization header first (access token)
80+
if token, err := ExtractBearerToken(headers); err == nil {
81+
glog.V(3).Info("Found access token in Authorization header")
82+
return &TokenInfo{Token: token, TokenType: "access"}, nil
83+
}
3884

39-
// Try exact header name first
85+
// Fallback to offline token header
4086
token, exists := headers[SSETokenHeader]
4187
if exists && token != "" {
42-
glog.V(3).Infof("Found OCM token in header %s", SSETokenHeader)
43-
return token, nil
88+
glog.V(3).Infof("Found offline token in header %s", SSETokenHeader)
89+
return &TokenInfo{Token: token, TokenType: "offline"}, nil
4490
}
4591

46-
glog.Warningf("Missing or empty %s header in SSE request. Available headers: %+v", SSETokenHeader, headers)
47-
return "", fmt.Errorf("missing or empty %s header", SSETokenHeader)
92+
glog.Warningf("No valid tokens found in SSE headers. Available header keys: %v", headerKeys)
93+
return nil, fmt.Errorf("missing valid authentication token")
4894
}
4995

5096
// ExtractTokenFromStdio extracts OCM offline token from OCM_OFFLINE_TOKEN environment variable
@@ -56,40 +102,45 @@ func ExtractTokenFromStdio() (string, error) {
56102
return token, nil
57103
}
58104

59-
// ExtractTokenFromContext extracts OCM offline token from context based on transport mode
60-
func ExtractTokenFromContext(ctx context.Context, transport string) (string, error) {
105+
// ExtractTokenInfoFromContext extracts token info from context based on transport mode
106+
func ExtractTokenInfoFromContext(ctx context.Context, transport string) (*TokenInfo, error) {
61107
glog.V(2).Infof("Extracting token for transport mode: %s", transport)
62108

63109
switch transport {
64110
case "stdio":
65-
return ExtractTokenFromStdio()
111+
// Stdio only supports offline tokens via environment variable
112+
token, err := ExtractTokenFromStdio()
113+
if err != nil {
114+
return nil, err
115+
}
116+
return &TokenInfo{Token: token, TokenType: "offline"}, nil
117+
66118
case "sse":
67-
// For SSE transport, extract token from HTTP headers in the context
119+
// For SSE transport, extract from HTTP headers
68120
headers := extractHeadersFromContext(ctx)
69121
if headers != nil {
70-
glog.V(2).Infof("Headers found in context for SSE transport")
71-
if token, err := ExtractTokenFromSSE(headers); err == nil {
72-
return token, nil
122+
if tokenInfo, err := ExtractTokenInfoFromSSE(headers); err == nil {
123+
return tokenInfo, nil
73124
} else {
74125
glog.Warningf("Failed to extract token from SSE headers: %v", err)
75126
}
76-
} else {
77-
glog.Warningf("No headers found in context for SSE transport")
78127
}
79128

80-
// Fallback to environment variable for MVP compatibility
129+
// Fallback to environment variable
81130
token := os.Getenv(StdioTokenEnv)
82131
if token == "" {
83-
err := fmt.Errorf("SSE transport requires %s header or %s environment variable", SSETokenHeader, StdioTokenEnv)
132+
err := fmt.Errorf("SSE transport requires %s header or %s environment variable",
133+
AuthorizationHeader, StdioTokenEnv)
84134
glog.Errorf("Authentication failed: %v", err)
85-
return "", err
135+
return nil, err
86136
}
87137
glog.V(2).Infof("Using fallback environment variable for SSE transport")
88-
return token, nil
138+
return &TokenInfo{Token: token, TokenType: "offline"}, nil
139+
89140
default:
90141
err := fmt.Errorf("unsupported transport mode: %s", transport)
91142
glog.Errorf("Authentication failed: %v", err)
92-
return "", err
143+
return nil, err
93144
}
94145
}
95146

0 commit comments

Comments
 (0)