Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion server/cmd/gram/deps.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/speakeasy-api/gram/server/internal/assets"
"github.com/speakeasy-api/gram/server/internal/attr"
"github.com/speakeasy-api/gram/server/internal/billing"
"github.com/speakeasy-api/gram/server/internal/cache"
"github.com/speakeasy-api/gram/server/internal/conv"
"github.com/speakeasy-api/gram/server/internal/encryption"
"github.com/speakeasy-api/gram/server/internal/externalmcp"
Expand Down Expand Up @@ -545,6 +546,7 @@ func newFunctionOrchestrator(
type mcpRegistryClientOptions struct {
pulseTenantID string
pulseAPIKey conv.Secret
cacheImpl cache.Cache
}

func newMCPRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvider, opts mcpRegistryClientOptions) (*externalmcp.RegistryClient, error) {
Expand All @@ -555,5 +557,5 @@ func newMCPRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvid

backend := externalmcp.NewPulseBackend(pulseURL, opts.pulseTenantID, opts.pulseAPIKey)

return externalmcp.NewRegistryClient(logger, tracerProvider, backend), nil
return externalmcp.NewRegistryClient(logger, tracerProvider, backend, opts.cacheImpl), nil
}
1 change: 1 addition & 0 deletions server/cmd/gram/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ func newStartCommand() *cli.Command {
mcpRegistryClient, err := newMCPRegistryClient(logger, tracerProvider, mcpRegistryClientOptions{
pulseTenantID: c.String("pulse-registry-tenant"),
pulseAPIKey: conv.NewSecret([]byte(c.String("pulse-registry-api-key"))),
cacheImpl: cache.NewRedisCacheAdapter(redisClient),
})
if err != nil {
return fmt.Errorf("failed to create mcp registry client: %w", err)
Expand Down
1 change: 1 addition & 0 deletions server/cmd/gram/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ func newWorkerCommand() *cli.Command {
mcpRegistryClient, err := newMCPRegistryClient(logger, tracerProvider, mcpRegistryClientOptions{
pulseTenantID: c.String("pulse-registry-tenant"),
pulseAPIKey: conv.NewSecret([]byte(c.String("pulse-registry-api-key"))),
cacheImpl: cache.NewRedisCacheAdapter(redisClient),
})
if err != nil {
return fmt.Errorf("failed to create mcp registry client: %w", err)
Expand Down
59 changes: 59 additions & 0 deletions server/internal/externalmcp/registry_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package externalmcp

import (
"crypto/sha256"
"fmt"
"net/http"
"sort"
"strings"
"time"

"github.com/speakeasy-api/gram/server/gen/types"
"github.com/speakeasy-api/gram/server/internal/cache"
)

const registryCacheTTL = 24 * time.Hour

// CachedListServersResponse wraps a list of external MCP servers for caching.
type CachedListServersResponse struct {
Key string
Servers []*types.ExternalMCPServer
}

var _ cache.CacheableObject[CachedListServersResponse] = (*CachedListServersResponse)(nil)

func (c CachedListServersResponse) CacheKey() string { return c.Key }
func (c CachedListServersResponse) AdditionalCacheKeys() []string { return []string{} }
func (c CachedListServersResponse) TTL() time.Duration { return registryCacheTTL }

// CachedServerDetailsResponse wraps server details for caching.
type CachedServerDetailsResponse struct {
Key string
Details *ServerDetails
}

var _ cache.CacheableObject[CachedServerDetailsResponse] = (*CachedServerDetailsResponse)(nil)

func (c CachedServerDetailsResponse) CacheKey() string { return c.Key }
func (c CachedServerDetailsResponse) AdditionalCacheKeys() []string { return []string{} }
func (c CachedServerDetailsResponse) TTL() time.Duration { return registryCacheTTL }

// registryCacheKey builds a cache key from a prefix and the request's URL + headers.
// Headers are sorted and hashed with SHA-256 to capture tenant/auth identity.
func registryCacheKey(prefix string, req *http.Request) string {
// Sort header keys for deterministic hashing
keys := make([]string, 0, len(req.Header))
for k := range req.Header {
keys = append(keys, k)
}
sort.Strings(keys)

h := sha256.New()
for _, k := range keys {
vals := req.Header[k]
sort.Strings(vals)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Cache key generation mutates HTTP request headers in place

The registryCacheKey function mutates the original HTTP request's headers by sorting header values in place.

Click to expand

Issue

At server/internal/externalmcp/registry_cache.go:53-54, the code gets a reference to the header values slice and sorts it in place:

vals := req.Header[k]
sort.Strings(vals)

In Go, req.Header[k] returns the actual slice stored in the map, not a copy. When sort.Strings(vals) is called, it modifies the original slice, thereby mutating the HTTP request's headers.

Impact

The cache key is generated before the HTTP request is sent (see registryclient.go:202 and registryclient.go:316). This means the request headers are mutated before c.httpClient.Do(req) is called. While header order typically doesn't affect HTTP semantics, this:

  1. Violates the principle of least surprise - a cache key function shouldn't have side effects
  2. Could cause issues if the HTTP client, transport, or any middleware depends on header order
  3. Could cause issues if the request object is inspected or reused after this call

Expected Behavior

The cache key generation should not modify the input request. A copy of the values should be made before sorting.

Recommendation: Make a copy of the slice before sorting:

vals := req.Header[k]
valsCopy := make([]string, len(vals))
copy(valsCopy, vals)
sort.Strings(valsCopy)
_, _ = fmt.Fprintf(h, "%s=%s\n", k, strings.Join(valsCopy, ","))
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

_, _ = fmt.Fprintf(h, "%s=%s\n", k, strings.Join(vals, ","))
}

return fmt.Sprintf("registry:%s:%s:%x", prefix, req.URL.String(), h.Sum(nil))
}
88 changes: 78 additions & 10 deletions server/internal/externalmcp/registryclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/speakeasy-api/gram/server/gen/types"
"github.com/speakeasy-api/gram/server/internal/attr"
"github.com/speakeasy-api/gram/server/internal/cache"
externalmcptypes "github.com/speakeasy-api/gram/server/internal/externalmcp/repo/types"
"github.com/speakeasy-api/gram/server/internal/o11y"
"github.com/speakeasy-api/gram/server/internal/oops"
Expand All @@ -28,23 +29,46 @@ type RegistryBackend interface {

// RegistryClient handles communication with external MCP registries.
type RegistryClient struct {
httpClient *http.Client
logger *slog.Logger
backend RegistryBackend
httpClient *http.Client
logger *slog.Logger
backend RegistryBackend
listCache *cache.TypedCacheObject[CachedListServersResponse]
detailsCache *cache.TypedCacheObject[CachedServerDetailsResponse]
}

// NewRegistryClient creates a new registry client.
func NewRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvider, backend RegistryBackend) *RegistryClient {
return &RegistryClient{
// NewRegistryClient creates a new registry client. The cacheImpl parameter is
// optional — pass nil to disable caching.
func NewRegistryClient(logger *slog.Logger, tracerProvider trace.TracerProvider, backend RegistryBackend, cacheImpl cache.Cache) *RegistryClient {
rc := &RegistryClient{
httpClient: &http.Client{
Transport: otelhttp.NewTransport(
retryablehttp.NewClient().StandardClient().Transport,
otelhttp.WithTracerProvider(tracerProvider),
),
},
logger: logger.With(attr.SlogComponent("mcp-registry-client")),
backend: backend,
logger: logger.With(attr.SlogComponent("mcp-registry-client")),
backend: backend,
listCache: nil,
detailsCache: nil,
}

if cacheImpl != nil {
listCache := cache.NewTypedObjectCache[CachedListServersResponse](
logger.With(attr.SlogCacheNamespace("registry-list")),
cacheImpl,
cache.SuffixNone,
)
rc.listCache = &listCache

detailsCache := cache.NewTypedObjectCache[CachedServerDetailsResponse](
logger.With(attr.SlogCacheNamespace("registry-details")),
cacheImpl,
cache.SuffixNone,
)
rc.detailsCache = &detailsCache
}

return rc
}

// Registry represents an MCP registry endpoint.
Expand Down Expand Up @@ -173,6 +197,16 @@ func (c *RegistryClient) ListServers(ctx context.Context, registry Registry, par
}
}

// Check cache after authorization so headers are populated.
if c.listCache != nil {
cacheKey := registryCacheKey("list", req)
cached, err := c.listCache.Get(ctx, cacheKey)
if err == nil {
c.logger.DebugContext(ctx, "registry list cache hit", attr.SlogCacheKey(cacheKey))
return cached.Servers, nil
}
}

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch from registry: %w", err)
Expand Down Expand Up @@ -231,6 +265,17 @@ func (c *RegistryClient) ListServers(ctx context.Context, registry Registry, par
servers = append(servers, server)
}

// Store in cache on success.
if c.listCache != nil {
cacheKey := registryCacheKey("list", req)
if storeErr := c.listCache.Store(ctx, CachedListServersResponse{
Key: cacheKey,
Servers: servers,
}); storeErr != nil {
c.logger.WarnContext(ctx, "failed to store registry list in cache", attr.SlogError(storeErr))
}
}

return servers, nil
}

Expand Down Expand Up @@ -266,6 +311,16 @@ func (c *RegistryClient) GetServerDetails(ctx context.Context, registry Registry
}
}

// Check cache after authorization so headers are populated.
if c.detailsCache != nil {
cacheKey := registryCacheKey("details", req)
cached, err := c.detailsCache.Get(ctx, cacheKey)
if err == nil {
c.logger.DebugContext(ctx, "registry details cache hit", attr.SlogCacheKey(cacheKey))
return cached.Details, nil
}
}

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send external mcp server details request: %w", err)
Expand Down Expand Up @@ -326,13 +381,26 @@ func (c *RegistryClient) GetServerDetails(ctx context.Context, registry Registry
tools = serverResp.Meta.Version.FifthRemote.Tools
}

return &ServerDetails{
details := &ServerDetails{
Name: serverResp.Server.Name,
Description: serverResp.Server.Description,
Version: serverResp.Server.Version,
RemoteURL: remoteURL,
TransportType: transportType,
Tools: tools,
Headers: headers,
}, nil
}

// Store in cache on success.
if c.detailsCache != nil {
cacheKey := registryCacheKey("details", req)
if storeErr := c.detailsCache.Store(ctx, CachedServerDetailsResponse{
Key: cacheKey,
Details: details,
}); storeErr != nil {
c.logger.WarnContext(ctx, "failed to store registry details in cache", attr.SlogError(storeErr))
}
}

return details, nil
}
8 changes: 4 additions & 4 deletions server/internal/externalmcp/registryclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestListServers_FiltersDeletedServers(t *testing.T) {
}))
defer server.Close()

