Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions router-tests/normalization_cache_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,162 @@
package integration

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
"github.com/wundergraph/cosmo/router/pkg/config"
)

// cacheHit represents the expected cache hit/miss status for all three normalization stages.
// True values mean the cache hit.
type cacheHit struct {
normalization bool
variables bool
remapping bool
}

// assertCacheHeaders checks all three normalization cache headers
func assertCacheHeaders(t *testing.T, res *testenv.TestResponse, expected cacheHit) {
t.Helper()
s := func(hit bool) string {
if hit {
return "HIT"
}
return "MISS"
}

require.Equal(t, s(expected.normalization), res.Response.Header.Get(core.NormalizationCacheHeader),
"Normalization cache hit mismatch")
require.Equal(t, s(expected.variables), res.Response.Header.Get(core.VariablesNormalizationCacheHeader),
"Variables normalization cache hit mismatch")
require.Equal(t, s(expected.remapping), res.Response.Header.Get(core.VariablesRemappingCacheHeader),
"Variables remapping cache hit mismatch")
}

func TestAdditionalNormalizationCaches(t *testing.T) {
t.Parallel()

t.Run("Basic normalization cache with skip/include", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) {
f := func(expected cacheHit, skipMouse bool) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
OperationName: []byte(`"Employee"`),
Query: `query Employee( $id: Int! = 4 $withAligators: Boolean! $withCats: Boolean! $skipDogs:Boolean! $skipMouses:Boolean! ) { employee(id: $id) { details { pets { name __typename ...AlligatorFields @include(if: $withAligators) ...CatFields @include(if: $withCats) ...DogFields @skip(if: $skipDogs) ...MouseFields @skip(if: $skipMouses) ...PonyFields @include(if: false) } } } } fragment AlligatorFields on Alligator { __typename class dangerous gender name } fragment CatFields on Cat { __typename class gender name type } fragment DogFields on Dog { __typename breed class gender name } fragment MouseFields on Mouse { __typename class gender name } fragment PonyFields on Pony { __typename class gender name }`,
Variables: []byte(fmt.Sprintf(`{"withAligators": true,"withCats": true,"skipDogs": false,"skipMouses": %t}`, skipMouse)),
})
assertCacheHeaders(t, res, expected)
require.Equal(t, `{"data":{"employee":{"details":{"pets":[{"name":"Abby","__typename":"Dog","breed":"GOLDEN_RETRIEVER","class":"MAMMAL","gender":"FEMALE"},{"name":"Survivor","__typename":"Pony"}]}}}}`, res.Body)
}

// First request: all caches miss
f(cacheHit{false, false, false}, true)
// Second request: all caches hit
f(cacheHit{true, true, true}, true)
// Third request: all caches hit
f(cacheHit{true, true, true}, true)
// Fourth request: different skip/include value, all caches miss
f(cacheHit{false, false, false}, false)
// Fifth request: back to original skip/include value, all caches hit
f(cacheHit{true, true, true}, true)
})
})

t.Run("Variables normalization cache - inline value extraction", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) {
// Test 1: Inline value gets extracted to variable - all caches miss
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employee(id: 1) { id details { forename } } }`,
})
assertCacheHeaders(t, res, cacheHit{false, false, false})

// Test 2: Same query - all caches hit
res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employee(id: 1) { id details { forename } } }`,
})
assertCacheHeaders(t, res, cacheHit{true, true, true})

// Test 3: Different inline value
res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employee(id: 2) { id details { forename } } }`,
})
assertCacheHeaders(t, res, cacheHit{false, false, true})
})
})

t.Run("Variables normalization cache - query changes, but variables stay the same", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) {
// Test with unused variables that should be removed - all caches miss
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query MyQuery($id: Int!) { employee(id: $id) { id } }`,
Variables: []byte(`{"id": 1}`),
})
require.Equal(t, `{"data":{"employee":{"id":1}}}`, res.Body)
assertCacheHeaders(t, res, cacheHit{false, false, false})

// Different query with same variable value.
res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query MyQuery($id: Int!) { employee(id: $id) { id details { forename }} }`,
Variables: []byte(`{"id": 1}`),
})
require.Equal(t, `{"data":{"employee":{"id":1,"details":{"forename":"Jens"}}}}`, res.Body)
assertCacheHeaders(t, res, cacheHit{false, false, false})
})
})

t.Run("Cache key isolation - different operations don't collide", func(t *testing.T) {
testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) {
// Test 1: Query A
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query A($id: Int!) { employee(id: $id) { id } }`,
Variables: []byte(`{"id": 1}`),
})
assertCacheHeaders(t, res, cacheHit{false, false, false})

