diff --git a/server/internal/mcp/servepublic_test.go b/server/internal/mcp/servepublic_test.go new file mode 100644 index 000000000..0a9af2a8b --- /dev/null +++ b/server/internal/mcp/servepublic_test.go @@ -0,0 +1,142 @@ +package mcp_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" + + "github.com/speakeasy-api/gram/server/internal/contextvalues" + "github.com/speakeasy-api/gram/server/internal/conv" + toolsets_repo "github.com/speakeasy-api/gram/server/internal/toolsets/repo" +) + +func TestService_ServePublic(t *testing.T) { + t.Parallel() + + t.Run("handles initialize request successfully", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Test MCP Server", + Slug: "test-mcp", + Description: conv.ToPGText("A test MCP server"), + HttpToolNames: []string{}, + DefaultEnvironmentSlug: pgtype.Text{String: "", Valid: false}, + McpSlug: conv.ToPGText("test-mcp"), + McpEnabled: true, + }) + require.NoError(t, err) + + reqBody := []map[string]any{ + { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{ + "name": "test-client", + "version": "1.0.0", + }, + }, + }, + } + bodyBytes, err := json.Marshal(reqBody) + require.NoError(t, err) + + mcpSlug := toolset.McpSlug.String + req := httptest.NewRequest(http.MethodPost, "/mcp/"+mcpSlug, bytes.NewReader(bodyBytes)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("mcpSlug", mcpSlug) + ctx = context.WithValue(ctx, chi.RouteCtxKey, rctx) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + err = ti.service.ServePublic(w, req) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, w.Code) + require.NotEmpty(t, w.Header().Get("Mcp-Session-Id")) + + var response map[string]any + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err, "response body: %s", w.Body.String()) + require.Equal(t, "2.0", response["jsonrpc"]) + require.InDelta(t, 1, response["id"], 0) + require.NotNil(t, response["result"]) + + result, ok := response["result"].(map[string]any) + require.True(t, ok, "result should be a map") + require.Equal(t, "2024-11-05", result["protocolVersion"]) + require.NotNil(t, result["capabilities"]) + require.NotNil(t, result["serverInfo"]) + }) + + t.Run("returns unauthorized for private mcp without authentication", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Private MCP Server", + Slug: "private-mcp", + Description: conv.ToPGText("A private MCP server"), + HttpToolNames: []string{}, + DefaultEnvironmentSlug: pgtype.Text{String: "", Valid: false}, + McpSlug: pgtype.Text{String: "", Valid: false}, + McpEnabled: false, + }) + require.NoError(t, err) + + reqBody := []map[string]any{ + { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + }, + } + bodyBytes, err := json.Marshal(reqBody) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/mcp/"+toolset.Slug, bytes.NewReader(bodyBytes)) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("mcpSlug", toolset.Slug) + unauthCtx := context.WithValue(t.Context(), chi.RouteCtxKey, rctx) + req = req.WithContext(unauthCtx) + + w := httptest.NewRecorder() + + err = ti.service.ServePublic(w, req) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} diff --git a/server/internal/mcp/setup_test.go b/server/internal/mcp/setup_test.go new file mode 100644 index 000000000..7879fd83d --- /dev/null +++ b/server/internal/mcp/setup_test.go @@ -0,0 +1,96 @@ +package mcp_test + +import ( + "context" + "log" + "net/url" + "os" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + + "github.com/speakeasy-api/gram/server/internal/auth/sessions" + "github.com/speakeasy-api/gram/server/internal/billing" + "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/environments" + "github.com/speakeasy-api/gram/server/internal/guardian" + "github.com/speakeasy-api/gram/server/internal/mcp" + "github.com/speakeasy-api/gram/server/internal/oauth" + "github.com/speakeasy-api/gram/server/internal/testenv" + "github.com/speakeasy-api/gram/server/internal/thirdparty/posthog" +) + +var ( + infra *testenv.Environment +) + +func TestMain(m *testing.M) { + res, cleanup, err := testenv.Launch(context.Background()) + if err != nil { + log.Fatalf("Failed to launch test infrastructure: %v", err) + os.Exit(1) + } + + infra = res + + code := m.Run() + + if err := cleanup(); err != nil { + log.Fatalf("Failed to cleanup test infrastructure: %v", err) + os.Exit(1) + } + + os.Exit(code) +} + +type testInstance struct { + service *mcp.Service + conn *pgxpool.Pool + sessionManager *sessions.Manager + serverURL *url.URL +} + +func newTestMCPService(t *testing.T) (context.Context, *testInstance) { + t.Helper() + + ctx := t.Context() + + logger := testenv.NewLogger(t) + tracerProvider := testenv.NewTracerProvider(t) + meterProvider := noop.NewMeterProvider() + + conn, err := infra.CloneTestDatabase(t, "mcptest") + require.NoError(t, err) + + redisClient, err := infra.NewRedisClient(t, 0) + require.NoError(t, err) + + billingClient := billing.NewStubClient(logger, tracerProvider) + + sessionManager, err := sessions.NewUnsafeManager(logger, conn, redisClient, cache.Suffix("gram-test"), "", billingClient) + require.NoError(t, err) + + ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) + + serverURL, err := url.Parse("http://0.0.0.0") + require.NoError(t, err) + + enc := testenv.NewEncryptionClient(t) + env := environments.NewEnvironmentEntries(logger, conn, enc) + posthog := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host") + cacheAdapter := cache.NewRedisCacheAdapter(redisClient) + guardianPolicy := guardian.NewDefaultPolicy() + oauthService := oauth.NewService(logger, tracerProvider, meterProvider, conn, serverURL, cacheAdapter, enc, env) + billingStub := billing.NewStubClient(logger, tracerProvider) + + svc := mcp.NewService(logger, tracerProvider, meterProvider, conn, sessionManager, env, posthog, serverURL, cacheAdapter, guardianPolicy, oauthService, billingStub, billingStub) + + return ctx, &testInstance{ + service: svc, + conn: conn, + sessionManager: sessionManager, + serverURL: serverURL, + } +} diff --git a/server/internal/mcpmetadata/setmetadata_test.go b/server/internal/mcpmetadata/setmetadata_test.go new file mode 100644 index 000000000..4cee775e6 --- /dev/null +++ b/server/internal/mcpmetadata/setmetadata_test.go @@ -0,0 +1,170 @@ +package mcpmetadata_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" + + gen "github.com/speakeasy-api/gram/server/gen/mcp_metadata" + "github.com/speakeasy-api/gram/server/gen/types" + assets_repo "github.com/speakeasy-api/gram/server/internal/assets/repo" + "github.com/speakeasy-api/gram/server/internal/contextvalues" + "github.com/speakeasy-api/gram/server/internal/conv" + toolsets_repo "github.com/speakeasy-api/gram/server/internal/toolsets/repo" +) + +func TestService_SetMcpMetadata(t *testing.T) { + t.Parallel() + + t.Run("creates metadata for toolset", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPMetadataService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Test MCP Server", + Slug: "test-mcp", + Description: conv.ToPGText("A test MCP server"), + HttpToolNames: []string{}, + DefaultEnvironmentSlug: pgtype.Text{String: "", Valid: false}, + McpSlug: pgtype.Text{String: "", Valid: false}, + McpEnabled: false, + }) + require.NoError(t, err) + + payload := &gen.SetMcpMetadataPayload{ + ToolsetSlug: types.Slug(toolset.Slug), + LogoAssetID: nil, + ExternalDocumentationURL: conv.Ptr("https://docs.example.com"), + SessionToken: nil, + ProjectSlugInput: nil, + } + + result, err := ti.service.SetMcpMetadata(ctx, payload) + require.NoError(t, err) + require.NotNil(t, result) + + require.NotEmpty(t, result.ID) + require.Equal(t, toolset.ID.String(), result.ToolsetID) + require.NotNil(t, result.ExternalDocumentationURL) + require.Equal(t, "https://docs.example.com", *result.ExternalDocumentationURL) + require.Nil(t, result.LogoAssetID) + }) + + t.Run("updates existing metadata", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPMetadataService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Test MCP Server", + Slug: "test-mcp-update", + Description: conv.ToPGText("A test MCP server"), + HttpToolNames: []string{}, + DefaultEnvironmentSlug: pgtype.Text{String: "", Valid: false}, + McpSlug: pgtype.Text{String: "", Valid: false}, + McpEnabled: false, + }) + require.NoError(t, err) + + firstPayload := &gen.SetMcpMetadataPayload{ + ToolsetSlug: types.Slug(toolset.Slug), + LogoAssetID: nil, + ExternalDocumentationURL: conv.Ptr("https://docs.example.com/v1"), + SessionToken: nil, + ProjectSlugInput: nil, + } + + firstResult, err := ti.service.SetMcpMetadata(ctx, firstPayload) + require.NoError(t, err) + require.NotNil(t, firstResult) + + secondPayload := &gen.SetMcpMetadataPayload{ + ToolsetSlug: types.Slug(toolset.Slug), + LogoAssetID: nil, + ExternalDocumentationURL: conv.Ptr("https://docs.example.com/v2"), + SessionToken: nil, + ProjectSlugInput: nil, + } + + secondResult, err := ti.service.SetMcpMetadata(ctx, secondPayload) + require.NoError(t, err) + require.NotNil(t, secondResult) + + require.Equal(t, firstResult.ID, secondResult.ID) + require.Equal(t, toolset.ID.String(), secondResult.ToolsetID) + require.NotNil(t, secondResult.ExternalDocumentationURL) + require.Equal(t, "https://docs.example.com/v2", *secondResult.ExternalDocumentationURL) + }) + + t.Run("sets logo asset ID", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPMetadataService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Test MCP Server", + Slug: "test-mcp-logo", + Description: conv.ToPGText("A test MCP server"), + HttpToolNames: []string{}, + DefaultEnvironmentSlug: pgtype.Text{String: "", Valid: false}, + McpSlug: pgtype.Text{String: "", Valid: false}, + McpEnabled: false, + }) + require.NoError(t, err) + + assetsRepo := assets_repo.New(ti.conn) + asset, err := assetsRepo.CreateAsset(ctx, assets_repo.CreateAssetParams{ + Name: "test-logo.png", + Url: "https://example.com/logo.png", + ProjectID: *authCtx.ProjectID, + Sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + Kind: "image", + ContentType: "image/png", + ContentLength: 1024, + }) + require.NoError(t, err) + + logoAssetID := asset.ID.String() + + payload := &gen.SetMcpMetadataPayload{ + ToolsetSlug: types.Slug(toolset.Slug), + LogoAssetID: &logoAssetID, + ExternalDocumentationURL: nil, + SessionToken: nil, + ProjectSlugInput: nil, + } + + result, err := ti.service.SetMcpMetadata(ctx, payload) + require.NoError(t, err) + require.NotNil(t, result) + + require.NotEmpty(t, result.ID) + require.Equal(t, toolset.ID.String(), result.ToolsetID) + require.NotNil(t, result.LogoAssetID) + require.Equal(t, logoAssetID, *result.LogoAssetID) + require.Nil(t, result.ExternalDocumentationURL) + }) +} diff --git a/server/internal/mcpmetadata/setup_test.go b/server/internal/mcpmetadata/setup_test.go new file mode 100644 index 000000000..3298aadcc --- /dev/null +++ b/server/internal/mcpmetadata/setup_test.go @@ -0,0 +1,82 @@ +package mcpmetadata_test + +import ( + "context" + "log" + "net/url" + "os" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" + + "github.com/speakeasy-api/gram/server/internal/auth/sessions" + "github.com/speakeasy-api/gram/server/internal/billing" + "github.com/speakeasy-api/gram/server/internal/cache" + "github.com/speakeasy-api/gram/server/internal/mcpmetadata" + "github.com/speakeasy-api/gram/server/internal/testenv" +) + +var ( + infra *testenv.Environment +) + +func TestMain(m *testing.M) { + res, cleanup, err := testenv.Launch(context.Background()) + if err != nil { + log.Fatalf("Failed to launch test infrastructure: %v", err) + os.Exit(1) + } + + infra = res + + code := m.Run() + + if err := cleanup(); err != nil { + log.Fatalf("Failed to cleanup test infrastructure: %v", err) + os.Exit(1) + } + + os.Exit(code) +} + +type testInstance struct { + service *mcpmetadata.Service + conn *pgxpool.Pool + sessionManager *sessions.Manager + serverURL *url.URL +} + +func newTestMCPMetadataService(t *testing.T) (context.Context, *testInstance) { + t.Helper() + + ctx := t.Context() + + logger := testenv.NewLogger(t) + tracerProvider := testenv.NewTracerProvider(t) + + conn, err := infra.CloneTestDatabase(t, "mcpmetadatatest") + require.NoError(t, err) + + redisClient, err := infra.NewRedisClient(t, 0) + require.NoError(t, err) + + billingClient := billing.NewStubClient(logger, tracerProvider) + + sessionManager, err := sessions.NewUnsafeManager(logger, conn, redisClient, cache.Suffix("gram-test"), "", billingClient) + require.NoError(t, err) + + ctx = testenv.InitAuthContext(t, ctx, conn, sessionManager) + + serverURL, err := url.Parse("http://0.0.0.0") + require.NoError(t, err) + + svc := mcpmetadata.NewService(logger, conn, sessionManager, serverURL) + + return ctx, &testInstance{ + service: svc, + conn: conn, + sessionManager: sessionManager, + serverURL: serverURL, + } +}