client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{})
client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil)
client.httpClient = server.Client()
registry := Registry{
ID: uuid.New(),
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestGetServerDetails_OnlyStreamableHTTP(t *testing.T) {
}))
defer server.Close()

client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{})
client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil)
client.httpClient = server.Client()
registry := Registry{
ID: uuid.New(),
Expand Down Expand Up @@ -169,7 +169,7 @@ func TestGetServerDetails_OnlySSE(t *testing.T) {
}))
defer server.Close()

client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{})
client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil)
client.httpClient = server.Client()
registry := Registry{
ID: uuid.New(),
Expand Down Expand Up @@ -215,7 +215,7 @@ func TestGetServerDetails_PrefersStreamableHTTPOverSSE(t *testing.T) {
}))
defer server.Close()

client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{})
client := NewRegistryClient(logger, tracernoop.NewTracerProvider(), &PassthroughBackend{}, nil)
client.httpClient = server.Client()
registry := Registry{
ID: uuid.New(),
Expand Down
1 change: 1 addition & 0 deletions server/internal/testenv/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func NewMCPRegistryClient(t *testing.T, logger *slog.Logger, tracerProvider trac
NewLogger(t),
tracerProvider,
externalmcp.NewPulseBackend(pulseURL, "test-tenant-id", conv.NewSecret([]byte("test-api-key"))),
nil,
)
require.NoError(t, err, "expected mcp registry client to initialize without error")

Expand Down
Loading