// Test 2: Query B with different structure should miss
res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query B($id: Int!) { employee(id: $id) { id details { forename } } }`,
Variables: []byte(`{"id": 1}`),
})
assertCacheHeaders(t, res, cacheHit{false, false, false})

// Test 3: Query A again should hit its own cache
res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query A($id: Int!) { employee(id: $id) { id } }`,
Variables: []byte(`{"id": 1}`),
})
assertCacheHeaders(t, res, cacheHit{true, true, true})
})
})

t.Run("List coercion with variables normalization cache", func(t *testing.T) {
testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) {
// Test that list coercion works correctly with caching
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query MyQuery($arg: [String!]!) { rootFieldWithListArg(arg: $arg) }`,
Variables: []byte(`{"arg": "single"}`),
})
require.Equal(t, `{"data":{"rootFieldWithListArg":["single"]}}`, res.Body)
assertCacheHeaders(t, res, cacheHit{false, false, false})

// Same structure should hit cache even with different value
res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query MyQuery($arg: [String!]!) { rootFieldWithListArg(arg: $arg) }`,
Variables: []byte(`{"arg": "different"}`),
})
require.Equal(t, `{"data":{"rootFieldWithListArg":["different"]}}`, res.Body)
assertCacheHeaders(t, res, cacheHit{true, false, true})
})
})

}

