Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions router-tests/normalization_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,35 @@ import (
"testing"

"github.com/stretchr/testify/require"

"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
"github.com/wundergraph/cosmo/router/pkg/config"
)

func TestAdditionalNormalizationCaches(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) {
f := func(v string) {
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
OperationName: []byte(`"Employee"`),
Query: `query Employee( $id: Int! = 4 $withAligators: Boolean! $withCats: Boolean! $skipDogs:Boolean! $skipMouses:Boolean! ) { employee(id: $id) { details { pets { name __typename ...AlligatorFields @include(if: $withAligators) ...CatFields @include(if: $withCats) ...DogFields @skip(if: $skipDogs) ...MouseFields @skip(if: $skipMouses) ...PonyFields @include(if: false) } } } } fragment AlligatorFields on Alligator { __typename class dangerous gender name } fragment CatFields on Cat { __typename class gender name type } fragment DogFields on Dog { __typename breed class gender name } fragment MouseFields on Mouse { __typename class gender name } fragment PonyFields on Pony { __typename class gender name }`,
Variables: []byte(`{"withAligators": true,"withCats": true,"skipDogs": false,"skipMouses": true}`),
})
require.NoError(t, err)
require.Equal(t, v, res.Response.Header.Get(core.NormalizationCacheHeader))
require.Equal(t, `{"data":{"employee":{"details":{"pets":[{"name":"Abby","__typename":"Dog","breed":"GOLDEN_RETRIEVER","class":"MAMMAL","gender":"FEMALE"},{"name":"Survivor","__typename":"Pony"}]}}}}`, res.Body)
}

f("MISS")
f("HIT")
f("HIT")
f("HIT")
f("HIT")
})
}

func TestNormalizationCache(t *testing.T) {
t.Parallel()

Expand Down
61 changes: 49 additions & 12 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,18 +506,21 @@ func (s *graphServer) setupEngineStatistics(baseAttributes []attribute.KeyValue)
}

