diff --git a/router-tests/cache_warmup_test.go b/router-tests/cache_warmup_test.go index 824c869f69..50213d0561 100644 --- a/router-tests/cache_warmup_test.go +++ b/router-tests/cache_warmup_test.go @@ -3,27 +3,34 @@ package integration import ( "context" "net/http" + "net/http/httptest" + "os" + "path/filepath" + "syscall" "testing" "time" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/sdk/metric" - "go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest" - - nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" - - "go.opentelemetry.io/otel/sdk/metric/metricdata" - - "github.com/wundergraph/cosmo/router/pkg/otel" - + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/zap" - "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/core" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/controlplane/configpoller" + "github.com/wundergraph/cosmo/router/pkg/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + "go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest" + "go.uber.org/zap" ) +type fakeSelfRegister struct{} + +func (*fakeSelfRegister) Register(ctx context.Context) (*nodev1.RegistrationInfo, error) { + return nil, nil +} + func TestCacheWarmup(t *testing.T) { t.Parallel() @@ -619,6 +626,11 @@ func TestCacheWarmup(t *testing.T) { RouterOptions: []core.Option{ core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ Enabled: true, + Source: config.CacheWarmupSource{ + CdnSource: config.CacheWarmupCDNSource{ + Enabled: true, + }, + }, }), }, AssertCacheMetrics: &testenv.CacheMetricsAssertions{ @@ -678,6 +690,11 @@ func TestCacheWarmup(t *testing.T) { RouterOptions: []core.Option{ core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ Enabled: true, + Source: config.CacheWarmupSource{ + CdnSource: config.CacheWarmupCDNSource{ + Enabled: true, + }, + }, }), }, AssertCacheMetrics: &testenv.CacheMetricsAssertions{ @@ -721,6 +738,11 @@ func TestCacheWarmup(t *testing.T) { RouterOptions: []core.Option{ core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ Enabled: true, + Source: config.CacheWarmupSource{ + CdnSource: config.CacheWarmupCDNSource{ + Enabled: true, + }, + }, }), }, AssertCacheMetrics: &testenv.CacheMetricsAssertions{ @@ -754,6 +776,11 @@ func TestCacheWarmup(t *testing.T) { RouterOptions: []core.Option{ core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ Enabled: true, + Source: config.CacheWarmupSource{ + CdnSource: config.CacheWarmupCDNSource{ + Enabled: true, + }, + }, }), }, AssertCacheMetrics: &testenv.CacheMetricsAssertions{ @@ -914,6 +941,362 @@ func TestCacheWarmup(t *testing.T) { }) } +func TestInMemorySwitchoverCaching(t *testing.T) { + t.Parallel() + + t.Run("Verify the plan is cached on config restart when in memory switchover is enabled", func(t *testing.T) { + t.Parallel() + + pm := ConfigPollerMock{ + ready: make(chan struct{}), + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ + Enabled: true, + InMemoryFallback: true, + Source: config.CacheWarmupSource{ + CdnSource: config.CacheWarmupCDNSource{ + Enabled: true, + }, + }, + }), + core.WithConfigVersionHeader(true), + }, + RouterConfig: &testenv.RouterConfig{ + ConfigPollerFactory: func(config *nodev1.RouterConfig) configpoller.ConfigPoller { + pm.initConfig = config + return &pm + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `{ employees { id } }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, xEnv.RouterConfigVersionMain(), res.Response.Header.Get("X-Router-Config-Version")) + require.JSONEq(t, employeesIDData, res.Body) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + + // Wait for the config poller to be ready + <-pm.ready + + pm.initConfig.Version = "updated" + require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `{ employees { id } }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, "updated", res.Response.Header.Get("X-Router-Config-Version")) + require.JSONEq(t, employeesIDData, res.Body) + require.Equal(t, "HIT", res.Response.Header.Get("x-wg-execution-plan-cache")) + + }) + }) + + t.Run("Verify the plan is not cached on config restart when in cache warmer is disabled", func(t *testing.T) { + t.Parallel() + + pm := ConfigPollerMock{ + ready: make(chan struct{}), + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ + Enabled: false, + }), + core.WithConfigVersionHeader(true), + }, + RouterConfig: &testenv.RouterConfig{ + ConfigPollerFactory: func(config *nodev1.RouterConfig) configpoller.ConfigPoller { + pm.initConfig = config + return &pm + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `{ employees { id } }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, xEnv.RouterConfigVersionMain(), res.Response.Header.Get("X-Router-Config-Version")) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + + // Wait for the config poller to be ready + <-pm.ready + + pm.initConfig.Version = "updated" + require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `{ employees { id } }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, "updated", res.Response.Header.Get("X-Router-Config-Version")) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + }) + }) + + t.Run("Verify the plan is not cached on config restart when using default cache warmer", func(t *testing.T) { + t.Parallel() + + pm := ConfigPollerMock{ + ready: make(chan struct{}), + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ + Enabled: true, + InMemoryFallback: false, + Source: config.CacheWarmupSource{ + CdnSource: config.CacheWarmupCDNSource{ + Enabled: true, + }, + }, + }), + core.WithConfigVersionHeader(true), + }, + RouterConfig: &testenv.RouterConfig{ + ConfigPollerFactory: func(config *nodev1.RouterConfig) configpoller.ConfigPoller { + pm.initConfig = config + return &pm + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `{ employees { id customDetails: details { forename } } }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, xEnv.RouterConfigVersionMain(), res.Response.Header.Get("X-Router-Config-Version")) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + + // Wait for the config poller to be ready + <-pm.ready + + pm.initConfig.Version = "updated" + require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) + + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `{ employees { id customDetails: details { forename } } }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, "updated", res.Response.Header.Get("X-Router-Config-Version")) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + }) + }) + + t.Run("Verify plan is cached when static execution config is reloaded", func(t *testing.T) { + t.Parallel() + + // Create a temporary file for the router config + configFile := t.TempDir() + "/config.json" + + // Initial config with just the employees subgraph + writeTestConfig(t, "initial", configFile) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithConfigVersionHeader(true), + core.WithExecutionConfig(&core.ExecutionConfig{ + Path: configFile, + Watch: true, + WatchInterval: 100 * time.Millisecond, + }), + core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ + Enabled: true, + InMemoryFallback: true, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { hello }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, "initial", res.Response.Header.Get("X-Router-Config-Version")) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + + writeTestConfig(t, "updated", configFile) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { hello }`, + }) + assert.Equal(t, "updated", res.Response.Header.Get("X-Router-Config-Version")) + assert.Equal(t, "HIT", res.Response.Header.Get("x-wg-execution-plan-cache")) + }, 2*time.Second, 100*time.Millisecond) + }) + }) + + t.Run("Verify fallback is used when cdn source is enabled but cdn returns 404 internally", func(t *testing.T) { + t.Parallel() + + // Create a temporary file for the router config + configFile := t.TempDir() + "/config.json" + + // Initial config with just the employees subgraph + writeTestConfig(t, "initial", configFile) + + var impl *fakeSelfRegister = nil + + testenv.Run(t, &testenv.Config{ + CdnSever: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })), + RouterOptions: []core.Option{ + core.WithSelfRegistration(impl), + core.WithConfigVersionHeader(true), + core.WithExecutionConfig(&core.ExecutionConfig{ + Path: configFile, + Watch: true, + WatchInterval: 100 * time.Millisecond, + }), + core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ + Enabled: true, + InMemoryFallback: true, + Source: config.CacheWarmupSource{ + CdnSource: config.CacheWarmupCDNSource{ + Enabled: true, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { hello }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, "initial", res.Response.Header.Get("X-Router-Config-Version")) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + + writeTestConfig(t, "updated", configFile) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { hello }`, + }) + assert.Equal(t, "updated", res.Response.Header.Get("X-Router-Config-Version")) + assert.Equal(t, "HIT", res.Response.Header.Get("x-wg-execution-plan-cache")) + }, 2*time.Second, 100*time.Millisecond) + }) + }) + + t.Run("Verify fallback is used when cdn source is enabled but cdn returns unauthorized internally", func(t *testing.T) { + t.Parallel() + + // Create a temporary file for the router config + configFile := t.TempDir() + "/config.json" + + // Initial config with just the employees subgraph + writeTestConfig(t, "initial", configFile) + + var impl *fakeSelfRegister = nil + + testenv.Run(t, &testenv.Config{ + CdnSever: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })), + RouterOptions: []core.Option{ + core.WithSelfRegistration(impl), + core.WithConfigVersionHeader(true), + core.WithExecutionConfig(&core.ExecutionConfig{ + Path: configFile, + Watch: true, + WatchInterval: 100 * time.Millisecond, + }), + core.WithCacheWarmupConfig(&config.CacheWarmupConfiguration{ + Enabled: true, + InMemoryFallback: true, + Source: config.CacheWarmupSource{ + CdnSource: config.CacheWarmupCDNSource{ + Enabled: true, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { hello }`, + }) + require.Equal(t, 200, res.Response.StatusCode) + require.Equal(t, "initial", res.Response.Header.Get("X-Router-Config-Version")) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + + writeTestConfig(t, "updated", configFile) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { hello }`, + }) + assert.Equal(t, "updated", res.Response.Header.Get("X-Router-Config-Version")) + assert.Equal(t, "HIT", res.Response.Header.Get("x-wg-execution-plan-cache")) + }, 2*time.Second, 100*time.Millisecond) + }) + }) + + t.Run("Successfully persists cache across config change restarts", func(t *testing.T) { + t.Parallel() + + updateConfig := func(t *testing.T, xEnv *testenv.Environment, ctx context.Context, listenString string, config string) { + f, err := os.Create(filepath.Join(xEnv.GetRouterProcessCwd(), "config.yaml")) + require.NoError(t, err) + + _, err = f.WriteString(config) + require.NoError(t, err) + require.NoError(t, f.Close()) + + err = xEnv.SignalRouterProcess(syscall.SIGHUP) + require.NoError(t, err) + require.NoError(t, xEnv.WaitForServer(ctx, xEnv.RouterURL+"/"+listenString, 600, 60), "healthcheck post-reload failed") + } + + getConfigString := func(listenString string) string { + return ` +version: "1" + +readiness_check_path: "/` + listenString + `" + +cache_warmup: + enabled: true + in_memory_fallback: true + source: + cdn: + enabled: false + +engine: + debug: + enable_cache_response_headers: true +` + } + + err := testenv.RunRouterBinary(t, &testenv.Config{ + DemoMode: true, + }, testenv.RunRouterBinConfigOptions{}, func(t *testing.T, xEnv *testenv.Environment) { + // Verify initial start + t.Logf("running router binary, cwd: %s", xEnv.GetRouterProcessCwd()) + ctx := context.Background() + require.NoError(t, xEnv.WaitForServer(ctx, xEnv.RouterURL+"/health/ready", 600, 60), "healthcheck pre-reload failed") + + // Enable cache response headers first + listenString1 := "after1" + updateConfig(t, xEnv, ctx, listenString1, getConfigString(listenString1)) + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `query { hello }`}) + require.Equal(t, "MISS", res.Response.Header.Get("x-wg-execution-plan-cache")) + + // Verify cache persisted on restart + listenString2 := "after2" + updateConfig(t, xEnv, ctx, listenString2, getConfigString(listenString2)) + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `query { hello }`}) + require.Equal(t, "HIT", res.Response.Header.Get("x-wg-execution-plan-cache")) + }) + + require.NoError(t, err) + }) + +} + // findDataPoint finds a data point in a slice of histogram data points by matching // the value of WgEnginePlanCacheHit attribute func findDataPoint(t *testing.T, dataPoints []metricdata.HistogramDataPoint[float64], cacheHit bool) metricdata.HistogramDataPoint[float64] { diff --git a/router-tests/go.mod b/router-tests/go.mod index 3de9ec5fcf..35e7f08b1f 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -63,7 +63,7 @@ require ( github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dgraph-io/ristretto/v2 v2.1.0 // indirect + github.com/dgraph-io/ristretto/v2 v2.4.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/docker/cli v28.2.2+incompatible // indirect github.com/docker/distribution v2.8.3+incompatible // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index dc73f635fd..82e01b2338 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -65,10 +65,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/ristretto/v2 v2.1.0 h1:59LjpOJLNDULHh8MC4UaegN52lC4JnO2dITsie/Pa8I= -github.com/dgraph-io/ristretto/v2 v2.1.0/go.mod h1:uejeqfYXpUomfse0+lO+13ATz4TypQYLJZzBSAemuB4= -github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= -github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dgraph-io/ristretto/v2 v2.4.0 h1:I/w09yLjhdcVD2QV192UJcq8dPBaAJb9pOuMyNy0XlU= +github.com/dgraph-io/ristretto/v2 v2.4.0/go.mod h1:0KsrXtXvnv0EqnzyowllbVJB8yBonswa2lTCK2gGo9E= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54 h1:SG7nF6SRlWhcT7cNTs5R6Hk4V2lcmLz2NsG2VnInyNo= diff --git a/router/cmd/main.go b/router/cmd/main.go index 59cd9b8572..dd04ac292f 100644 --- a/router/cmd/main.go +++ b/router/cmd/main.go @@ -150,7 +150,8 @@ func Main() { } rs, err := core.NewRouterSupervisor(&core.RouterSupervisorOpts{ - BaseLogger: baseLogger, + BaseLogger: baseLogger, + SwitchoverConfig: core.NewSwitchoverConfig(baseLogger), ConfigFactory: func() (*config.Config, error) { result, err := config.LoadConfig(*configPathFlag) if err != nil { diff --git a/router/core/cache_warmup.go b/router/core/cache_warmup.go index b914567291..9472f0ab91 100644 --- a/router/core/cache_warmup.go +++ b/router/core/cache_warmup.go @@ -34,6 +34,7 @@ type CacheWarmupProcessor interface { type CacheWarmupConfig struct { Log *zap.Logger Source CacheWarmupSource + FallbackSource CacheWarmupSource Workers int ItemsPerSecond int Timeout time.Duration @@ -45,6 +46,7 @@ func WarmupCaches(ctx context.Context, cfg *CacheWarmupConfig) (err error) { w := &cacheWarmup{ log: cfg.Log.With(zap.String("component", "cache_warmup")), source: cfg.Source, + fallbackSource: cfg.FallbackSource, workers: cfg.Workers, itemsPerSecond: cfg.ItemsPerSecond, timeout: cfg.Timeout, @@ -92,6 +94,7 @@ func WarmupCaches(ctx context.Context, cfg *CacheWarmupConfig) (err error) { type cacheWarmup struct { log *zap.Logger source CacheWarmupSource + fallbackSource CacheWarmupSource workers int itemsPerSecond int timeout time.Duration @@ -105,6 +108,12 @@ func (w *cacheWarmup) run(ctx context.Context) (int, error) { defer cancel() items, err := w.source.LoadItems(ctx, w.log) + + // Try fallback if no items were loaded OR there was an error loading from main source + if len(items) == 0 || err != nil { + items, err = w.loadFromFallbackSource(ctx, err) + } + if err != nil { return 0, err } @@ -197,6 +206,25 @@ func (w *cacheWarmup) run(ctx context.Context) (int, error) { return len(items), nil } +func (w *cacheWarmup) loadFromFallbackSource(ctx context.Context, mainErr error) ([]*nodev1.Operation, error) { + if w.fallbackSource == nil { + return nil, mainErr + } + + fallbackItems, err := w.fallbackSource.LoadItems(ctx, w.log) + if err != nil { + // If fallback source also failed, log the fallback error and return the original error + w.log.Error("Failed to load cache warmup config from fallback source", zap.Error(err)) + return nil, mainErr + } + + // In case we went to the fallback because the main source had an error, log the original error + if mainErr != nil { + w.log.Error("Falling back to PlanSource due to error loading cache warmup config from CDN", zap.Error(mainErr)) + } + return fallbackItems, nil +} + type CacheWarmupPlanningProcessorOptions struct { OperationProcessor *OperationProcessor OperationPlanner *OperationPlanner diff --git a/router/core/cache_warmup_plans.go b/router/core/cache_warmup_plans.go new file mode 100644 index 0000000000..d5bd939a77 --- /dev/null +++ b/router/core/cache_warmup_plans.go @@ -0,0 +1,32 @@ +package core + +import ( + "context" + + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "go.uber.org/zap" +) + +var _ CacheWarmupSource = (*PlanSource)(nil) + +// PlanSource is a very basic cache warmup source that relies on the caller of this type to pass in the +// queries to be used for cache warming directly +type PlanSource struct { + queries []*nodev1.Operation +} + +// NewPlanSource creates a new PlanSource with the given queries from the caller +func NewPlanSource(switchoverCacheWarmerQueries []*nodev1.Operation) *PlanSource { + if switchoverCacheWarmerQueries == nil { + switchoverCacheWarmerQueries = make([]*nodev1.Operation, 0) + } + return &PlanSource{queries: switchoverCacheWarmerQueries} +} + +// LoadItems loads the items from the plan source when called by the cache warmer +func (c *PlanSource) LoadItems(_ context.Context, _ *zap.Logger) ([]*nodev1.Operation, error) { + if c == nil { + return nil, nil + } + return c.queries, nil +} diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 87aa96331c..ccaf1e56c7 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -112,12 +112,20 @@ type BuildGraphMuxOptions struct { EngineConfig *nodev1.EngineConfiguration ConfigSubgraphs []*nodev1.Subgraph RoutingUrlGroupings map[string]map[string]bool + SwitchoverConfig *SwitchoverConfig } func (b BuildGraphMuxOptions) IsBaseGraph() bool { return b.FeatureFlagName == "" } +// buildMultiGraphHandlerOptions contains the configuration options for building a multi-graph handler. +type buildMultiGraphHandlerOptions struct { + baseMux *chi.Mux + featureFlagConfigs map[string]*nodev1.FeatureFlagRouterExecutionConfig + switchoverConfig *SwitchoverConfig +} + // newGraphServer creates a new server instance. func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterConfig, proxy ProxyFunc) (*graphServer, error) { /* Older versions of composition will not populate a compatibility version. @@ -273,6 +281,7 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC EngineConfig: routerConfig.GetEngineConfig(), ConfigSubgraphs: routerConfig.GetSubgraphs(), RoutingUrlGroupings: routingUrlGroupings, + SwitchoverConfig: r.switchoverConfig, }) if err != nil { return nil, fmt.Errorf("failed to build base mux: %w", err) @@ -283,7 +292,11 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC s.logger.Info("Feature flags enabled", zap.Strings("flags", maps.Keys(featureFlagConfigMap))) } - multiGraphHandler, err := s.buildMultiGraphHandler(ctx, gm.mux, featureFlagConfigMap) + multiGraphHandler, err := s.buildMultiGraphHandler(ctx, buildMultiGraphHandlerOptions{ + baseMux: gm.mux, + featureFlagConfigs: featureFlagConfigMap, + switchoverConfig: r.switchoverConfig, + }) if err != nil { return nil, fmt.Errorf("failed to build feature flag handler: %w", err) } @@ -431,22 +444,22 @@ func getRoutingUrlGroupingForCircuitBreakers( func (s *graphServer) buildMultiGraphHandler( ctx context.Context, - baseMux *chi.Mux, - featureFlagConfigs map[string]*nodev1.FeatureFlagRouterExecutionConfig, + opts buildMultiGraphHandlerOptions, ) (http.HandlerFunc, error) { - if len(featureFlagConfigs) == 0 { - return baseMux.ServeHTTP, nil + if len(opts.featureFlagConfigs) == 0 { + return opts.baseMux.ServeHTTP, nil } - featureFlagToMux := make(map[string]*chi.Mux, len(featureFlagConfigs)) + featureFlagToMux := make(map[string]*chi.Mux, len(opts.featureFlagConfigs)) // Build all the muxes for the feature flags in serial to avoid any race conditions - for featureFlagName, executionConfig := range featureFlagConfigs { + for featureFlagName, executionConfig := range opts.featureFlagConfigs { gm, err := s.buildGraphMux(ctx, BuildGraphMuxOptions{ FeatureFlagName: featureFlagName, RouterConfigVersion: executionConfig.GetVersion(), EngineConfig: executionConfig.GetEngineConfig(), ConfigSubgraphs: executionConfig.Subgraphs, + SwitchoverConfig: opts.switchoverConfig, }) if err != nil { return nil, fmt.Errorf("failed to build mux for feature flag '%s': %w", featureFlagName, err) @@ -473,7 +486,7 @@ func (s *graphServer) buildMultiGraphHandler( return } - baseMux.ServeHTTP(w, r) + opts.baseMux.ServeHTTP(w, r) }, nil } @@ -519,12 +532,12 @@ type graphMux struct { validationCache *ristretto.Cache[uint64, bool] operationHashCache *ristretto.Cache[uint64, string] - accessLogsFileLogger *logging.BufferedLogger - metricStore rmetric.Store - prometheusCacheMetrics *rmetric.CacheMetrics - otelCacheMetrics *rmetric.CacheMetrics - streamMetricStore rmetric.StreamMetricStore - prometheusMetricsExporter *graphqlmetrics.PrometheusMetricsExporter + accessLogsFileLogger *logging.BufferedLogger + metricStore rmetric.Store + prometheusCacheMetrics *rmetric.CacheMetrics + otelCacheMetrics *rmetric.CacheMetrics + streamMetricStore rmetric.StreamMetricStore + prometheusMetricsExporter *graphqlmetrics.PrometheusMetricsExporter } // buildOperationCaches creates the caches for the graph mux. @@ -1296,7 +1309,8 @@ func (s *graphServer) buildGraphMux( DisableExposingVariablesContentOnValidationError: s.engineExecutionConfiguration.DisableExposingVariablesContentOnValidationError, ComplexityLimits: s.securityConfiguration.ComplexityLimits, }) - operationPlanner := NewOperationPlanner(executor, gm.planCache) + + operationPlanner := NewOperationPlanner(executor, gm.planCache, opts.SwitchoverConfig.inMemoryPlanCacheFallback.IsEnabled()) // We support the MCP only on the base graph. Feature flags are not supported yet. if opts.IsBaseGraph() && s.mcpServer != nil { @@ -1346,16 +1360,35 @@ func (s *graphServer) buildGraphMux( ) } - if s.Config.cacheWarmup.Source.Filesystem != nil { + switch { + case s.cacheWarmup.Source.Filesystem != nil: warmupConfig.Source = NewFileSystemSource(&FileSystemSourceConfig{ RootPath: s.Config.cacheWarmup.Source.Filesystem.Path, }) - } else { + // Enable in-memory switchover fallback when: + // - Router has cache warmer with inMemoryFallback enabled, AND + // - Either: + // - Using static execution config (not Cosmo): s.selfRegister == nil + // - OR CDN cache warmer is explictly disabled + case s.cacheWarmup.InMemoryFallback && (s.selfRegister == nil || !s.Config.cacheWarmup.Source.CdnSource.Enabled): + // We first utilize the existing plan cache (if it was already set, i.e., not on the first start) to create a list of queries + // and then reset the plan cache to the new plan cache for this start afterwards. + warmupConfig.Source = NewPlanSource(opts.SwitchoverConfig.inMemoryPlanCacheFallback.getPlanCacheForFF(opts.FeatureFlagName)) + opts.SwitchoverConfig.inMemoryPlanCacheFallback.setPlanCacheForFF(opts.FeatureFlagName, gm.planCache) + case s.Config.cacheWarmup.Source.CdnSource.Enabled: + // We use the in-memory cache as a fallback if enabled + // This is useful for when an issue occurs with the CDN when retrieving the required manifest + if s.cacheWarmup.InMemoryFallback { + warmupConfig.FallbackSource = NewPlanSource(opts.SwitchoverConfig.inMemoryPlanCacheFallback.getPlanCacheForFF(opts.FeatureFlagName)) + opts.SwitchoverConfig.inMemoryPlanCacheFallback.setPlanCacheForFF(opts.FeatureFlagName, gm.planCache) + } cdnSource, err := NewCDNSource(s.Config.cdnConfig.URL, s.graphApiToken, s.logger) if err != nil { return nil, fmt.Errorf("failed to create cdn source: %w", err) } warmupConfig.Source = cdnSource + default: + return nil, fmt.Errorf("unexpected cache warmer source provided") } err = WarmupCaches(ctx, warmupConfig) diff --git a/router/core/operation_planner.go b/router/core/operation_planner.go index 38c3b6aac5..12ff2f6929 100644 --- a/router/core/operation_planner.go +++ b/router/core/operation_planner.go @@ -21,13 +21,19 @@ type planWithMetaData struct { operationDocument, schemaDocument *ast.Document typeFieldUsageInfo []*graphqlschemausage.TypeFieldUsageInfo argumentUsageInfo []*graphqlmetricsv1.ArgumentUsageInfo + content string } type OperationPlanner struct { - sf singleflight.Group - planCache ExecutionPlanCache[uint64, *planWithMetaData] - executor *Executor - trackUsageInfo bool + sf singleflight.Group + planCache ExecutionPlanCache[uint64, *planWithMetaData] + executor *Executor + trackUsageInfo bool + operationContent bool +} + +type operationPlannerOpts struct { + operationContent bool } type ExecutionPlanCache[K any, V any] interface { @@ -35,19 +41,22 @@ type ExecutionPlanCache[K any, V any] interface { Get(key K) (V, bool) // Set the value in the cache with a cost. The cost depends on the cache implementation Set(key K, value V, cost int64) bool + // Iterate over all items in the cache (non-deterministic) + IterValues(cb func(v V) (stop bool)) // Close the cache and free resources Close() } -func NewOperationPlanner(executor *Executor, planCache ExecutionPlanCache[uint64, *planWithMetaData]) *OperationPlanner { +func NewOperationPlanner(executor *Executor, planCache ExecutionPlanCache[uint64, *planWithMetaData], storeContent bool) *OperationPlanner { return &OperationPlanner{ - planCache: planCache, - executor: executor, - trackUsageInfo: executor.TrackUsageInfo, + planCache: planCache, + executor: executor, + trackUsageInfo: executor.TrackUsageInfo, + operationContent: storeContent, } } -func (p *OperationPlanner) preparePlan(ctx *operationContext) (*planWithMetaData, error) { +func (p *OperationPlanner) preparePlan(ctx *operationContext, opts operationPlannerOpts) (*planWithMetaData, error) { doc, report := astparser.ParseGraphqlDocumentString(ctx.content) if report.HasErrors() { return nil, &reportError{report: &report} @@ -81,6 +90,10 @@ func (p *OperationPlanner) preparePlan(ctx *operationContext) (*planWithMetaData schemaDocument: p.executor.RouterSchema, } + if opts.operationContent { + out.content = ctx.Content() + } + if p.trackUsageInfo { out.typeFieldUsageInfo = graphqlschemausage.GetTypeFieldUsageInfo(preparedPlan) out.argumentUsageInfo, err = graphqlschemausage.GetArgumentUsageInfo(&doc, p.executor.RouterSchema, ctx.variables, preparedPlan, ctx.remapVariables) @@ -106,7 +119,7 @@ func (p *OperationPlanner) plan(opContext *operationContext, options PlanOptions skipCache := options.TraceOptions.Enable || options.ExecutionOptions.IncludeQueryPlanInResponse if skipCache { - prepared, err := p.preparePlan(opContext) + prepared, err := p.preparePlan(opContext, operationPlannerOpts{operationContent: false}) if err != nil { return err } @@ -134,7 +147,7 @@ func (p *OperationPlanner) plan(opContext *operationContext, options PlanOptions // this ensures that we only prepare the plan once for this operation ID operationIDStr := strconv.FormatUint(operationID, 10) sharedPreparedPlan, err, _ := p.sf.Do(operationIDStr, func() (interface{}, error) { - prepared, err := p.preparePlan(opContext) + prepared, err := p.preparePlan(opContext, operationPlannerOpts{operationContent: p.operationContent}) if err != nil { return nil, err } diff --git a/router/core/restart_switchover_config.go b/router/core/restart_switchover_config.go new file mode 100644 index 0000000000..bad7e98c4c --- /dev/null +++ b/router/core/restart_switchover_config.go @@ -0,0 +1,172 @@ +package core + +import ( + "sync" + + "github.com/dgraph-io/ristretto/v2" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "go.uber.org/zap" +) + +type planCache = *ristretto.Cache[uint64, *planWithMetaData] + +// SwitchoverConfig This file describes any configuration which should persist or be shared across router restarts +type SwitchoverConfig struct { + inMemoryPlanCacheFallback *InMemoryPlanCacheFallback +} + +func NewSwitchoverConfig(logger *zap.Logger) *SwitchoverConfig { + return &SwitchoverConfig{ + inMemoryPlanCacheFallback: &InMemoryPlanCacheFallback{ + logger: logger, + }, + } +} + +// UpdateSwitchoverConfig updates the switchover config based on the provided config. +func (s *SwitchoverConfig) UpdateSwitchoverConfig(config *Config) { + s.inMemoryPlanCacheFallback.updateStateFromConfig(config) +} + +// CleanupFeatureFlags cleans up anything related to unused feature flags due to being now excluded +// from the execution config +func (s *SwitchoverConfig) CleanupFeatureFlags(routerCfg *nodev1.RouterConfig) { + s.inMemoryPlanCacheFallback.cleanupUnusedFeatureFlags(routerCfg) +} + +func (s *SwitchoverConfig) OnRouterConfigReload() { + // For cases of router config changes (not execution config), we shut down before creating the + // graph mux, because we need to initialize everything from the start + // This causes problems in using the previous planCache reference as it gets closed, so we need to + // copy it over before it gets closed, and we restart with config changes + + // There can be inflight requests when this is called even though it's called in the restart path, + // This is because this is called before the router instance is shutdown before being reloaded + s.inMemoryPlanCacheFallback.extractQueriesAndOverridePlanCache() +} + +// InMemoryPlanCacheFallback is a store that stores either queries or references to the planner cache for use with the cache warmer +type InMemoryPlanCacheFallback struct { + mu sync.RWMutex + queriesForFeatureFlag map[string]any + logger *zap.Logger +} + +// updateStateFromConfig updates the internal state of the in-memory switchover cache based on the provided config +func (c *InMemoryPlanCacheFallback) updateStateFromConfig(config *Config) { + enabled := config.cacheWarmup != nil && + config.cacheWarmup.Enabled && + config.cacheWarmup.InMemoryFallback + + c.mu.Lock() + defer c.mu.Unlock() + + // If the configuration change occurred which disabled or enabled the switchover cache, we need to update the internal state + if enabled { + // Only initialize if its nil because its a first start, we dont want to override any old data in a map + if c.queriesForFeatureFlag == nil { + c.queriesForFeatureFlag = make(map[string]any) + } + return + } + + // Reset the map to free up memory + c.queriesForFeatureFlag = nil +} + +// IsEnabled returns whether the in-memory switchover cache is enabled +func (c *InMemoryPlanCacheFallback) IsEnabled() bool { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.queriesForFeatureFlag != nil +} + +// getPlanCacheForFF gets the plan cache in the []*nodev1.Operation format for a specific feature flag key +func (c *InMemoryPlanCacheFallback) getPlanCacheForFF(featureFlagKey string) []*nodev1.Operation { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.queriesForFeatureFlag == nil { + return nil + } + + switch cache := c.queriesForFeatureFlag[featureFlagKey].(type) { + case planCache: + return convertToNodeOperation(cache) + case []*nodev1.Operation: + return cache + // This would occur during the first start (we add this case to specifically log any other cases) + case nil: + return nil + // This should not happen as we cannot have any types other than the above + default: + c.logger.Error("unexpected type") + return nil + } +} + +// setPlanCacheForFF sets the plan cache for a specific feature flag key +func (c *InMemoryPlanCacheFallback) setPlanCacheForFF(featureFlagKey string, cache planCache) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.queriesForFeatureFlag == nil || cache == nil { + return + } + c.queriesForFeatureFlag[featureFlagKey] = cache +} + +// extractQueriesAndOverridePlanCache extracts the queries from the plan cache and overrides the internal map +func (c *InMemoryPlanCacheFallback) extractQueriesAndOverridePlanCache() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.queriesForFeatureFlag == nil { + return + } + + switchoverMap := make(map[string]any) + for k, v := range c.queriesForFeatureFlag { + if cache, ok := v.(planCache); ok { + switchoverMap[k] = convertToNodeOperation(cache) + } + } + c.queriesForFeatureFlag = switchoverMap +} + +// cleanupUnusedFeatureFlags removes any feature flags that were removed from the execution config +// after a schema / execution config change +func (c *InMemoryPlanCacheFallback) cleanupUnusedFeatureFlags(routerCfg *nodev1.RouterConfig) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.queriesForFeatureFlag == nil || routerCfg.FeatureFlagConfigs == nil { + return + } + + for ffName := range c.queriesForFeatureFlag { + // Skip the base which is "" + if ffName == "" { + continue + } + if _, exists := routerCfg.FeatureFlagConfigs.ConfigByFeatureFlagName[ffName]; !exists { + delete(c.queriesForFeatureFlag, ffName) + } + } +} + +func convertToNodeOperation(data planCache) []*nodev1.Operation { + items := make([]*nodev1.Operation, 0) + + // Ensure any buffered writes have been applied + data.Wait() + + data.IterValues(func(v *planWithMetaData) (stop bool) { + items = append(items, &nodev1.Operation{ + Request: &nodev1.OperationRequest{Query: v.content}, + }) + return false + }) + return items +} diff --git a/router/core/restart_switchover_config_test.go b/router/core/restart_switchover_config_test.go new file mode 100644 index 0000000000..33105a240e --- /dev/null +++ b/router/core/restart_switchover_config_test.go @@ -0,0 +1,433 @@ +package core + +import ( + "testing" + + "github.com/dgraph-io/ristretto/v2" + "github.com/stretchr/testify/require" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "go.uber.org/zap" +) + +func TestInMemorySwitchoverCache_UpdateInMemorySwitchoverCacheForConfigChanges(t *testing.T) { + t.Parallel() + t.Run("enable cache from disabled state", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{} + cfg := &Config{ + cacheWarmup: &config.CacheWarmupConfiguration{ + Enabled: true, + InMemoryFallback: true, + }, + } + + cache.updateStateFromConfig(cfg) + + require.NotNil(t, cache.queriesForFeatureFlag) + require.Empty(t, cache.queriesForFeatureFlag) + }) + + t.Run("disable cache from enabled state", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + cache.queriesForFeatureFlag["test"] = nil + + cfg := &Config{ + cacheWarmup: &config.CacheWarmupConfiguration{ + Enabled: false, + }, + } + + cache.updateStateFromConfig(cfg) + + require.Nil(t, cache.queriesForFeatureFlag) + }) + + t.Run("update when already enabled keeps existing data", func(t *testing.T) { + t.Parallel() + existingMap := make(map[string]any) + existingMap["test"] = nil + + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: existingMap, + } + + cfg := &Config{ + cacheWarmup: &config.CacheWarmupConfiguration{ + Enabled: true, + InMemoryFallback: true, + }, + } + + cache.updateStateFromConfig(cfg) + + require.NotNil(t, cache.queriesForFeatureFlag) + require.Len(t, cache.queriesForFeatureFlag, 1) + require.Contains(t, cache.queriesForFeatureFlag, "test") + }) + + t.Run("update when already disabled", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: nil, + } + + cfg := &Config{ + cacheWarmup: &config.CacheWarmupConfiguration{ + Enabled: false, + }, + } + + cache.updateStateFromConfig(cfg) + + require.Nil(t, cache.queriesForFeatureFlag) + }) + + t.Run("nil cacheWarmup config disables cache", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + + cfg := &Config{ + cacheWarmup: nil, + } + + cache.updateStateFromConfig(cfg) + + require.Nil(t, cache.queriesForFeatureFlag) + }) + + t.Run("cacheWarmup enabled but InMemoryFallback disabled", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{} + + cfg := &Config{ + cacheWarmup: &config.CacheWarmupConfiguration{ + Enabled: true, + InMemoryFallback: false, + }, + } + + cache.updateStateFromConfig(cfg) + + require.Nil(t, cache.queriesForFeatureFlag) + }) +} + +func TestInMemorySwitchOverCache_GetPlanCacheForFF(t *testing.T) { + t.Parallel() + t.Run("returns operations for existing feature flag when enabled with ristretto cache", func(t *testing.T) { + t.Parallel() + mockCache, err := ristretto.NewCache(&ristretto.Config[uint64, *planWithMetaData]{ + MaxCost: 10000, + NumCounters: 10000000, + IgnoreInternalCost: true, + BufferItems: 64, + }) + require.NoError(t, err) + + query1 := "query { test1 }" + query2 := "query { test2 }" + + mockCache.Set(1, &planWithMetaData{content: query1}, 1) + mockCache.Set(2, &planWithMetaData{content: query2}, 1) + mockCache.Wait() + + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + cache.queriesForFeatureFlag["test-ff"] = mockCache + + result := cache.getPlanCacheForFF("test-ff") + + require.NotNil(t, result) + require.IsType(t, []*nodev1.Operation{}, result) + require.Len(t, result, 2) + + // Verify the operations contain the expected queries (order may vary) + queries := make([]string, len(result)) + for i, op := range result { + queries[i] = op.Request.Query + } + require.ElementsMatch(t, []string{query1, query2}, queries) + }) + + t.Run("returns operations for existing feature flag when enabled with operation slice", func(t *testing.T) { + t.Parallel() + expectedOps := []*nodev1.Operation{ + {Request: &nodev1.OperationRequest{Query: "query { test1 }"}}, + {Request: &nodev1.OperationRequest{Query: "query { test2 }"}}, + } + + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + cache.queriesForFeatureFlag["test-ff"] = expectedOps + + result := cache.getPlanCacheForFF("test-ff") + + require.NotNil(t, result) + require.Equal(t, expectedOps, result) + }) + + t.Run("returns empty slice for non-existent feature flag", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + logger: zap.NewNop(), + queriesForFeatureFlag: make(map[string]any), + } + + result := cache.getPlanCacheForFF("non-existent") + require.Nil(t, result) + }) + + t.Run("returns nil when cache is disabled", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: nil, + } + + result := cache.getPlanCacheForFF("test-ff") + + require.Nil(t, result) + }) +} + +func TestInMemorySwitchOverCache_SetPlanCacheForFF(t *testing.T) { + t.Parallel() + t.Run("sets cache for feature flag when enabled", func(t *testing.T) { + t.Parallel() + mockCache, err := ristretto.NewCache(&ristretto.Config[uint64, *planWithMetaData]{ + MaxCost: 100, + NumCounters: 10000, + BufferItems: 64, + }) + require.NoError(t, err) + + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + + cache.setPlanCacheForFF("test-ff", mockCache) + + require.Contains(t, cache.queriesForFeatureFlag, "test-ff") + // Verify it's the same cache by comparing the underlying pointer + require.Equal(t, cache.queriesForFeatureFlag["test-ff"], mockCache) + }) + + t.Run("does not set cache when disabled", func(t *testing.T) { + t.Parallel() + mockCache, err := ristretto.NewCache(&ristretto.Config[uint64, *planWithMetaData]{ + MaxCost: 100, + NumCounters: 10000, + BufferItems: 64, + }) + require.NoError(t, err) + + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: nil, + } + + cache.setPlanCacheForFF("test-ff", mockCache) + + require.Nil(t, cache.queriesForFeatureFlag) + }) + + t.Run("does not set nil cache", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + + cache.setPlanCacheForFF("test-ff", nil) + + require.NotContains(t, cache.queriesForFeatureFlag, "test-ff") + }) +} + +func TestInMemorySwitchOverCache_CleanupUnusedFeatureFlags(t *testing.T) { + t.Parallel() + t.Run("removes unused feature flags", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + cache.queriesForFeatureFlag["ff1"] = nil + cache.queriesForFeatureFlag["ff2"] = nil + cache.queriesForFeatureFlag["ff3"] = nil + + routerCfg := &nodev1.RouterConfig{ + FeatureFlagConfigs: &nodev1.FeatureFlagRouterExecutionConfigs{ + ConfigByFeatureFlagName: map[string]*nodev1.FeatureFlagRouterExecutionConfig{ + "ff1": {}, + "ff2": {}, + }, + }, + } + + cache.cleanupUnusedFeatureFlags(routerCfg) + + require.Len(t, cache.queriesForFeatureFlag, 2) + require.Contains(t, cache.queriesForFeatureFlag, "ff1") + require.Contains(t, cache.queriesForFeatureFlag, "ff2") + require.NotContains(t, cache.queriesForFeatureFlag, "ff3") + }) + + t.Run("keeps empty string feature flag", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + cache.queriesForFeatureFlag[""] = nil + cache.queriesForFeatureFlag["ff1"] = nil + + routerCfg := &nodev1.RouterConfig{ + FeatureFlagConfigs: &nodev1.FeatureFlagRouterExecutionConfigs{ + ConfigByFeatureFlagName: map[string]*nodev1.FeatureFlagRouterExecutionConfig{}, + }, + } + + cache.cleanupUnusedFeatureFlags(routerCfg) + + require.Len(t, cache.queriesForFeatureFlag, 1) + require.Contains(t, cache.queriesForFeatureFlag, "") + require.NotContains(t, cache.queriesForFeatureFlag, "ff1") + }) + + t.Run("does nothing when cache is disabled", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: nil, + } + + routerCfg := &nodev1.RouterConfig{ + FeatureFlagConfigs: &nodev1.FeatureFlagRouterExecutionConfigs{ + ConfigByFeatureFlagName: map[string]*nodev1.FeatureFlagRouterExecutionConfig{}, + }, + } + + cache.cleanupUnusedFeatureFlags(routerCfg) + + // Should still be nil because cleanup is skipped when disabled + require.Nil(t, cache.queriesForFeatureFlag) + }) + + t.Run("does nothing when FeatureFlagConfigs is nil", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + cache.queriesForFeatureFlag["ff1"] = nil + + routerCfg := &nodev1.RouterConfig{ + FeatureFlagConfigs: nil, + } + + cache.cleanupUnusedFeatureFlags(routerCfg) + + // Should still have ff1 because FeatureFlagConfigs is nil + require.Len(t, cache.queriesForFeatureFlag, 1) + require.Contains(t, cache.queriesForFeatureFlag, "ff1") + }) +} + +func TestInMemorySwitchOverCache_ProcessOnConfigChangeRestart(t *testing.T) { + t.Parallel() + t.Run("converts ristretto caches to operation slices", func(t *testing.T) { + t.Parallel() + mockCache1, err := ristretto.NewCache(&ristretto.Config[uint64, *planWithMetaData]{ + MaxCost: 10000, + NumCounters: 10000000, + IgnoreInternalCost: true, + BufferItems: 64, + }) + require.NoError(t, err) + + mockCache2, err := ristretto.NewCache(&ristretto.Config[uint64, *planWithMetaData]{ + MaxCost: 10000, + NumCounters: 10000000, + IgnoreInternalCost: true, + BufferItems: 64, + }) + require.NoError(t, err) + + query1 := "query { test1 }" + query2 := "query { test2 }" + + mockCache1.Set(1, &planWithMetaData{content: query1}, 1) + mockCache1.Wait() + mockCache2.Set(2, &planWithMetaData{content: query2}, 1) + mockCache2.Wait() + + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + cache.queriesForFeatureFlag["ff1"] = mockCache1 + cache.queriesForFeatureFlag["ff2"] = mockCache2 + + cache.extractQueriesAndOverridePlanCache() + + // Verify both caches have been converted to operation slices + require.IsType(t, []*nodev1.Operation{}, cache.queriesForFeatureFlag["ff1"]) + require.IsType(t, []*nodev1.Operation{}, cache.queriesForFeatureFlag["ff2"]) + + ff1Ops := cache.queriesForFeatureFlag["ff1"].([]*nodev1.Operation) + ff2Ops := cache.queriesForFeatureFlag["ff2"].([]*nodev1.Operation) + + require.Len(t, ff1Ops, 1) + require.Len(t, ff2Ops, 1) + require.Equal(t, query1, ff1Ops[0].Request.Query) + require.Equal(t, query2, ff2Ops[0].Request.Query) + }) + + t.Run("does nothing when cache is disabled", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: nil, + } + + cache.extractQueriesAndOverridePlanCache() + + // Should remain nil since processing is skipped + require.Nil(t, cache.queriesForFeatureFlag) + }) + + t.Run("handles empty cache", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + + require.NotPanics(t, func() { + cache.extractQueriesAndOverridePlanCache() + }) + + require.Empty(t, cache.queriesForFeatureFlag) + }) +} + +func TestInMemorySwitchOverCache_IsEnabled(t *testing.T) { + t.Parallel() + t.Run("returns true when cache is enabled", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: make(map[string]any), + } + + require.True(t, cache.IsEnabled()) + }) + + t.Run("returns false when cache is disabled", func(t *testing.T) { + t.Parallel() + cache := &InMemoryPlanCacheFallback{ + queriesForFeatureFlag: nil, + } + + require.False(t, cache.IsEnabled()) + }) + +} diff --git a/router/core/router.go b/router/core/router.go index ad4b77cc33..a54dcdd4bf 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -90,6 +90,7 @@ type ( proxy ProxyFunc disableUsageTracking bool usage UsageTracker + switchoverConfig *SwitchoverConfig } UsageTracker interface { @@ -604,6 +605,9 @@ func (r *Router) newServer(ctx context.Context, cfg *nodev1.RouterConfig) error r.httpServer.SwapGraphServer(ctx, server) + // Cleanup any unused feature flags in case a feature flag was removed + r.switchoverConfig.CleanupFeatureFlags(cfg) + return nil } @@ -724,56 +728,6 @@ func (r *Router) BaseURL() string { return r.baseURL } -// NewServer prepares a new server instance but does not start it. The method should only be used when you want to bootstrap -// the server manually otherwise you can use Router.Start(). You're responsible for setting health checks status to ready with Server.HealthChecks(). -// The server can be shutdown with Router.Shutdown(). Use core.WithExecutionConfig to pass the initial config otherwise the Router will -// try to fetch the config from the control plane. You can swap the router config by using Router.newGraphServer(). -func (r *Router) NewServer(ctx context.Context) (Server, error) { - if r.shutdown.Load() { - return nil, fmt.Errorf("router is shutdown. Create a new instance with router.NewRouter()") - } - - if err := r.bootstrap(ctx); err != nil { - return nil, fmt.Errorf("failed to bootstrap application: %w", err) - } - - r.httpServer = newServer(&httpServerOptions{ - addr: r.listenAddr, - logger: r.logger, - tlsConfig: r.tlsConfig, - tlsServerConfig: r.tlsServerConfig, - healthcheck: r.healthcheck, - baseURL: r.baseURL, - maxHeaderBytes: int(r.routerTrafficConfig.MaxHeaderBytes.Uint64()), - livenessCheckPath: r.livenessCheckPath, - readinessCheckPath: r.readinessCheckPath, - healthCheckPath: r.healthCheckPath, - }) - - // Start the server with the static config without polling - if r.staticExecutionConfig != nil { - r.logger.Info("Static execution config provided. Polling is disabled. Updating execution config is only possible by providing a config.") - return r.httpServer, r.newServer(ctx, r.staticExecutionConfig) - } - - // when no static config is provided and no poller is configured, we can't start the server - if r.configPoller == nil { - return nil, fmt.Errorf("config fetcher not provided. Please provide a static execution config instead") - } - - cfg, err := r.configPoller.GetRouterConfig(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get initial execution config: %w", err) - } - - if err := r.newServer(ctx, cfg.Config); err != nil { - r.logger.Error("Failed to start server with initial config", zap.Error(err)) - return nil, err - } - - return r.httpServer, nil -} - // bootstrap initializes the Router. It is called by Start() and NewServer(). // It should only be called once for a Router instance. func (r *Router) bootstrap(ctx context.Context) error { @@ -1221,6 +1175,13 @@ func (r *Router) Start(ctx context.Context) error { healthCheckPath: r.healthCheckPath, }) + if r.switchoverConfig == nil { + // This is only applicable for tests since we do not call here via the supervisor + r.switchoverConfig = NewSwitchoverConfig(r.logger) + } + + r.switchoverConfig.UpdateSwitchoverConfig(&r.Config) + // Start the server with the static config without polling if r.staticExecutionConfig != nil { @@ -2095,6 +2056,12 @@ func WithConfigPollerConfig(cfg *RouterConfigPollerConfig) Option { } } +func WithSwitchoverConfig(cfg *SwitchoverConfig) Option { + return func(r *Router) { + r.switchoverConfig = cfg + } +} + func WithPersistedOperationsConfig(cfg config.PersistedOperationsConfig) Option { return func(r *Router) { r.persistedOperationsConfig = cfg diff --git a/router/core/router_config.go b/router/core/router_config.go index 319216a18a..57a4821d8c 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -317,6 +317,7 @@ func (c *Config) Usage() map[string]any { } else { usage["cache_warmup_source"] = "cdn" } + usage["cache_warmup_in_memory_fallback_enabled"] = c.cacheWarmup.InMemoryFallback usage["cache_warmup_workers"] = c.cacheWarmup.Workers usage["cache_warmup_items_per_second"] = c.cacheWarmup.ItemsPerSecond usage["cache_warmup_timeout"] = c.cacheWarmup.Timeout.String() diff --git a/router/core/supervisor.go b/router/core/supervisor.go index 30e7e0fca8..577bbcb902 100644 --- a/router/core/supervisor.go +++ b/router/core/supervisor.go @@ -30,15 +30,17 @@ type RouterSupervisor struct { // RouterResources is a struct for holding resources used by the router. type RouterResources struct { - Config *config.Config - Logger *zap.Logger + Config *config.Config + Logger *zap.Logger + SwitchoverConfig *SwitchoverConfig } // RouterSupervisorOpts is a struct for configuring the router supervisor. type RouterSupervisorOpts struct { - BaseLogger *zap.Logger - ConfigFactory func() (*config.Config, error) - RouterFactory func(ctx context.Context, res *RouterResources) (*Router, error) + BaseLogger *zap.Logger + ConfigFactory func() (*config.Config, error) + RouterFactory func(ctx context.Context, res *RouterResources) (*Router, error) + SwitchoverConfig *SwitchoverConfig } // NewRouterSupervisor creates a new RouterSupervisor instance. @@ -48,7 +50,8 @@ func NewRouterSupervisor(opts *RouterSupervisorOpts) (*RouterSupervisor, error) logger: opts.BaseLogger.With(zap.String("component", "supervisor")), configFactory: opts.ConfigFactory, resources: &RouterResources{ - Logger: opts.BaseLogger, + Logger: opts.BaseLogger, + SwitchoverConfig: opts.SwitchoverConfig, }, } @@ -154,6 +157,11 @@ func (rs *RouterSupervisor) Start() error { shutdown := <-rs.shutdownChan rs.logger.Debug("Got shutdown signal", zap.Bool("shutdown", shutdown)) + + if !shutdown { + rs.router.switchoverConfig.OnRouterConfigReload() + } + if err := rs.stopRouter(); err != nil { if errors.Is(err, context.DeadlineExceeded) { rs.logger.Warn("Router shutdown deadline exceeded. Consider increasing the shutdown delay") diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 0fafc833e2..a18c843c28 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -50,7 +50,7 @@ func newRouter(ctx context.Context, params RouterResources, additionalOptions .. } } - options := optionsFromResources(logger, cfg) + options := optionsFromResources(logger, cfg, params.SwitchoverConfig) options = append(options, additionalOptions...) authenticators, err := setupAuthenticators(ctx, logger, cfg) @@ -181,7 +181,7 @@ func newRouter(ctx context.Context, params RouterResources, additionalOptions .. return NewRouter(options...) } -func optionsFromResources(logger *zap.Logger, config *config.Config) []Option { +func optionsFromResources(logger *zap.Logger, config *config.Config, switchoverConfig *SwitchoverConfig) []Option { options := []Option{ WithListenerAddr(config.ListenAddr), WithOverrideRoutingURL(config.OverrideRoutingURL), @@ -272,6 +272,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config) []Option { WithPlugins(config.Plugins), WithDemoMode(config.DemoMode), WithStreamsHandlerConfiguration(config.Events.Handlers), + WithSwitchoverConfig(switchoverConfig), } return options diff --git a/router/go.mod b/router/go.mod index 415918d050..9e97513a6c 100644 --- a/router/go.mod +++ b/router/go.mod @@ -63,7 +63,7 @@ require ( github.com/alicebob/miniredis/v2 v2.34.0 github.com/caarlos0/env/v11 v11.3.1 github.com/cep21/circuit/v4 v4.0.0 - github.com/dgraph-io/ristretto/v2 v2.1.0 + github.com/dgraph-io/ristretto/v2 v2.4.0 github.com/expr-lang/expr v1.17.7 github.com/goccy/go-json v0.10.3 github.com/google/go-containerregistry v0.20.3 diff --git a/router/go.sum b/router/go.sum index 715b552f47..9a15517ef1 100644 --- a/router/go.sum +++ b/router/go.sum @@ -54,10 +54,10 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/ristretto/v2 v2.1.0 h1:59LjpOJLNDULHh8MC4UaegN52lC4JnO2dITsie/Pa8I= -github.com/dgraph-io/ristretto/v2 v2.1.0/go.mod h1:uejeqfYXpUomfse0+lO+13ATz4TypQYLJZzBSAemuB4= -github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= -github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dgraph-io/ristretto/v2 v2.4.0 h1:I/w09yLjhdcVD2QV192UJcq8dPBaAJb9pOuMyNy0XlU= +github.com/dgraph-io/ristretto/v2 v2.4.0/go.mod h1:0KsrXtXvnv0EqnzyowllbVJB8yBonswa2lTCK2gGo9E= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 8eb71bc5f1..51d74822d1 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -973,20 +973,24 @@ type ClientHeader struct { type CacheWarmupSource struct { Filesystem *CacheWarmupFileSystemSource `yaml:"filesystem,omitempty"` + CdnSource CacheWarmupCDNSource `yaml:"cdn,omitempty"` } type CacheWarmupFileSystemSource struct { Path string `yaml:"path" env:"CACHE_WARMUP_SOURCE_FILESYSTEM_PATH"` } -type CacheWarmupCDNSource struct{} +type CacheWarmupCDNSource struct { + Enabled bool `yaml:"enabled" envDefault:"true" env:"CACHE_WARMUP_SOURCE_CDN_ENABLED"` +} type CacheWarmupConfiguration struct { - Enabled bool `yaml:"enabled" envDefault:"false" env:"CACHE_WARMUP_ENABLED"` - Source CacheWarmupSource `yaml:"source" env:"CACHE_WARMUP_SOURCE"` - Workers int `yaml:"workers" envDefault:"8" env:"CACHE_WARMUP_WORKERS"` - ItemsPerSecond int `yaml:"items_per_second" envDefault:"50" env:"CACHE_WARMUP_ITEMS_PER_SECOND"` - Timeout time.Duration `yaml:"timeout" envDefault:"30s" env:"CACHE_WARMUP_TIMEOUT"` + Enabled bool `yaml:"enabled" envDefault:"false" env:"CACHE_WARMUP_ENABLED"` + Source CacheWarmupSource `yaml:"source" env:"CACHE_WARMUP_SOURCE"` + Workers int `yaml:"workers" envDefault:"8" env:"CACHE_WARMUP_WORKERS"` + ItemsPerSecond int `yaml:"items_per_second" envDefault:"50" env:"CACHE_WARMUP_ITEMS_PER_SECOND"` + Timeout time.Duration `yaml:"timeout" envDefault:"30s" env:"CACHE_WARMUP_TIMEOUT"` + InMemoryFallback bool `yaml:"in_memory_fallback" envDefault:"true" env:"CACHE_WARMUP_IN_MEMORY_FALLBACK"` } type MCPConfiguration struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index a531fa4af3..fe09b965a9 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2405,11 +2405,26 @@ "format": "file-path" } } + }, + "cdn": { + "type": "object", + "description": "The CDN source of the cache warmup items.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable or disable the CDN source for cache warmup.", + "default": true + } + } } }, "oneOf": [ { "required": ["filesystem"] + }, + { + "required": ["cdn"] } ] }, @@ -2430,6 +2445,11 @@ "duration": { "minimum": "1s" } + }, + "in_memory_fallback": { + "type": "boolean", + "description": "Enable in-memory fallback. When enabled, the router will reuse the cached query plans in memory and use it to rewarm the cache on schema changes and hot config reloads. The default value is true.", + "default": true } } }, diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index b4ddad685e..72b65b81b5 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -312,11 +312,15 @@ "CacheWarmup": { "Enabled": false, "Source": { - "Filesystem": null + "Filesystem": null, + "CdnSource": { + "Enabled": true + } }, "Workers": 8, "ItemsPerSecond": 50, - "Timeout": 30000000000 + "Timeout": 30000000000, + "InMemoryFallback": true }, "RouterConfigPath": "", "RouterRegistration": true, diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index d4707aa1a8..815fb5c6e5 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -661,11 +661,15 @@ "CacheWarmup": { "Enabled": false, "Source": { - "Filesystem": null + "Filesystem": null, + "CdnSource": { + "Enabled": true + } }, "Workers": 8, "ItemsPerSecond": 50, - "Timeout": 30000000000 + "Timeout": 30000000000, + "InMemoryFallback": true }, "RouterConfigPath": "latest.json", "RouterRegistration": true,