func TestNormalizationCache(t *testing.T) {
t.Parallel()

Expand Down
4 changes: 2 additions & 2 deletions router/core/cache_warmup.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,12 @@ func (c *CacheWarmupPlanningProcessor) ProcessOperation(ctx context.Context, ope
return nil, err
}

_, err = k.NormalizeVariables()
_, _, err = k.NormalizeVariables()
if err != nil {
return nil, err
}

err = k.RemapVariables(c.disableVariablesRemapping)
_, err = k.RemapVariables(c.disableVariablesRemapping)
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions router/core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,10 @@ type operationContext struct {
sha256Hash string
protocol OperationProtocol

persistedOperationCacheHit bool
normalizationCacheHit bool
persistedOperationCacheHit bool
normalizationCacheHit bool
variablesNormalizationCacheHit bool
variablesRemappingCacheHit bool

typeFieldUsageInfo graphqlschemausage.TypeFieldMetrics
argumentUsageInfo []*graphqlmetrics.ArgumentUsageInfo
Expand Down
87 changes: 51 additions & 36 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,18 +506,22 @@ func (s *graphServer) setupEngineStatistics(baseAttributes []attribute.KeyValue)
}

type graphMux struct {
mux *chi.Mux
planCache *ristretto.Cache[uint64, *planWithMetaData]
persistedOperationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
complexityCalculationCache *ristretto.Cache[uint64, ComplexityCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
operationHashCache *ristretto.Cache[uint64, string]
accessLogsFileLogger *logging.BufferedLogger
metricStore rmetric.Store
prometheusCacheMetrics *rmetric.CacheMetrics
otelCacheMetrics *rmetric.CacheMetrics
streamMetricStore rmetric.StreamMetricStore
mux *chi.Mux

planCache *ristretto.Cache[uint64, *planWithMetaData]
persistedOperationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
complexityCalculationCache *ristretto.Cache[uint64, ComplexityCacheEntry]
variablesNormalizationCache *ristretto.Cache[uint64, VariablesNormalizationCacheEntry]
remapVariablesCache *ristretto.Cache[uint64, RemapVariablesCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
operationHashCache *ristretto.Cache[uint64, string]

accessLogsFileLogger *logging.BufferedLogger
metricStore rmetric.Store
prometheusCacheMetrics *rmetric.CacheMetrics
otelCacheMetrics *rmetric.CacheMetrics
streamMetricStore rmetric.StreamMetricStore
}

// buildOperationCaches creates the caches for the graph mux.
Expand Down Expand Up @@ -572,6 +576,30 @@ func (s *graphMux) buildOperationCaches(srv *graphServer) (computeSha256 bool, e
if err != nil {
return computeSha256, fmt.Errorf("failed to create normalization cache: %w", err)
}

variablesNormalizationCacheConfig := &ristretto.Config[uint64, VariablesNormalizationCacheEntry]{
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
MaxCost: srv.engineExecutionConfiguration.NormalizationCacheSize,
NumCounters: srv.engineExecutionConfiguration.NormalizationCacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
}
s.variablesNormalizationCache, err = ristretto.NewCache[uint64, VariablesNormalizationCacheEntry](variablesNormalizationCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create normalization cache: %w", err)
}

remapVariablesCacheConfig := &ristretto.Config[uint64, RemapVariablesCacheEntry]{
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
MaxCost: srv.engineExecutionConfiguration.NormalizationCacheSize,
NumCounters: srv.engineExecutionConfiguration.NormalizationCacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
}
s.remapVariablesCache, err = ristretto.NewCache[uint64, RemapVariablesCacheEntry](remapVariablesCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create normalization cache: %w", err)
}
}

if srv.engineExecutionConfiguration.EnableValidationCache && srv.engineExecutionConfiguration.ValidationCacheSize > 0 {
Expand Down Expand Up @@ -709,31 +737,16 @@ func (s *graphMux) configureCacheMetrics(srv *graphServer, baseOtelAttributes []
}

func (s *graphMux) Shutdown(ctx context.Context) error {
var err error

if s.planCache != nil {
s.planCache.Close()
}
s.planCache.Close()
s.persistedOperationCache.Close()
s.normalizationCache.Close()
s.variablesNormalizationCache.Close()
s.remapVariablesCache.Close()
s.complexityCalculationCache.Close()
s.validationCache.Close()
s.operationHashCache.Close()

if s.persistedOperationCache != nil {
s.persistedOperationCache.Close()
}

if s.normalizationCache != nil {
s.normalizationCache.Close()
}

if s.complexityCalculationCache != nil {
s.complexityCalculationCache.Close()
}

if s.validationCache != nil {
s.validationCache.Close()
}

if s.operationHashCache != nil {
s.operationHashCache.Close()
}
var err error

if s.accessLogsFileLogger != nil {
if aErr := s.accessLogsFileLogger.Close(); aErr != nil {
Expand Down Expand Up @@ -1227,6 +1240,8 @@ func (s *graphServer) buildGraphMux(
ValidationCache: gm.validationCache,
QueryDepthCache: gm.complexityCalculationCache,
OperationHashCache: gm.operationHashCache,
VariablesNormalizationCache: gm.variablesNormalizationCache,
RemapVariablesCache: gm.remapVariablesCache,
ParseKitPoolSize: s.engineExecutionConfiguration.ParseKitPoolSize,
IntrospectionEnabled: s.Config.introspection,
ParserTokenizerLimits: astparser.TokenizerLimits{
Expand Down
35 changes: 17 additions & 18 deletions router/core/graphql_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ var (
)

const (
ExecutionPlanCacheHeader = "X-WG-Execution-Plan-Cache"
PersistedOperationCacheHeader = "X-WG-Persisted-Operation-Cache"
NormalizationCacheHeader = "X-WG-Normalization-Cache"
ExecutionPlanCacheHeader = "X-WG-Execution-Plan-Cache"
PersistedOperationCacheHeader = "X-WG-Persisted-Operation-Cache"
NormalizationCacheHeader = "X-WG-Normalization-Cache"
VariablesNormalizationCacheHeader = "X-WG-Variables-Normalization-Cache"
VariablesRemappingCacheHeader = "X-WG-Variables-Remapping-Cache"
)

type ReportError interface {
Expand Down Expand Up @@ -428,25 +430,22 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv
}

func (h *GraphQLHandler) setDebugCacheHeaders(w http.ResponseWriter, opCtx *operationContext) {
if h.enableNormalizationCacheResponseHeader {
if opCtx.normalizationCacheHit {
w.Header().Set(NormalizationCacheHeader, "HIT")
} else {
w.Header().Set(NormalizationCacheHeader, "MISS")
s := func(hit bool) string {
if hit {
return "HIT"
}
return "MISS"
}

if h.enableNormalizationCacheResponseHeader {
w.Header().Set(NormalizationCacheHeader, s(opCtx.normalizationCacheHit))
w.Header().Set(VariablesNormalizationCacheHeader, s(opCtx.variablesNormalizationCacheHit))
w.Header().Set(VariablesRemappingCacheHeader, s(opCtx.variablesRemappingCacheHit))
}
if h.enablePersistedOperationCacheResponseHeader {
if opCtx.persistedOperationCacheHit {
w.Header().Set(PersistedOperationCacheHeader, "HIT")
} else {
w.Header().Set(PersistedOperationCacheHeader, "MISS")
}
w.Header().Set(PersistedOperationCacheHeader, s(opCtx.persistedOperationCacheHit))
}
if h.enableExecutionPlanCacheResponseHeader {
if opCtx.planCacheHit {
w.Header().Set(ExecutionPlanCacheHeader, "HIT")
} else {
w.Header().Set(ExecutionPlanCacheHeader, "MISS")
}
w.Header().Set(ExecutionPlanCacheHeader, s(opCtx.planCacheHit))
}
}
Loading
Loading