type graphMux struct {
mux *chi.Mux
planCache *ristretto.Cache[uint64, *planWithMetaData]
persistedOperationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
complexityCalculationCache *ristretto.Cache[uint64, ComplexityCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
operationHashCache *ristretto.Cache[uint64, string]
accessLogsFileLogger *logging.BufferedLogger
metricStore rmetric.Store
prometheusCacheMetrics *rmetric.CacheMetrics
otelCacheMetrics *rmetric.CacheMetrics
streamMetricStore rmetric.StreamMetricStore
mux *chi.Mux
planCache *ristretto.Cache[uint64, *planWithMetaData]
persistedOperationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
complexityCalculationCache *ristretto.Cache[uint64, ComplexityCacheEntry]
variablesNormalizationCache *ristretto.Cache[uint64, VariablesNormalizationCacheEntry]
remapVariablesCache *ristretto.Cache[uint64, RemapVariablesCacheEntry]

validationCache *ristretto.Cache[uint64, bool]
operationHashCache *ristretto.Cache[uint64, string]
accessLogsFileLogger *logging.BufferedLogger
metricStore rmetric.Store
prometheusCacheMetrics *rmetric.CacheMetrics
otelCacheMetrics *rmetric.CacheMetrics
streamMetricStore rmetric.StreamMetricStore
}

// buildOperationCaches creates the caches for the graph mux.
Expand Down Expand Up @@ -572,6 +575,30 @@ func (s *graphMux) buildOperationCaches(srv *graphServer) (computeSha256 bool, e
if err != nil {
return computeSha256, fmt.Errorf("failed to create normalization cache: %w", err)
}

variablesNormalizationCacheConfig := &ristretto.Config[uint64, VariablesNormalizationCacheEntry]{
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
MaxCost: srv.engineExecutionConfiguration.NormalizationCacheSize,
NumCounters: srv.engineExecutionConfiguration.NormalizationCacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
}
s.variablesNormalizationCache, err = ristretto.NewCache[uint64, VariablesNormalizationCacheEntry](variablesNormalizationCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create normalization cache: %w", err)
}

remapVariablesCacheConfig := &ristretto.Config[uint64, RemapVariablesCacheEntry]{
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
MaxCost: srv.engineExecutionConfiguration.NormalizationCacheSize,
NumCounters: srv.engineExecutionConfiguration.NormalizationCacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
}
s.remapVariablesCache, err = ristretto.NewCache[uint64, RemapVariablesCacheEntry](remapVariablesCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create normalization cache: %w", err)
}
}

if srv.engineExecutionConfiguration.EnableValidationCache && srv.engineExecutionConfiguration.ValidationCacheSize > 0 {
Expand Down Expand Up @@ -723,6 +750,14 @@ func (s *graphMux) Shutdown(ctx context.Context) error {
s.normalizationCache.Close()
}

if s.variablesNormalizationCache != nil {
s.variablesNormalizationCache.Close()
}

if s.remapVariablesCache != nil {
s.remapVariablesCache.Close()
}

if s.complexityCalculationCache != nil {
s.complexityCalculationCache.Close()
}
Expand Down Expand Up @@ -1226,6 +1261,8 @@ func (s *graphServer) buildGraphMux(
ValidationCache: gm.validationCache,
QueryDepthCache: gm.complexityCalculationCache,
OperationHashCache: gm.operationHashCache,
VariablesNormalizationCache: gm.variablesNormalizationCache,
RemapVariablesCache: gm.remapVariablesCache,
ParseKitPoolSize: s.engineExecutionConfiguration.ParseKitPoolSize,
IntrospectionEnabled: s.Config.introspection,
ParserTokenizerLimits: astparser.TokenizerLimits{
Expand Down
150 changes: 133 additions & 17 deletions router/core/operation_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ type OperationProcessorOptions struct {
PersistedOpsNormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
NormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
QueryDepthCache *ristretto.Cache[uint64, ComplexityCacheEntry]
VariablesNormalizationCache *ristretto.Cache[uint64, VariablesNormalizationCacheEntry]
RemapVariablesCache *ristretto.Cache[uint64, RemapVariablesCacheEntry]
ValidationCache *ristretto.Cache[uint64, bool]
OperationHashCache *ristretto.Cache[uint64, string]
ParseKitPoolSize int
Expand Down Expand Up @@ -160,6 +162,8 @@ type OperationCache struct {

persistedOperationNormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
variablesNormalizationCache *ristretto.Cache[uint64, VariablesNormalizationCacheEntry]
remapVariablesCache *ristretto.Cache[uint64, RemapVariablesCacheEntry]
complexityCache *ristretto.Cache[uint64, ComplexityCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
operationHashCache *ristretto.Cache[uint64, string]
Expand Down Expand Up @@ -749,12 +753,25 @@ func (o *OperationKit) normalizePersistedOperation(clientName string, isApq bool
}

type NormalizationCacheEntry struct {
operationID uint64
normalizedRepresentation string
operationType string
operationDefinitionRef int
}

type VariablesNormalizationCacheEntry struct {
normalizedRepresentation string
uploadsMapping []uploads.UploadPathMapping
id uint64
variables json.RawMessage
reparse bool
}

type RemapVariablesCacheEntry struct {
internalID uint64
normalizedRepresentation string
remapVariables map[string]string
}

type ComplexityCacheEntry struct {
Depth int
TotalFields int
Expand All @@ -772,6 +789,13 @@ func (o *OperationKit) normalizeNonPersistedOperation() (cached bool, err error)
o.parsedOperation.NormalizedRepresentation = entry.normalizedRepresentation
o.parsedOperation.Type = entry.operationType
o.parsedOperation.NormalizationCacheHit = true

// remove skip include variables from variables
// as they were removed during normalization, but still present when we get operation from cache
for _, varName := range skipIncludeVariableNames {
o.parsedOperation.Request.Variables = jsonparser.Delete(o.parsedOperation.Request.Variables, varName)
}

err = o.setAndParseOperationDoc()
if err != nil {
return false, err
Expand All @@ -790,15 +814,9 @@ func (o *OperationKit) normalizeNonPersistedOperation() (cached bool, err error)
}
}

// reset with the original variables
// set variables to the normalized variables as skip inlude variables will be removed after normalization
o.parsedOperation.Request.Variables = o.kit.doc.Input.Variables

// Hash the normalized operation with the static operation name & original variables to avoid different IDs for the same operation
err = o.kit.printer.Print(o.kit.doc, o.kit.keyGen)
if err != nil {
return false, errors.WithStack(fmt.Errorf("normalizeNonPersistedOperation (uncached) failed generating operation hash: %w", err))
}

// Print the operation with the original operation name
o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = o.originalOperationNameRef
err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation)
Expand All @@ -811,7 +829,6 @@ func (o *OperationKit) normalizeNonPersistedOperation() (cached bool, err error)

if o.cache != nil && o.cache.normalizationCache != nil {
entry := NormalizationCacheEntry{
operationID: o.parsedOperation.InternalID,
normalizedRepresentation: o.parsedOperation.NormalizedRepresentation,
operationType: o.parsedOperation.Type,
}
Expand Down Expand Up @@ -840,7 +857,43 @@ func (o *OperationKit) setAndParseOperationDoc() error {
return nil
}

func (o *OperationKit) normalizeVariablesCacheKey() uint64 {
_, _ = o.kit.keyGen.Write(o.kit.doc.Input.Variables)

// fmt.Println("####### NormalizeVariables: variables len:", len(o.kit.doc.Input.Variables))
// fmt.Println("####### NormalizeVariables: variables:", string(o.kit.doc.Input.Variables))
_, _ = o.kit.keyGen.WriteString(o.parsedOperation.NormalizedRepresentation)

// fmt.Println("####### NormalizeVariables: normalizedRepresentation:", o.parsedOperation.NormalizedRepresentation)
sum := o.kit.keyGen.Sum64()
o.kit.keyGen.Reset()
return sum
}

func (o *OperationKit) NormalizeVariables() ([]uploads.UploadPathMapping, error) {
cacheKey := o.normalizeVariablesCacheKey()
// fmt.Println("####### NormalizeVariables: cacheKey:", cacheKey, "#######")
if o.cache != nil && o.cache.variablesNormalizationCache != nil {
entry, ok := o.cache.variablesNormalizationCache.Get(cacheKey)
if ok {
o.parsedOperation.NormalizedRepresentation = entry.normalizedRepresentation
o.parsedOperation.ID = entry.id
o.parsedOperation.Request.Variables = entry.variables

if entry.reparse {
if err := o.setAndParseOperationDoc(); err != nil {
return nil, err
}
}

// fmt.Println("####### NormalizeVariables: cache hit #######")

return entry.uploadsMapping, nil
} else {
// fmt.Println("####### NormalizeVariables: cache miss #######")
}
}

variablesBefore := make([]byte, len(o.kit.doc.Input.Variables))
copy(variablesBefore, o.kit.doc.Input.Variables)

Expand Down Expand Up @@ -879,16 +932,28 @@ func (o *OperationKit) NormalizeVariables() ([]uploads.UploadPathMapping, error)
// Reset the doc with the original name
o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = nameRef

o.kit.keyGen.Reset()
o.kit.keyGen.Reset() // should not be needed if we properly reset after use - check do we have any remaining places where we do not reset keygen - maybe wrap into a type which will reset once we got key
_, err = o.kit.keyGen.Write(o.kit.normalizedOperation.Bytes())
if err != nil {
return nil, err
}

o.parsedOperation.ID = o.kit.keyGen.Sum64()
o.kit.keyGen.Reset()

// If the normalized form of the operation didn't change, we don't need to print it again
if bytes.Equal(o.kit.doc.Input.Variables, variablesBefore) && bytes.Equal(o.kit.doc.Input.RawBytes, operationRawBytesBefore) {
if o.cache != nil && o.cache.variablesNormalizationCache != nil {
entry := VariablesNormalizationCacheEntry{
uploadsMapping: uploadsMapping,
id: o.parsedOperation.ID,
normalizedRepresentation: o.parsedOperation.NormalizedRepresentation,
variables: o.parsedOperation.Request.Variables,
reparse: false,
}
o.cache.variablesNormalizationCache.Set(cacheKey, entry, 1)
}

return uploadsMapping, nil
}

Expand All @@ -902,10 +967,51 @@ func (o *OperationKit) NormalizeVariables() ([]uploads.UploadPathMapping, error)
o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String()
o.parsedOperation.Request.Variables = o.kit.doc.Input.Variables

if o.cache != nil && o.cache.variablesNormalizationCache != nil {
entry := VariablesNormalizationCacheEntry{
uploadsMapping: uploadsMapping,
id: o.parsedOperation.ID,
normalizedRepresentation: o.parsedOperation.NormalizedRepresentation,
variables: o.parsedOperation.Request.Variables,
reparse: true,
}
o.cache.variablesNormalizationCache.Set(cacheKey, entry, 1)
}

return uploadsMapping, nil
}

func (o *OperationKit) remapVariablesCacheKey() uint64 {
// fmt.Println("####### RemapVariables: normalized representation len:", len(o.parsedOperation.NormalizedRepresentation))
_, _ = o.kit.keyGen.WriteString(o.parsedOperation.NormalizedRepresentation)
// fmt.Println("####### RemapVariables: normalizedRepresentation:", o.parsedOperation.NormalizedRepresentation)
sum := o.kit.keyGen.Sum64()
o.kit.keyGen.Reset()
return sum
}

func (o *OperationKit) RemapVariables(disabled bool) error {
cacheKey := o.remapVariablesCacheKey()
// fmt.Println("####### RemapVariables: cacheKey:", cacheKey, "#######")
if o.cache != nil && o.cache.remapVariablesCache != nil {
entry, ok := o.cache.remapVariablesCache.Get(cacheKey)
if ok {
o.parsedOperation.NormalizedRepresentation = entry.normalizedRepresentation
o.parsedOperation.InternalID = entry.internalID
o.parsedOperation.RemapVariables = entry.remapVariables

if err := o.setAndParseOperationDoc(); err != nil {
return err
}

// fmt.Println("####### RemapVariables: cache hit #######")

return nil
} else {
// fmt.Println("####### RemapVariables: cache miss #######")
}
}

report := &operationreport.Report{}

// even if the variables are disabled, we still need to execute rest of the method,
Expand Down Expand Up @@ -955,6 +1061,15 @@ func (o *OperationKit) RemapVariables(disabled bool) error {

o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String()

if o.cache != nil && o.cache.remapVariablesCache != nil {
entry := RemapVariablesCacheEntry{
internalID: o.parsedOperation.InternalID,
normalizedRepresentation: o.parsedOperation.NormalizedRepresentation,
remapVariables: o.parsedOperation.RemapVariables,
}
o.cache.remapVariablesCache.Set(cacheKey, entry, 1)
}

return nil
}

Expand Down Expand Up @@ -995,7 +1110,6 @@ func (o *OperationKit) handleFoundPersistedOperationEntry(entry NormalizationCac
// as we skip parse for the cached persisted operations
o.parsedOperation.IsPersistedOperation = true
o.parsedOperation.NormalizationCacheHit = true
o.parsedOperation.InternalID = entry.operationID
o.parsedOperation.NormalizedRepresentation = entry.normalizedRepresentation
o.parsedOperation.Type = entry.operationType
// We will always only have a single operation definition in the document
Expand Down Expand Up @@ -1042,7 +1156,6 @@ func (o *OperationKit) persistedOperationCacheKeyHasTtl(clientName string, inclu
func (o *OperationKit) savePersistedOperationToCache(clientName string, isApq bool, skipIncludeVariableNames []string) {
cacheKey := o.generatePersistedOperationCacheKey(clientName, skipIncludeVariableNames, o.kit.numOperations > 1)
entry := NormalizationCacheEntry{
operationID: o.parsedOperation.InternalID,
normalizedRepresentation: o.parsedOperation.NormalizedRepresentation,
operationType: o.parsedOperation.Type,
operationDefinitionRef: o.operationDefinitionRef,
Expand Down Expand Up @@ -1317,12 +1430,15 @@ func NewOperationProcessor(opts OperationProcessorOptions) *OperationProcessor {
processor.parseKitSemaphore <- i
processor.parseKits[i] = createParseKit(i, processor.parseKitOptions)
}
if opts.NormalizationCache != nil || opts.ValidationCache != nil || opts.QueryDepthCache != nil || opts.OperationHashCache != nil || opts.EnablePersistedOperationsCache {
if opts.NormalizationCache != nil || opts.ValidationCache != nil || opts.QueryDepthCache != nil || opts.OperationHashCache != nil || opts.EnablePersistedOperationsCache ||
opts.VariablesNormalizationCache != nil || opts.RemapVariablesCache != nil {
processor.operationCache = &OperationCache{
normalizationCache: opts.NormalizationCache,
validationCache: opts.ValidationCache,
complexityCache: opts.QueryDepthCache,
operationHashCache: opts.OperationHashCache,
normalizationCache: opts.NormalizationCache,
validationCache: opts.ValidationCache,
complexityCache: opts.QueryDepthCache,
operationHashCache: opts.OperationHashCache,
variablesNormalizationCache: opts.VariablesNormalizationCache,
remapVariablesCache: opts.RemapVariablesCache,
}
}
if opts.EnablePersistedOperationsCache {
Expand Down
Loading