diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index 53b4f2c5e8..8b57b9c11d 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -38,12 +38,6 @@ func newInternalExecutionContext() *internalExecutionContext { } } -func (e *internalExecutionContext) prepare(ctx context.Context, variables []byte, request resolve.Request) { - e.setContext(ctx) - e.setVariables(variables) - e.setRequest(request) -} - func (e *internalExecutionContext) setRequest(request resolve.Request) { e.resolveContext.Request = request } @@ -194,7 +188,10 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques } execContext := newInternalExecutionContext() - execContext.prepare(ctx, operation.Variables, operation.InternalRequest()) + execContext.setContext(ctx) + execContext.setVariables(operation.Variables) + execContext.setRequest(operation.InternalRequest()) + for i := range options { options[i](execContext) } @@ -210,10 +207,13 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques } var report operationreport.Report - cachedPlan := e.getCachedPlan(execContext, operation.Document(), e.config.schema.Document(), operation.OperationName, &report) + cachedPlan, costCalculator := e.getCachedPlan(execContext, operation.Document(), e.config.schema.Document(), operation.OperationName, &report) if report.HasErrors() { return report } + operation.ComputeStaticCost(costCalculator, e.config.plannerConfig, execContext.resolveContext.Variables) + // Debugging of cost trees. Do not remove. + // fmt.Println(costCalculator.DebugPrint(e.config.plannerConfig, execContext.resolveContext.Variables)) if execContext.resolveContext.TracingOptions.Enable && !execContext.resolveContext.TracingOptions.ExcludePlannerStats { planningTime := resolve.GetDurationNanoSinceTraceStart(execContext.resolveContext.Context()) - tracePlanStart @@ -236,33 +236,33 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques } } -func (e *ExecutionEngine) getCachedPlan(ctx *internalExecutionContext, operation, definition *ast.Document, operationName string, report *operationreport.Report) plan.Plan { +func (e *ExecutionEngine) getCachedPlan(ctx *internalExecutionContext, operation, definition *ast.Document, operationName string, report *operationreport.Report) (plan.Plan, *plan.CostCalculator) { hash := pool.Hash64.Get() hash.Reset() defer pool.Hash64.Put(hash) err := astprinter.Print(operation, hash) if err != nil { report.AddInternalError(err) - return nil + return nil, nil } cacheKey := hash.Sum64() if cached, ok := e.executionPlanCache.Get(cacheKey); ok { if p, ok := cached.(plan.Plan); ok { - return p + return p, p.GetStaticCostCalculator() } } planner, _ := plan.NewPlanner(e.config.plannerConfig) planResult := planner.Plan(operation, definition, operationName, report) if report.HasErrors() { - return nil + return nil, nil } ctx.postProcessor.Process(planResult) e.executionPlanCache.Add(cacheKey, planResult) - return planResult + return planResult, planResult.GetStaticCostCalculator() } func (e *ExecutionEngine) GetWebsocketBeforeStartHook() WebsocketBeforeStartHook { diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 35ca1c0963..9ac77cd131 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -201,17 +201,19 @@ func TestWithAdditionalHttpHeaders(t *testing.T) { } type ExecutionEngineTestCase struct { - schema *graphql.Schema - operation func(t *testing.T) graphql.Request - dataSources []plan.DataSource - fields plan.FieldConfigurations - engineOptions []ExecutionOptions + schema *graphql.Schema + operation func(t *testing.T) graphql.Request + dataSources []plan.DataSource + fields plan.FieldConfigurations + engineOptions []ExecutionOptions + customResolveMap map[string]resolve.CustomResolve + skipReason string + indentJSON bool + expectedResponse string expectedJSONResponse string expectedFixture string - customResolveMap map[string]resolve.CustomResolve - skipReason string - indentJSON bool + expectedStaticCost int } type _executionTestOptions struct { @@ -220,6 +222,7 @@ type _executionTestOptions struct { apolloRouterCompatibilitySubrequestHTTPError bool propagateFetchReasons bool validateRequiredExternalFields bool + computeStaticCost bool } type executionTestOptions func(*_executionTestOptions) @@ -244,6 +247,12 @@ func validateRequiredExternalFields() executionTestOptions { } } +func computeStaticCost() executionTestOptions { + return func(options *_executionTestOptions) { + options.computeStaticCost = true + } +} + func TestExecutionEngine_Execute(t *testing.T) { run := func(testCase ExecutionEngineTestCase, withError bool, expectedErrorMessage string, options ...executionTestOptions) func(t *testing.T) { t.Helper() @@ -278,13 +287,16 @@ func TestExecutionEngine_Execute(t *testing.T) { } engineConf.plannerConfig.BuildFetchReasons = opts.propagateFetchReasons engineConf.plannerConfig.ValidateRequiredExternalFields = opts.validateRequiredExternalFields - engine, err := NewExecutionEngine(ctx, abstractlogger.Noop{}, engineConf, resolve.ResolverOptions{ + engineConf.plannerConfig.ComputeStaticCost = opts.computeStaticCost + engineConf.plannerConfig.StaticCostDefaultListSize = 10 + resolveOpts := resolve.ResolverOptions{ MaxConcurrency: 1024, ResolvableOptions: opts.resolvableOptions, ApolloRouterCompatibilitySubrequestHTTPError: opts.apolloRouterCompatibilitySubrequestHTTPError, PropagateFetchReasons: opts.propagateFetchReasons, ValidateRequiredExternalFields: opts.validateRequiredExternalFields, - }) + } + engine, err := NewExecutionEngine(ctx, abstractlogger.Noop{}, engineConf, resolveOpts) require.NoError(t, err) operation := testCase.operation(t) @@ -306,6 +318,15 @@ func TestExecutionEngine_Execute(t *testing.T) { return } + if withError { + require.Error(t, err) + if expectedErrorMessage != "" { + assert.Contains(t, err.Error(), expectedErrorMessage) + } + } else { + require.NoError(t, err) + } + if testCase.expectedJSONResponse != "" { assert.JSONEq(t, testCase.expectedJSONResponse, actualResponse) } @@ -314,14 +335,11 @@ func TestExecutionEngine_Execute(t *testing.T) { assert.Equal(t, testCase.expectedResponse, actualResponse) } - if withError { - require.Error(t, err) - if expectedErrorMessage != "" { - assert.Contains(t, err.Error(), expectedErrorMessage) - } - } else { - require.NoError(t, err) + if testCase.expectedStaticCost != 0 { + gotCost := operation.StaticCost() + require.Equal(t, testCase.expectedStaticCost, gotCost) } + } } @@ -872,7 +890,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, }, }, ChildNodes: []plan.TypeField{ @@ -931,7 +949,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, // Only for this field propagate the fetch reasons, // even if a user has asked for the interface in the query. FetchReasonFields: []string{"name"}, @@ -991,7 +1009,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, FetchReasonFields: []string{"name"}, // implementing is marked }, }, @@ -1062,7 +1080,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, }, }, ChildNodes: []plan.TypeField{ @@ -1135,7 +1153,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, FetchReasonFields: []string{"name"}, // implementing is marked }, }, @@ -4635,16 +4653,49 @@ func TestExecutionEngine_Execute(t *testing.T) { } ` - makeDataSource := func(t *testing.T, expectFetchReasons bool) []plan.DataSource { + type makeDataSourceOpts struct { + expectFetchReasons bool + includeCostConfig bool + } + + makeDataSource := func(t *testing.T, opts makeDataSourceOpts) []plan.DataSource { var expectedBody1 string var expectedBody2 string - if !expectFetchReasons { + if !opts.expectFetchReasons { expectedBody1 = `{"query":"{accounts {__typename ... on User {some {__typename id}} ... on Admin {some {__typename id}}}}"}` } else { expectedBody1 = `{"query":"{accounts {__typename ... on User {some {__typename id}} ... on Admin {some {__typename id}}}}","extensions":{"fetch_reasons":[{"typename":"Admin","field":"some","by_user":true},{"typename":"User","field":"id","by_subgraphs":["id-2"],"by_user":true,"is_key":true}]}}` } expectedBody2 = `{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {__typename title}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"3"}]}}` + // Cost config for DS1 (first subgraph): accounts service + var ds1CostConfig *plan.DataSourceCostConfig + if opts.includeCostConfig { + ds1CostConfig = &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "accounts"}: {HasWeight: true, Weight: 5}, + {TypeName: "User", FieldName: "some"}: {HasWeight: true, Weight: 2}, + {TypeName: "Admin", FieldName: "some"}: {HasWeight: true, Weight: 3}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "accounts"}: {AssumedSize: 3}, + }, + } + } + + // Cost config for DS2 (second subgraph): extends User/Admin with title + var ds2CostConfig *plan.DataSourceCostConfig + if opts.includeCostConfig { + ds2CostConfig = &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "User", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "User", FieldName: "title"}: {HasWeight: true, Weight: 4}, + {TypeName: "Admin", FieldName: "adminName"}: {HasWeight: true, Weight: 3}, + {TypeName: "Admin", FieldName: "title"}: {HasWeight: true, Weight: 5}, + }, + } + } + return []plan.DataSource{ mustGraphqlDataSourceConfiguration(t, "id-1", @@ -4680,6 +4731,7 @@ func TestExecutionEngine_Execute(t *testing.T) { FieldNames: []string{"id", "title", "some"}, }, }, + CostConfig: ds1CostConfig, FederationMetaData: plan.FederationMetaData{ Keys: plan.FederationFieldConfigurations{ { @@ -4730,6 +4782,7 @@ func TestExecutionEngine_Execute(t *testing.T) { FieldNames: []string{"id", "adminName", "title"}, }, }, + CostConfig: ds2CostConfig, FederationMetaData: plan.FederationMetaData{ Keys: plan.FederationFieldConfigurations{ { @@ -4772,24 +4825,24 @@ func TestExecutionEngine_Execute(t *testing.T) { return graphql.Request{ OperationName: "Accounts", Query: ` - query Accounts { - accounts { - ... on User { - some { - title + query Accounts { + accounts { + ... on User { + some { + title + } } - } - ... on Admin { - some { - __typename - id + ... on Admin { + some { + __typename + id + } } } - } - }`, + }`, } }, - dataSources: makeDataSource(t, false), + dataSources: makeDataSource(t, makeDataSourceOpts{expectFetchReasons: false}), expectedResponse: `{"data":{"accounts":[{"some":{"title":"User1"}},{"some":{"__typename":"User","id":"2"}},{"some":{"title":"User3"}}]}}`, })) @@ -4805,28 +4858,79 @@ func TestExecutionEngine_Execute(t *testing.T) { return graphql.Request{ OperationName: "Accounts", Query: ` - query Accounts { - accounts { - ... on User { - some { - title - } - } - ... on Admin { - some { - __typename - id + query Accounts { + accounts { + ... on User { + some { + title + } + } + ... on Admin { + some { + __typename + id + } + } } - } - } - }`, + }`, } }, - dataSources: makeDataSource(t, true), + dataSources: makeDataSource(t, makeDataSourceOpts{expectFetchReasons: true}), expectedResponse: `{"data":{"accounts":[{"some":{"title":"User1"}},{"some":{"__typename":"User","id":"2"}},{"some":{"title":"User3"}}]}}`, }, withFetchReasons(), )) + + t.Run("run with static cost computation", runWithoutError( + ExecutionEngineTestCase{ + schema: func(t *testing.T) *graphql.Schema { + t.Helper() + parseSchema, err := graphql.NewSchemaFromString(definition) + require.NoError(t, err) + return parseSchema + }(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + OperationName: "Accounts", + Query: ` + query Accounts { + accounts { + ... on User { + some { + title + } + } + ... on Admin { + some { + __typename + id + } + } + } + }`, + } + }, + dataSources: makeDataSource(t, makeDataSourceOpts{includeCostConfig: true}), + expectedResponse: `{"data":{"accounts":[{"some":{"title":"User1"}},{"some":{"__typename":"User","id":"2"}},{"some":{"title":"User3"}}]}}`, + // Cost breakdown with federation: + // Query.accounts: fieldCost=5, multiplier=3 (listSize) + // accounts returns interface [Node!]! with implementing types [User, Admin] + // + // Children (per interface member type): + // User.some: User: fieldCost=3 (DS1:2 + DS2:1 summed) + // User.title: 4 (DS2, resolved via _entities federation) + // cost = 3 + 4 = 7 + // + // Admin.some: User: fieldCost=3 (DS1 only) + // cost = 3 + // + // Children total = 7 + 3 = 10 + // (is it possible to improve accuracy here by using the largest fragment instead of the sum?) + // Total = (5 + 10) * 3 = 45 + expectedStaticCost: 45, + }, + computeStaticCost(), + )) }) t.Run("validation of optional @requires dependencies", func(t *testing.T) { @@ -5527,6 +5631,1187 @@ func TestExecutionEngine_Execute(t *testing.T) { }, withFetchReasons(), validateRequiredExternalFields())) }) }) + + t.Run("static cost computation", func(t *testing.T) { + t.Run("common on star wars scheme", func(t *testing.T) { + rootNodes := []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"hero", "droid"}}, + {TypeName: "Human", FieldNames: []string{"name", "height", "friends"}}, + {TypeName: "Droid", FieldNames: []string{"name", "primaryFunction", "friends"}}, + } + childNodes := []plan.TypeField{ + {TypeName: "Character", FieldNames: []string{"name", "friends"}}, + } + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", + }, + SchemaConfiguration: mustSchemaConfig( + t, + nil, + string(graphql.StarwarsSchema(t).RawSchema()), + ), + }) + + t.Run("droid with weighted plain fields", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + droid(id: "R2D2") { + name + primaryFunction + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + }}, + customConfig, + ), + }, + fields: []plan.FieldConfiguration{ + { + TypeName: "Query", FieldName: "droid", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "id", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, + }, + }, + }, + }, + expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + expectedStaticCost: 18, // Query.droid (1) + droid.name (17) + }, + computeStaticCost(), + )) + + t.Run("droid with weighted plain fields and an argument", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + droid(id: "R2D2") { + name + primaryFunction + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "droid"}: { + ArgumentWeights: map[string]int{"id": 3}, + HasWeight: false, + }, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + }}, + customConfig, + ), + }, + fields: []plan.FieldConfiguration{ + { + TypeName: "Query", FieldName: "droid", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "id", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, + }, + }, + }, + }, + expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + expectedStaticCost: 21, // Query.droid (1) + Query.droid.id (3) + droid.name (17) + }, + computeStaticCost(), + )) + + t.Run("hero field has weight (returns interface) and with concrete fragment", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + name + ... on Human { height } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker","height":"12"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 3}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + Types: map[string]int{ + "Human": 13, + }, + }}, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker","height":"12"}}}`, + expectedStaticCost: 22, // Query.hero (2) + Human.height (3) + Droid.name (17=max(7, 17)) + }, + computeStaticCost(), + )) + + t.Run("hero field has no weight (returns interface) and with concrete fragment", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { name } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + Types: map[string]int{ + "Human": 13, + "Droid": 11, + }, + }}, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker"}}}`, + expectedStaticCost: 30, // Query.Human (13) + Droid.name (17=max(7, 17)) + }, + computeStaticCost(), + )) + + t.Run("query hero without assumedSize on friends", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + {"__typename":"Human","name":"Luke Skywalker","height":"12"}, + {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} + ]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 127, // Query.hero(max(7,5))+10*(Human(max(7,5))+Human.name(2)+Human.height(1)+Droid.name(2)) + }, + computeStaticCost(), + )) + + t.Run("query hero with assumedSize on friends", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + {"__typename":"Human","name":"Luke Skywalker","height":"12"}, + {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} + ]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 5}, + {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 20}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 247, // Query.hero(max(7,5))+ 20 * (7+2+2+1) + // We pick maximum on every path independently. This is to reveal the upper boundary. + // Query.hero: picked maximum weight (Human=7) out of two types (Human, Droid) + // Query.hero.friends: the max possible weight (7) is for implementing class Human + // of the returned type of Character; the multiplier picked for the Droid since + // it is the maximum possible value - we considered the enclosing type that contains it. + }, + computeStaticCost(), + )) + + t.Run("query hero with assumedSize on friends and weight defined", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + {"__typename":"Human","name":"Luke Skywalker","height":"12"}, + {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} + ]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "friends"}: {HasWeight: true, Weight: 3}, + {TypeName: "Droid", FieldName: "friends"}: {HasWeight: true, Weight: 4}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 5}, + {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 20}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 187, // Query.hero(max(7,5))+ 20 * (4+2+2+1) + }, + computeStaticCost(), + )) + + t.Run("query hero with empty cost structures", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + {"__typename":"Human","name":"Luke Skywalker","height":"12"}, + {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} + ]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{}, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 11, // Query.hero(max(1,1))+ 10 * 1 + }, + computeStaticCost(), + )) + + t.Run("named fragment on interface", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: ` + fragment CharacterFields on Character { + name + friends { name } + } + { hero { ...CharacterFields } } + `, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke","friends":[{"name":"Leia"}]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 3}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 5}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 4}, + {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 6}, + }, + Types: map[string]int{ + "Human": 2, + "Droid": 3, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke","friends":[{"name":"Leia"}]}}}`, + // Cost calculation: + // Query.hero: 2 + // Character.name: max(Human.name=3, Droid.name=5) = 5 + // friends listSize: max(4, 6) = 6 + // Character type: max(Human=2, Droid=3) = 3 + // name: max(Human.name=3, Droid.name=5) = 5 + // Total: 2 + 5 + 6 * (3 + 5) + expectedStaticCost: 55, + }, + computeStaticCost(), + )) + + t.Run("named fragment with concrete type", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: ` + fragment HumanFields on Human { + name + height + } + { hero { ...HumanFields } } + `, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke","height":"1.72"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 3}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 5}, + }, + Types: map[string]int{ + "Human": 1, + "Droid": 1, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke","height":"1.72"}}}`, + // Total: 2 + 3 + 7 + expectedStaticCost: 12, + }, + computeStaticCost(), + )) + + }) + + t.Run("union types", func(t *testing.T) { + unionSchema := ` + type Query { + search(term: String!): [SearchResult!] + } + union SearchResult = User | Post | Comment + type User @key(fields: "id") { + id: ID! + name: String! + email: String! + } + type Post @key(fields: "id") { + id: ID! + title: String! + body: String! + } + type Comment @key(fields: "id") { + id: ID! + text: String! + } + ` + schema, err := graphql.NewSchemaFromString(unionSchema) + require.NoError(t, err) + + rootNodes := []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"search"}}, + {TypeName: "User", FieldNames: []string{"id", "name", "email"}}, + {TypeName: "Post", FieldNames: []string{"id", "title", "body"}}, + {TypeName: "Comment", FieldNames: []string{"id", "text"}}, + } + childNodes := []plan.TypeField{} + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", + }, + SchemaConfiguration: mustSchemaConfig(t, nil, unionSchema), + }) + fieldConfig := []plan.FieldConfiguration{ + { + TypeName: "Query", + FieldName: "search", + Path: []string{"search"}, + Arguments: []plan.ArgumentConfiguration{ + {Name: "term", SourceType: plan.FieldArgumentSource, RenderConfig: plan.RenderArgumentAsGraphQLValue}, + }, + }, + } + + t.Run("union with all member types", runWithoutError( + ExecutionEngineTestCase{ + schema: schema, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + search(term: "test") { + ... on User { name email } + ... on Post { title body } + ... on Comment { text } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"search":[{"__typename":"User","name":"John","email":"john@test.com"}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "User", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "User", FieldName: "email"}: {HasWeight: true, Weight: 3}, + {TypeName: "Post", FieldName: "title"}: {HasWeight: true, Weight: 4}, + {TypeName: "Post", FieldName: "body"}: {HasWeight: true, Weight: 5}, + {TypeName: "Comment", FieldName: "text"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "search"}: {AssumedSize: 5}, + }, + Types: map[string]int{ + "User": 2, + "Post": 3, + "Comment": 1, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"search":[{"name":"John","email":"john@test.com"}]}}`, + // search listSize: 10 + // For each SearchResult, use max across all union members: + // Type weight: max(User=2, Post=3, Comment=1) = 3 + // Fields: all fields from all fragments are counted + // (2 + 3) + (4 + 5) + (1) = 15 + // TODO: this is not correct, we should pick a maximum sum among types implementing union. + // 9 should be used instead of 15 + // Total: 5 * (3 + 15) + expectedStaticCost: 90, + }, + computeStaticCost(), + )) + + t.Run("union with weighted search field", runWithoutError( + ExecutionEngineTestCase{ + schema: schema, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + search(term: "test") { + ... on User { name } + ... on Post { title } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"search":[{"__typename":"User","name":"John"}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "User", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Post", FieldName: "title"}: {HasWeight: true, Weight: 5}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "search"}: {AssumedSize: 3}, + }, + Types: map[string]int{ + "User": 6, + "Post": 10, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"search":[{"name":"John"}]}}`, + // Query.search: max(User=10, Post=6) + // search listSize: 3 + // Union members: + // All fields from all fragments: User.name(2) + Post.title(5) + // Total: 3 * (10+2+5) + // TODO: we might correct this by counting only members of one implementing types + // of a union when fragments are used. + expectedStaticCost: 51, + }, + computeStaticCost(), + )) + }) + + t.Run("listSize", func(t *testing.T) { + listSchema := ` + type Query { + items(first: Int, last: Int): [Item!] + } + type Item @key(fields: "id") { + id: ID + } + ` + schemaSlicing, err := graphql.NewSchemaFromString(listSchema) + require.NoError(t, err) + rootNodes := []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"items"}}, + {TypeName: "Item", FieldNames: []string{"id"}}, + } + childNodes := []plan.TypeField{} + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", + }, + SchemaConfiguration: mustSchemaConfig(t, nil, listSchema), + }) + fieldConfig := []plan.FieldConfiguration{ + { + TypeName: "Query", + FieldName: "items", + Path: []string{"items"}, + Arguments: []plan.ArgumentConfiguration{ + { + Name: "first", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, + }, + { + Name: "last", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, + }, + }, + }, + } + t.Run("multiple slicing arguments as literals", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query MultipleSlicingArguments { + items(first: 5, last: 12) { id } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[ {"id":"2"}, {"id":"3"} ]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 8, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 3, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[{"id":"2"},{"id":"3"}]}}`, + expectedStaticCost: 48, // slicingArgument(12) * (Item(3)+Item.id(1)) + }, + computeStaticCost(), + )) + t.Run("slicing argument as a variable", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query SlicingWithVariable($limit: Int!) { + items(first: $limit) { id } + }`, + Variables: []byte(`{"limit": 25}`), + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[ {"id":"2"}, {"id":"3"} ]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 8, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 3, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[{"id":"2"},{"id":"3"}]}}`, + expectedStaticCost: 100, // slicingArgument($limit=25) * (Item(3)+Item.id(1)) + }, + computeStaticCost(), + )) + t.Run("slicing argument not provided falls back to assumedSize", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query NoSlicingArg { + items { id } + }`, + // No slicing arguments provided - should fall back to assumedSize + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[{"id":"1"},{"id":"2"}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 15, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[{"id":"1"},{"id":"2"}]}}`, + expectedStaticCost: 45, // Total: 15 * (2 + 1) + }, + computeStaticCost(), + )) + t.Run("zero slicing argument falls back to assumedSize", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query ZeroSlicing { + items(first: 0) { id } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 20, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[]}}`, + expectedStaticCost: 60, // 20 * (2 + 1) + }, + computeStaticCost(), + )) + t.Run("negative slicing argument falls back to assumedSize", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query NegativeSlicing { + items(first: -5) { id } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 25, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[]}}`, + expectedStaticCost: 75, // 25 * (2 + 1) + }, + computeStaticCost(), + )) + + }) + + t.Run("nested lists with compounding multipliers", func(t *testing.T) { + nestedSchema := ` + type Query { + users(first: Int): [User!] + } + type User @key(fields: "id") { + id: ID! + posts(first: Int): [Post!] + } + type Post @key(fields: "id") { + id: ID! + comments(first: Int): [Comment!] + } + type Comment @key(fields: "id") { + id: ID! + text: String! + } + ` + schemaNested, err := graphql.NewSchemaFromString(nestedSchema) + require.NoError(t, err) + + rootNodes := []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"users"}}, + {TypeName: "User", FieldNames: []string{"id", "posts"}}, + {TypeName: "Post", FieldNames: []string{"id", "comments"}}, + {TypeName: "Comment", FieldNames: []string{"id", "text"}}, + } + childNodes := []plan.TypeField{} + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", + }, + SchemaConfiguration: mustSchemaConfig(t, nil, nestedSchema), + }) + fieldConfig := []plan.FieldConfiguration{ + { + TypeName: "Query", FieldName: "users", Path: []string{"users"}, + Arguments: []plan.ArgumentConfiguration{ + {Name: "first", SourceType: plan.FieldArgumentSource, RenderConfig: plan.RenderArgumentAsGraphQLValue}, + }, + }, + { + TypeName: "User", FieldName: "posts", Path: []string{"posts"}, + Arguments: []plan.ArgumentConfiguration{ + {Name: "first", SourceType: plan.FieldArgumentSource, RenderConfig: plan.RenderArgumentAsGraphQLValue}, + }, + }, + { + TypeName: "Post", FieldName: "comments", Path: []string{"comments"}, + Arguments: []plan.ArgumentConfiguration{ + {Name: "first", SourceType: plan.FieldArgumentSource, RenderConfig: plan.RenderArgumentAsGraphQLValue}, + }, + }, + } + + t.Run("nested lists with slicing arguments", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaNested, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + users(first: 10) { + posts(first: 5) { + comments(first: 3) { text } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"users":[{"posts":[{"comments":[{"text":"hello"}]}]}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Comment", FieldName: "text"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "users"}: { + AssumedSize: 100, + SlicingArguments: []string{"first"}, + }, + {TypeName: "User", FieldName: "posts"}: { + AssumedSize: 50, + SlicingArguments: []string{"first"}, + }, + {TypeName: "Post", FieldName: "comments"}: { + AssumedSize: 20, + SlicingArguments: []string{"first"}, + }, + }, + Types: map[string]int{ + "User": 4, + "Post": 3, + "Comment": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"users":[{"posts":[{"comments":[{"text":"hello"}]}]}]}}`, + // Cost calculation: + // users(first:10): multiplier 10 + // User type weight: 4 + // posts(first:5): multiplier 5 + // Post type weight: 3 + // comments(first:3): multiplier 3 + // Comment type weight: 2 + // text weight: 1 + // Total: 10 * (4 + 5 * (3 + 3 * (2 + 1))) + expectedStaticCost: 640, + }, + computeStaticCost(), + )) + + t.Run("nested lists fallback to assumedSize when slicing arg not provided", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaNested, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + users(first: 2) { + posts { + comments(first: 4) { text } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"users":[{"posts":[{"comments":[{"text":"hi"}]}]}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Comment", FieldName: "text"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "users"}: { + AssumedSize: 100, + SlicingArguments: []string{"first"}, + }, + {TypeName: "User", FieldName: "posts"}: { + AssumedSize: 50, // no slicing arg, should use this + }, + {TypeName: "Post", FieldName: "comments"}: { + AssumedSize: 20, + SlicingArguments: []string{"first"}, + }, + }, + Types: map[string]int{ + "User": 4, + "Post": 3, + "Comment": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"users":[{"posts":[{"comments":[{"text":"hi"}]}]}]}}`, + // Cost calculation: + // users(first:2): multiplier 2 + // User type weight: 4 + // posts (no arg): assumedSize 50 + // Post type weight: 3 + // comments(first:4): multiplier 4 + // Comment type weight: 2 + // text weight: 1 + // Total: 2 * (4 + 50 * (3 + 4 * (2 + 1))) + expectedStaticCost: 1508, + }, + computeStaticCost(), + )) + }) + + }) } func testNetHttpClient(t *testing.T, testCase roundTripperTestCase) *http.Client { @@ -5632,7 +6917,7 @@ func TestExecutionEngine_GetCachedPlan(t *testing.T) { } report := operationreport.Report{} - cachedPlan := engine.getCachedPlan(firstInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) + cachedPlan, _ := engine.getCachedPlan(firstInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) _, oldestCachedPlan, _ := engine.executionPlanCache.GetOldest() assert.False(t, report.HasErrors()) assert.Equal(t, 1, engine.executionPlanCache.Len()) @@ -5643,7 +6928,7 @@ func TestExecutionEngine_GetCachedPlan(t *testing.T) { http.CanonicalHeaderKey("Authorization"): []string{"123abc"}, } - cachedPlan = engine.getCachedPlan(secondInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) + cachedPlan, _ = engine.getCachedPlan(secondInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) _, oldestCachedPlan, _ = engine.executionPlanCache.GetOldest() assert.False(t, report.HasErrors()) assert.Equal(t, 1, engine.executionPlanCache.Len()) @@ -5660,7 +6945,7 @@ func TestExecutionEngine_GetCachedPlan(t *testing.T) { } report := operationreport.Report{} - cachedPlan := engine.getCachedPlan(firstInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) + cachedPlan, _ := engine.getCachedPlan(firstInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) _, oldestCachedPlan, _ := engine.executionPlanCache.GetOldest() assert.False(t, report.HasErrors()) assert.Equal(t, 1, engine.executionPlanCache.Len()) @@ -5671,7 +6956,7 @@ func TestExecutionEngine_GetCachedPlan(t *testing.T) { http.CanonicalHeaderKey("Authorization"): []string{"xyz098"}, } - cachedPlan = engine.getCachedPlan(secondInternalExecCtx, differentGqlRequest.Document(), schema.Document(), differentGqlRequest.OperationName, &report) + cachedPlan, _ = engine.getCachedPlan(secondInternalExecCtx, differentGqlRequest.Document(), schema.Document(), differentGqlRequest.OperationName, &report) _, oldestCachedPlan, _ = engine.executionPlanCache.GetOldest() assert.False(t, report.HasErrors()) assert.Equal(t, 2, engine.executionPlanCache.Len()) diff --git a/execution/graphql/request.go b/execution/graphql/request.go index a3ab0888d0..85a7051d80 100644 --- a/execution/graphql/request.go +++ b/execution/graphql/request.go @@ -6,8 +6,11 @@ import ( "io" "net/http" + "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) @@ -42,6 +45,8 @@ type Request struct { request resolve.Request validForSchema map[uint64]ValidationResult + + staticCost int } func UnmarshalRequest(reader io.Reader, request *Request) error { @@ -189,3 +194,15 @@ func (r *Request) OperationType() (OperationType, error) { return OperationTypeUnknown, nil } + +func (r *Request) ComputeStaticCost(calc *plan.CostCalculator, config plan.Configuration, variables *astjson.Value) { + if calc != nil { + r.staticCost = calc.GetStaticCost(config, variables) + } else { + r.staticCost = 0 + } +} + +func (r *Request) StaticCost() int { + return r.staticCost +} diff --git a/v2/pkg/ast/ast_type.go b/v2/pkg/ast/ast_type.go index 6c22f33fe3..92aff65a7b 100644 --- a/v2/pkg/ast/ast_type.go +++ b/v2/pkg/ast/ast_type.go @@ -250,6 +250,7 @@ func (d *Document) ResolveTypeNameString(ref int) string { return unsafebytes.BytesToString(d.ResolveTypeNameBytes(ref)) } +// ResolveUnderlyingType unwraps the ref type until it finds the named type. func (d *Document) ResolveUnderlyingType(ref int) (typeRef int) { typeRef = ref graphqlType := d.Types[ref] @@ -261,6 +262,7 @@ func (d *Document) ResolveUnderlyingType(ref int) (typeRef int) { return } +// ResolveListOrNameType unwraps the ref type until it finds the named or list type. func (d *Document) ResolveListOrNameType(ref int) (typeRef int) { typeRef = ref graphqlType := d.Types[ref] diff --git a/v2/pkg/engine/plan/configuration.go b/v2/pkg/engine/plan/configuration.go index dafcb021c5..cff8109235 100644 --- a/v2/pkg/engine/plan/configuration.go +++ b/v2/pkg/engine/plan/configuration.go @@ -46,6 +46,12 @@ type Configuration struct { // entity. // This option requires BuildFetchReasons set to true. ValidateRequiredExternalFields bool + + // ComputeStaticCost enables static cost computation for operations. + ComputeStaticCost bool + + // When the list size is unknown from directives, this value is used as a default for static cost. + StaticCostDefaultListSize int } type DebugConfiguration struct { diff --git a/v2/pkg/engine/plan/datasource_configuration.go b/v2/pkg/engine/plan/datasource_configuration.go index f196a00bc8..afc922be8d 100644 --- a/v2/pkg/engine/plan/datasource_configuration.go +++ b/v2/pkg/engine/plan/datasource_configuration.go @@ -51,6 +51,9 @@ type DataSourceMetadata struct { Directives *DirectiveConfigurations + // CostConfig holds IBM static cost configuration for this data source + CostConfig *DataSourceCostConfig + rootNodesIndex map[string]fieldsIndex // maps TypeName to fieldsIndex childNodesIndex map[string]fieldsIndex // maps TypeName to fieldsIndex @@ -287,6 +290,9 @@ type DataSource interface { Hash() DSHash FederationConfiguration() FederationMetaData CreatePlannerConfiguration(logger abstractlogger.Logger, fetchConfig *objectFetchConfiguration, pathConfig *plannerPathsConfiguration, configuration *Configuration) PlannerConfiguration + + // GetCostConfig returns the IBM static cost configuration for this data source + GetCostConfig() *DataSourceCostConfig } func (d *dataSourceConfiguration[T]) CustomConfiguration() T { @@ -335,6 +341,13 @@ func (d *dataSourceConfiguration[T]) Hash() DSHash { return d.hash } +func (d *dataSourceConfiguration[T]) GetCostConfig() *DataSourceCostConfig { + if d.DataSourceMetadata == nil { + return nil + } + return d.DataSourceMetadata.CostConfig +} + type DataSourcePlannerConfiguration struct { RequiredFields FederationFieldConfigurations ParentPath string diff --git a/v2/pkg/engine/plan/node_selection_visitor.go b/v2/pkg/engine/plan/node_selection_visitor.go index 8f9f80fb48..bbcffd3c43 100644 --- a/v2/pkg/engine/plan/node_selection_visitor.go +++ b/v2/pkg/engine/plan/node_selection_visitor.go @@ -34,7 +34,7 @@ type nodeSelectionVisitor struct { visitedFieldsRequiresChecks map[fieldIndexKey]struct{} // visitedFieldsRequiresChecks is a map[fieldIndexKey] of already processed fields which we check for presence of @requires directive visitedFieldsKeyChecks map[fieldIndexKey]struct{} // visitedFieldsKeyChecks is a map[fieldIndexKey] of already processed fields which we check for @key requirements - visitedFieldsAbstractChecks map[int]struct{} // visitedFieldsAbstractChecks is a map[FieldRef] of already processed fields which we check for abstract type, e.g. union or interface + visitedFieldsAbstractChecks map[int]struct{} // visitedFieldsAbstractChecks is a map[fieldRef] of already processed fields which we check for abstract type, e.g. union or interface fieldDependsOn map[fieldIndexKey][]int // fieldDependsOn is a map[fieldIndexKey][]fieldRef - holds list of field refs which are required by a field ref, e.g. field should be planned only after required fields were planned fieldRefDependsOn map[int][]int // fieldRefDependsOn is a map[fieldRef][]fieldRef - holds list of field refs which are required by a field ref, it is a second index without datasource hash fieldRequirementsConfigs map[fieldIndexKey][]FederationFieldConfiguration // fieldRequirementsConfigs is a map[fieldIndexKey]FederationFieldConfiguration - holds a list of required configuratuibs for a field ref to later built representation variables diff --git a/v2/pkg/engine/plan/path_builder_visitor.go b/v2/pkg/engine/plan/path_builder_visitor.go index 8348f6d5e6..7ccb90bb60 100644 --- a/v2/pkg/engine/plan/path_builder_visitor.go +++ b/v2/pkg/engine/plan/path_builder_visitor.go @@ -43,7 +43,7 @@ type pathBuilderVisitor struct { addedPathTracker []pathConfiguration // addedPathTracker is a list of paths which were added addedPathTrackerIndex map[string][]int // addedPathTrackerIndex is a map of path to index in addedPathTracker - fieldDependenciesForPlanners map[int][]int // fieldDependenciesForPlanners is a map[FieldRef][]plannerIdx holds list of planner ids which depends on a field ref. Used for @key dependencies + fieldDependenciesForPlanners map[int][]int // fieldDependenciesForPlanners is a map[fieldRef][]plannerIdx holds list of planner ids which depends on a field ref. Used for @key dependencies fieldsPlannedOn map[int][]int // fieldsPlannedOn is a map[fieldRef][]plannerIdx holds list of planner ids which planned a field ref secondaryRun bool // secondaryRun is a flag to indicate that we're running the pathBuilderVisitor not the first time diff --git a/v2/pkg/engine/plan/plan.go b/v2/pkg/engine/plan/plan.go index 15f97769f0..1cca76d896 100644 --- a/v2/pkg/engine/plan/plan.go +++ b/v2/pkg/engine/plan/plan.go @@ -14,11 +14,14 @@ const ( type Plan interface { PlanKind() Kind SetFlushInterval(interval int64) + GetStaticCostCalculator() *CostCalculator + SetStaticCostCalculator(calc *CostCalculator) } type SynchronousResponsePlan struct { - Response *resolve.GraphQLResponse - FlushInterval int64 + Response *resolve.GraphQLResponse + FlushInterval int64 + StaticCostCalculator *CostCalculator } func (s *SynchronousResponsePlan) SetFlushInterval(interval int64) { @@ -29,9 +32,18 @@ func (*SynchronousResponsePlan) PlanKind() Kind { return SynchronousResponseKind } +func (s *SynchronousResponsePlan) GetStaticCostCalculator() *CostCalculator { + return s.StaticCostCalculator +} + +func (s *SynchronousResponsePlan) SetStaticCostCalculator(c *CostCalculator) { + s.StaticCostCalculator = c +} + type SubscriptionResponsePlan struct { - Response *resolve.GraphQLSubscription - FlushInterval int64 + Response *resolve.GraphQLSubscription + FlushInterval int64 + StaticCostCalculator *CostCalculator } func (s *SubscriptionResponsePlan) SetFlushInterval(interval int64) { @@ -41,3 +53,11 @@ func (s *SubscriptionResponsePlan) SetFlushInterval(interval int64) { func (*SubscriptionResponsePlan) PlanKind() Kind { return SubscriptionResponseKind } + +func (s *SubscriptionResponsePlan) GetStaticCostCalculator() *CostCalculator { + return s.StaticCostCalculator +} + +func (s *SubscriptionResponsePlan) SetStaticCostCalculator(c *CostCalculator) { + s.StaticCostCalculator = c +} diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index 7b948ab6bd..73df967228 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -18,6 +18,7 @@ type Planner struct { planningWalker *astvisitor.Walker planningVisitor *Visitor + costVisitor *StaticCostVisitor nodeSelectionBuilder *NodeSelectionBuilder planningPathBuilder *PathBuilder @@ -59,6 +60,7 @@ func NewPlanner(config Configuration) (*Planner, error) { // planning planningWalker := astvisitor.NewWalkerWithID(48, "PlanningWalker") + planningVisitor := &Visitor{ Walker: &planningWalker, fieldConfigs: map[int]*FieldConfiguration{}, @@ -75,14 +77,6 @@ func NewPlanner(config Configuration) (*Planner, error) { return p, nil } -func (p *Planner) SetConfig(config Configuration) { - p.config = config -} - -func (p *Planner) SetDebugConfig(config DebugConfiguration) { - p.config.Debug = config -} - type _opts struct { includeQueryPlanInResponse bool } @@ -165,6 +159,20 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string p.planningWalker.RegisterEnterDirectiveVisitor(p.planningVisitor) p.planningWalker.RegisterInlineFragmentVisitor(p.planningVisitor) + // Register cost visitor on the same walker (will be invoked after planningVisitor hooks). + // We have to register it last in the walker, as it depends on the fieldPlanners field of the + // visitor. That field is populated in the AllowVisitor callback. Walker calls Enter* callbacks + // in the order they were registered, and Leave* callbacks in the reverse order. + if p.config.ComputeStaticCost { + p.costVisitor = NewStaticCostVisitor(p.planningWalker, operation, definition) + p.costVisitor.planners = plannersConfigurations + p.costVisitor.fieldPlanners = &p.planningVisitor.fieldPlanners + p.costVisitor.operationDefinition = &p.planningVisitor.operationDefinitionRef + + p.planningWalker.RegisterEnterFieldVisitor(p.costVisitor) + p.planningWalker.RegisterLeaveFieldVisitor(p.costVisitor) + } + for key := range p.planningVisitor.planners { if p.config.MinifySubgraphOperations { if dataSourceWithMinify, ok := p.planningVisitor.planners[key].Planner().(SubgraphRequestMinifier); ok { @@ -195,12 +203,18 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string } } - // create raw execution plan + // create a raw execution plan p.planningWalker.Walk(operation, definition, report) if report.HasErrors() { return } + if p.config.ComputeStaticCost { + costCalc := NewCostCalculator() + costCalc.tree = p.costVisitor.finalCostTree() + p.planningVisitor.plan.SetStaticCostCalculator(costCalc) + } + return p.planningVisitor.plan } diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go new file mode 100644 index 0000000000..805deaa8f7 --- /dev/null +++ b/v2/pkg/engine/plan/static_cost.go @@ -0,0 +1,562 @@ +package plan + +/* + +Static Cost Analysis. + +Planning visitor collects information for the costCalculator via EnterField and LeaveField hooks. +Calculator builds a tree of nodes, each node corresponding to the requested field. +After the planning is done, a callee could get a ref to the calculator and request cost calculation. +Cost calculation walks the previously built tree and using variables provided with operation, +estimates the static cost. + +https://ibm.github.io/graphql-specs/cost-spec.html + +It builds on top of IBM spec for @cost and @listSize directive with a few changes. + +* We use Int! for weights instead of floats packed in String!. +* When weight is specified for the type and a field returns the list of that type, +this weight (along with children's costs) is multiplied too. + +A few things on the TBD list: + +* Support of SizedFields of @listSize +* Weights on fields of InputObjects with recursion +* Weights on arguments of directives + +*/ + +import ( + "fmt" + "strings" + + "github.com/wundergraph/astjson" +) + +// We don't allow configuring default weights for enums, scalars and objects. +// But they could be in the future. + +const DefaultEnumScalarWeight = 0 +const DefaultObjectWeight = 1 + +// FieldWeight defines cost configuration for a specific field of an object or input object. +type FieldWeight struct { + + // Weight is the cost of this field definition. It could be negative or zero. + // Should be used only if HasWeight is true. + Weight int + + // Means that there was weight attached to the field definition. + HasWeight bool + + // ArgumentWeights maps an argument name to its weight. + // Location: ARGUMENT_DEFINITION + ArgumentWeights map[string]int +} + +// FieldListSize contains parsed data from the @listSize directive for an object field. +type FieldListSize struct { + // AssumedSize is the default assumed size when no slicing argument is provided. + // If 0, the global default list cost is used. + AssumedSize int + + // SlicingArguments are argument names that control list size (e.g., "first", "last", "limit") + // The value of these arguments will be used as the multiplier. + SlicingArguments []string + + // SizedFields are contains field names in the returned object that returns lists. + // For these lists we estimate the size based on the value of the slicing arguments or AssumedSize. + SizedFields []string + + // RequireOneSlicingArgument if true, at least one slicing argument must be provided. + // If false and no slicing argument is provided, AssumedSize is used. + // It is not used right now since it is required only for validation. + RequireOneSlicingArgument bool +} + +// multiplier returns the multiplier based on arguments and variables. +// It picks the maximum value among slicing arguments, otherwise it tries to use AssumedSize. +// If neither is available, it falls back to defaultListSize. +// +// Does not take into account the SizedFields; TBD later. +func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo, vars *astjson.Value, defaultListSize int) int { + multiplier := -1 + for _, slicingArg := range ls.SlicingArguments { + arg, ok := arguments[slicingArg] + if !ok || !arg.isSimple { + continue + } + + var value int + // Argument could be a variable or literal value. + if arg.hasVariable { + if vars == nil { + continue + } + if v := vars.Get(arg.varName); v == nil || v.Type() != astjson.TypeNumber { + continue + } + value = vars.GetInt(arg.varName) + } else if arg.intValue > 0 { + value = arg.intValue + } + + if value > 0 && value > multiplier { + multiplier = value + } + } + + if multiplier == -1 && ls.AssumedSize > 0 { + multiplier = ls.AssumedSize + } + if multiplier == -1 { + multiplier = defaultListSize + } + return multiplier +} + +// DataSourceCostConfig holds all cost configurations for a data source. +// This data is passed from the composition. +type DataSourceCostConfig struct { + // Weights maps field coordinate to its weights. Cannot be on fields of interfaces. + // Location: FIELD_DEFINITION, INPUT_FIELD_DEFINITION + Weights map[FieldCoordinate]*FieldWeight + + // ListSizes maps field coordinates to their respective list size configurations. + // Location: FIELD_DEFINITION + ListSizes map[FieldCoordinate]*FieldListSize + + // Types maps TypeName to the weight of the object, scalar or enum definition. + // If TypeName is not present, the default value for Enums and Scalars is 0, otherwise 1. + // Weight assigned to the field or argument definitions overrides the weight of type definition. + // Location: ENUM, OBJECT, SCALAR + Types map[string]int + + // Arguments on directives is a special case. They use a special kind of coordinate: + // directive name + argument name. That should be the key mapped to the weight. + // + // Directives can be used on [input] object fields and arguments of fields. This creates + // mutual recursion between them; it complicates cost calculation. + // We avoid them intentionally in the first iteration. +} + +// NewDataSourceCostConfig creates a new cost config with defaults +func NewDataSourceCostConfig() *DataSourceCostConfig { + return &DataSourceCostConfig{ + Weights: make(map[FieldCoordinate]*FieldWeight), + ListSizes: make(map[FieldCoordinate]*FieldListSize), + Types: make(map[string]int), + } +} + +// EnumScalarTypeWeight returns the cost for an enum or scalar types +func (c *DataSourceCostConfig) EnumScalarTypeWeight(enumName string) int { + if c == nil { + return 0 + } + if cost, ok := c.Types[enumName]; ok { + return cost + } + return DefaultEnumScalarWeight +} + +// ObjectTypeWeight returns the default object cost +func (c *DataSourceCostConfig) ObjectTypeWeight(name string) int { + if c == nil { + return DefaultObjectWeight + } + if cost, ok := c.Types[name]; ok { + return cost + } + return DefaultObjectWeight +} + +// CostTreeNode represents a node in the cost calculation tree +// Based on IBM GraphQL Cost Specification: https://ibm.github.io/graphql-specs/cost-spec.html +type CostTreeNode struct { + parent *CostTreeNode + + // dataSourceHashes identifies which data sources resolve this field. + dataSourceHashes []DSHash + + // children contain child field costs + children []*CostTreeNode + + // The data below is stored for deferred cost calculation. + // We populate these fields in EnterField and use them as a source of truth in LeaveField. + + // fieldRef is the AST field reference. Used by the visitor to build the tree. + fieldRef int + + // Enclosing type name and field name + fieldCoord FieldCoordinate + + // fieldTypeName contains the name of an unwrapped (named) type that is returned by this field. + fieldTypeName string + + // implementTypeNames contains the names of all types that implement this interface/union field. + implementingTypeNames []string + + // arguments contain the values of arguments passed to the field. + arguments map[string]ArgumentInfo + + returnsListType bool + returnsSimpleType bool + returnsAbstractType bool + isEnclosingTypeAbstract bool +} + +func (node *CostTreeNode) maxWeightImplementingField(config *DataSourceCostConfig, fieldName string) *FieldWeight { + var maxWeight *FieldWeight + for _, implTypeName := range node.implementingTypeNames { + // Get the cost config for the field of an implementing type. + coord := FieldCoordinate{implTypeName, fieldName} + fieldWeight := config.Weights[coord] + + if fieldWeight != nil { + if fieldWeight.HasWeight && (maxWeight == nil || fieldWeight.Weight > maxWeight.Weight) { + maxWeight = fieldWeight + } + } + } + return maxWeight +} + +func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostConfig, fieldName string, arguments map[string]ArgumentInfo, vars *astjson.Value, defaultListSize int) *FieldListSize { + var maxMultiplier int + var maxListSize *FieldListSize + for _, implTypeName := range node.implementingTypeNames { + coord := FieldCoordinate{implTypeName, fieldName} + listSize := config.ListSizes[coord] + + if listSize != nil { + multiplier := listSize.multiplier(arguments, vars, defaultListSize) + if maxListSize == nil || multiplier > maxMultiplier { + maxMultiplier = multiplier + maxListSize = listSize + } + } + } + return maxListSize +} + +// staticCost calculates the static cost of this node and all descendants +func (node *CostTreeNode) staticCost(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, defaultListSize int) int { + if node == nil { + return 0 + } + + fieldCost, argsCost, directivesCost, multiplier := node.costsAndMultiplier(configs, variables, defaultListSize) + + // Sum children (fields) costs + var childrenCost int + for _, child := range node.children { + childrenCost += child.staticCost(configs, variables, defaultListSize) + } + + // Apply multiplier to children cost (for list fields) + if multiplier == 0 { + multiplier = 1 + } + cost := argsCost + directivesCost + if cost < 0 { + // If arguments and directive weights decrease the field cost, floor it to zero. + cost = 0 + } + // Here we do not follow IBM spec. IBM spec does not use the cost of the object itself + // in multiplication. It assumes that the weight of the type should be just summed up + // without regard to the size of the list. + // + // We, instead, multiply with field cost. + // If there is a weight attached to the type that is returned (resolved) by the field, + // the more objects are requested, the more expensive it should be. + // This, in turn, has some ambiguity for definitions of the weights for the list types. + // "A: [Obj] @cost(weight: 5)" means that the cost of the field is 5 for each object in the list. + // "type Object @cost(weight: 5) { ... }" does exactly the same thing. + // Weight defined on a field has priority over the weight defined on a type. + cost += (fieldCost + childrenCost) * multiplier + + return cost +} + +// costsAndMultiplier fills in the cost values for a node based on its data sources. +// +// For this node we sum weights of the field or its returned type for all the data sources. +// Each data source can have its own cost configuration. If we plan field on two data sources, +// it means more work for the router: we should sum the costs. +// +// fieldCost is the weight of this field or its returned type +// argsCost is the sum of argument weights and input fields used on this field. +// Weights on directives ignored for now. +// For the multiplier we pick the maximum field weight of implementing types and then +// the maximum among slicing arguments. +func (node *CostTreeNode) costsAndMultiplier(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, defaultListSize int) (fieldCost, argsCost, directiveCost, multiplier int) { + if len(node.dataSourceHashes) <= 0 { + // no data source is responsible for this field + return + } + + parent := node.parent + fieldCost = 0 + argsCost = 0 + directiveCost = 0 + multiplier = 0 + + for _, dsHash := range node.dataSourceHashes { + dsCostConfig, ok := configs[dsHash] + if !ok || dsCostConfig == nil { + dsCostConfig = &DataSourceCostConfig{} + // Save it for later use by other fields: + configs[dsHash] = dsCostConfig + } + + fieldWeight := dsCostConfig.Weights[node.fieldCoord] + listSize := dsCostConfig.ListSizes[node.fieldCoord] + // The cost directive is not allowed on fields in an interface. + // The cost of a field on an interface can be calculated based on the costs of + // the corresponding field on each concrete type implementing that interface, + // either directly or indirectly through other interfaces. + if fieldWeight != nil && node.isEnclosingTypeAbstract && parent.returnsAbstractType { + // Composition should not let interface fields have weights, so we assume that + // the enclosing type is concrete. + fmt.Printf("WARNING: cost directive on field %v of interface %v\n", node.fieldCoord, parent.fieldCoord) + } + if node.isEnclosingTypeAbstract && parent.returnsAbstractType { + // This field is part of the enclosing interface/union. + // We look into implementing types and find the max-weighted field. + // Found fieldWeight can be used for all the calculations. + fieldWeight = parent.maxWeightImplementingField(dsCostConfig, node.fieldCoord.FieldName) + // If this field has listSize defined, then do not look into implementing types. + if listSize == nil && node.returnsListType { + listSize = parent.maxMultiplierImplementingField(dsCostConfig, node.fieldCoord.FieldName, node.arguments, variables, defaultListSize) + } + } + + if fieldWeight != nil && fieldWeight.HasWeight { + fieldCost += fieldWeight.Weight + } else { + // Use the weight of the type returned by this field + switch { + case node.returnsSimpleType: + fieldCost += dsCostConfig.EnumScalarTypeWeight(node.fieldTypeName) + case node.returnsAbstractType: + // For the abstract field, find the max weight among all implementing types + maxWeight := 0 + for _, implTypeName := range node.implementingTypeNames { + weight := dsCostConfig.ObjectTypeWeight(implTypeName) + if weight > maxWeight { + maxWeight = weight + } + } + fieldCost += maxWeight + default: + fieldCost += dsCostConfig.ObjectTypeWeight(node.fieldTypeName) + } + } + + for argName, arg := range node.arguments { + if fieldWeight != nil { + if weight, ok := fieldWeight.ArgumentWeights[argName]; ok { + argsCost += weight + continue + } + } + // Take into account the type of the argument. + // If the argument definition itself does not have weight attached, + // but the type of the argument does have weight attached to it. + if arg.isSimple { + argsCost += dsCostConfig.EnumScalarTypeWeight(arg.typeName) + } else if arg.isInputObject { + // TODO: arguments should include costs of input object fields + } else { + argsCost += dsCostConfig.ObjectTypeWeight(arg.typeName) + } + + } + + // Return early, since we do not support sizedFields yet. That parameter means + // that lisSize could be applied to fields that return non-lists. + if !node.returnsListType { + continue + } + + // Compute multiplier as the maximum of data sources. + if listSize != nil { + localMultiplier := listSize.multiplier(node.arguments, variables, defaultListSize) + // If this node returns a list of abstract types, then it could have listSize defined. + // Spec allows defining listSize on the fields of interfaces. + if localMultiplier > multiplier { + multiplier = localMultiplier + } + } + + } + + if multiplier == 0 && node.returnsListType { + multiplier = defaultListSize + } + return +} + +type ArgumentInfo struct { + intValue int + + // The name of an unwrapped type. + typeName string + + // If argument is passed an input object, we want to gather counts + // for all the field coordinates with non-null values used in the argument. + // TBD later when input objects are supported. + // + // For example, for + // "input A { x: Int, rec: A! }" + // following value is passed: + // { x: 1, rec: { x: 2, rec: { x: 3 } } }, + // then coordCounts will be: + // { {"A", "rec"}: 2, {"A", "x"}: 3 } + // + coordCounts map[FieldCoordinate]int + + // isInputObject is true for an input object passed to the argument, + // otherwise the argument is Scalar or Enum. + isInputObject bool + + isSimple bool + + // When the argument points to a variable, it contains the name of the variable. + hasVariable bool + + // The name of the variable that has value for this argument. + varName string +} + +// CostCalculator manages cost calculation during AST traversal +type CostCalculator struct { + // tree points to the root of the complete cost tree. Calculator tree is built once per query, + // then it is cached as part of the plan cache and + // not supposed to change again even during lifetime of a process. + tree *CostTreeNode +} + +// NewCostCalculator creates a new cost calculator. The defaultListSize is floored to 1. +func NewCostCalculator() *CostCalculator { + c := CostCalculator{} + return &c +} + +// GetStaticCost returns the calculated total static cost. +// config should be static per process or instance. variables could change between requests. +func (c *CostCalculator) GetStaticCost(config Configuration, variables *astjson.Value) int { + // costConfigs maps data source hash to its cost configuration. At the runtime we do not change + // this at all. It could be set once per router process. + costConfigs := make(map[DSHash]*DataSourceCostConfig) + for _, ds := range config.DataSources { + if costConfig := ds.GetCostConfig(); costConfig != nil { + costConfigs[ds.Hash()] = costConfig + } + } + defaultListSize := config.StaticCostDefaultListSize + if defaultListSize < 1 { + defaultListSize = 1 + } + return c.tree.staticCost(costConfigs, variables, defaultListSize) + +} + +// DebugPrint prints the cost tree structure for debugging purposes. +// It shows each node's field coordinate, costs, multipliers, and computed totals. +func (c *CostCalculator) DebugPrint(config Configuration, variables *astjson.Value) string { + if c.tree == nil || len(c.tree.children) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString("Cost Tree Debug\n") + sb.WriteString("===============\n") + costConfigs := make(map[DSHash]*DataSourceCostConfig) + for _, ds := range config.DataSources { + if costConfig := ds.GetCostConfig(); costConfig != nil { + costConfigs[ds.Hash()] = costConfig + } + } + defaultListSize := config.StaticCostDefaultListSize + if defaultListSize < 1 { + defaultListSize = 1 + } + c.tree.children[0].debugPrint(&sb, costConfigs, variables, defaultListSize, 0) + return sb.String() +} + +// debugPrint recursively prints a node and its children with indentation. +func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, defaultListSize int, depth int) { + // implementation is a bit crude and redundant, we could skip calculating nodes all over again. + // but it should suffice for debugging tests. + if node == nil { + return + } + + indent := strings.Repeat(" ", depth) + + fieldInfo := fmt.Sprintf("%s.%s", node.fieldCoord.TypeName, node.fieldCoord.FieldName) + + fmt.Fprintf(sb, "%s* %s", indent, fieldInfo) + + if node.fieldTypeName != "" { + fmt.Fprintf(sb, " : %s", node.fieldTypeName) + } + + var flags []string + if node.returnsListType { + flags = append(flags, "list") + } + if node.returnsAbstractType { + flags = append(flags, "abstract") + } + if node.returnsSimpleType { + flags = append(flags, "simple") + } + if len(flags) > 0 { + fmt.Fprintf(sb, " [%s]", strings.Join(flags, ",")) + } + sb.WriteString("\n") + + // Compute costs for this node to display in debug output + fieldCost, argsCost, dirsCost, multiplier := node.costsAndMultiplier(configs, variables, defaultListSize) + if fieldCost != 0 || argsCost != 0 || dirsCost != 0 || multiplier != 0 { + fmt.Fprintf(sb, "%s fieldCost=%d, argsCost=%d, directivesCost=%d, multiplier=%d", + indent, fieldCost, argsCost, dirsCost, multiplier) + + // Show data sources + if len(node.dataSourceHashes) > 0 { + fmt.Fprintf(sb, ", dataSources=%d", len(node.dataSourceHashes)) + } + sb.WriteString("\n") + } + + if len(node.arguments) > 0 { + var argStrs []string + for name, arg := range node.arguments { + if arg.hasVariable { + argStrs = append(argStrs, fmt.Sprintf("%s=$%s", name, arg.varName)) + } else if arg.isSimple { + argStrs = append(argStrs, fmt.Sprintf("%s=%d", name, arg.intValue)) + } else { + argStrs = append(argStrs, fmt.Sprintf("%s=", name)) + } + } + fmt.Fprintf(sb, "%s args: {%s}\n", indent, strings.Join(argStrs, ", ")) + } + + if len(node.implementingTypeNames) > 0 { + fmt.Fprintf(sb, "%s implements: [%s]\n", indent, strings.Join(node.implementingTypeNames, ", ")) + } + + // This is somewhat redundant, but it should not be used in production. + // If there is a need to present cost tree to the user, + // printing should be embedded into the tree calculation process. + subtreeCost := node.staticCost(configs, variables, defaultListSize) + fmt.Fprintf(sb, "%s cost=%d\n", indent, subtreeCost) + + for _, child := range node.children { + child.debugPrint(sb, configs, variables, defaultListSize, depth+1) + } +} diff --git a/v2/pkg/engine/plan/static_cost_visitor.go b/v2/pkg/engine/plan/static_cost_visitor.go new file mode 100644 index 0000000000..a566e6ae93 --- /dev/null +++ b/v2/pkg/engine/plan/static_cost_visitor.go @@ -0,0 +1,213 @@ +package plan + +import ( + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" +) + +// StaticCostVisitor builds the cost tree during AST traversal. +// It is registered on the same walker as the planning Visitor and uses +// data from the planning visitor (fieldPlanners, planners) to determine +// which data sources are responsible for each field. +type StaticCostVisitor struct { + Walker *astvisitor.Walker + + // AST documents - set before walking + Operation *ast.Document + Definition *ast.Document + + // References to planning visitor data - set before walking + planners []PlannerConfiguration + fieldPlanners *map[int][]int // Pointer to Visitor.fieldPlanners + + // Pointer to the main visitor's operationDefinition (set during EnterDocument) + operationDefinition *int + + // stack to keep track of the current node + stack []*CostTreeNode + + // The final cost tree that is built during plan traversal. + tree *CostTreeNode +} + +// NewStaticCostVisitor creates a new cost tree visitor +func NewStaticCostVisitor(walker *astvisitor.Walker, operation, definition *ast.Document) *StaticCostVisitor { + stack := make([]*CostTreeNode, 0, 16) + rootNode := CostTreeNode{ + fieldCoord: FieldCoordinate{"_none", "_root"}, + } + stack = append(stack, &rootNode) + return &StaticCostVisitor{ + Walker: walker, + Operation: operation, + Definition: definition, + stack: stack, + tree: &rootNode, + } +} + +// EnterField creates a partial cost node when entering a field. +// The node is filled in full in the LeaveField when fieldPlanners data is available. +func (v *StaticCostVisitor) EnterField(fieldRef int) { + typeName := v.Walker.EnclosingTypeDefinition.NameString(v.Definition) + fieldName := v.Operation.FieldNameUnsafeString(fieldRef) + + fieldDefinitionRef, ok := v.Walker.FieldDefinition(fieldRef) + if !ok { + // Push the sentinel node, so the LeaveField would pop the stack correctly. + v.stack = append(v.stack, &CostTreeNode{fieldRef: fieldRef}) + return + } + fieldDefinitionTypeRef := v.Definition.FieldDefinitionType(fieldDefinitionRef) + isListType := v.Definition.TypeIsList(fieldDefinitionTypeRef) + isSimpleType := v.Definition.TypeIsEnum(fieldDefinitionTypeRef, v.Definition) || v.Definition.TypeIsScalar(fieldDefinitionTypeRef, v.Definition) + unwrappedTypeName := v.Definition.ResolveTypeNameString(fieldDefinitionTypeRef) + + arguments := v.extractFieldArguments(fieldRef) + + // Check and push through if the unwrapped type of this field is interface or union. + unwrappedTypeNode, exists := v.Definition.NodeByNameStr(unwrappedTypeName) + var implementingTypeNames []string + var isAbstractType bool + if exists { + if unwrappedTypeNode.Kind == ast.NodeKindInterfaceTypeDefinition { + impl, ok := v.Definition.InterfaceTypeDefinitionImplementedByObjectWithNames(unwrappedTypeNode.Ref) + if ok { + implementingTypeNames = append(implementingTypeNames, impl...) + isAbstractType = true + } + } + if unwrappedTypeNode.Kind == ast.NodeKindUnionTypeDefinition { + impl, ok := v.Definition.UnionTypeDefinitionMemberTypeNames(unwrappedTypeNode.Ref) + if ok { + implementingTypeNames = append(implementingTypeNames, impl...) + isAbstractType = true + } + } + } + + isEnclosingTypeAbstract := v.Walker.EnclosingTypeDefinition.Kind.IsAbstractType() + // Create a skeleton node. dataSourceHashes will be filled in leaveFieldCost + node := CostTreeNode{ + fieldRef: fieldRef, + fieldCoord: FieldCoordinate{typeName, fieldName}, + fieldTypeName: unwrappedTypeName, + implementingTypeNames: implementingTypeNames, + returnsListType: isListType, + returnsSimpleType: isSimpleType, + returnsAbstractType: isAbstractType, + isEnclosingTypeAbstract: isEnclosingTypeAbstract, + arguments: arguments, + } + + // Attach to parent + if len(v.stack) > 0 { + parent := v.stack[len(v.stack)-1] + parent.children = append(parent.children, &node) + } + + v.stack = append(v.stack, &node) +} + +// LeaveField fills DataSource hashes for the current node and pop it from the cost stack. +func (v *StaticCostVisitor) LeaveField(fieldRef int) { + dsHashes := v.getFieldDataSourceHashes(fieldRef) + + if len(v.stack) <= 1 { // Keep root on stack + return + } + + // Find the current node (should match fieldRef) + lastIndex := len(v.stack) - 1 + current := v.stack[lastIndex] + if current.fieldRef != fieldRef { + return + } + + current.dataSourceHashes = dsHashes + current.parent = v.stack[lastIndex-1] + + v.stack = v.stack[:lastIndex] +} + +// getFieldDataSourceHashes returns all data source hashes for the field. +// A field can be planned on multiple data sources in federation scenarios. +func (v *StaticCostVisitor) getFieldDataSourceHashes(fieldRef int) []DSHash { + plannerIDs, ok := (*v.fieldPlanners)[fieldRef] + if !ok || len(plannerIDs) == 0 { + return nil + } + + dsHashes := make([]DSHash, 0, len(plannerIDs)) + for _, plannerID := range plannerIDs { + if plannerID >= 0 && plannerID < len(v.planners) { + dsHash := v.planners[plannerID].DataSourceConfiguration().Hash() + dsHashes = append(dsHashes, dsHash) + } + } + return dsHashes +} + +// extractFieldArguments extracts arguments from a field for cost calculation +// This implementation does not go deep for input objects yet. +// It should return unwrapped type names for arguments and that is it for now. +func (v *StaticCostVisitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { + argRefs := v.Operation.FieldArguments(fieldRef) + if len(argRefs) == 0 { + return nil + } + + arguments := make(map[string]ArgumentInfo, len(argRefs)) + for _, argRef := range argRefs { + argName := v.Operation.ArgumentNameString(argRef) + argValue := v.Operation.ArgumentValue(argRef) + argInfo := ArgumentInfo{} + + switch argValue.Kind { + case ast.ValueKindVariable: + variableValue := v.Operation.VariableValueNameString(argValue.Ref) + if !v.Operation.OperationDefinitionHasVariableDefinition(*v.operationDefinition, variableValue) { + continue // omit optional argument when the variable is not defined + } + + // We cannot read values of variables from the context here. Save it for later. + argInfo.hasVariable = true + argInfo.varName = variableValue + + variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(*v.operationDefinition, v.Operation.VariableValueNameBytes(argValue.Ref)) + if !exists { + continue + } + variableTypeRef := v.Operation.VariableDefinitions[variableDefinition].Type + unwrappedVarTypeRef := v.Operation.ResolveUnderlyingType(variableTypeRef) + argInfo.typeName = v.Operation.TypeNameString(unwrappedVarTypeRef) + node, exists := v.Definition.NodeByNameStr(argInfo.typeName) + if !exists { + continue + } + + // fmt.Printf("variableTypeRef = %v unwrappedVarTypeRef = %v typeName = %v nodeKind = %v varVal = %v\n", variableTypeRef, unwrappedVarTypeRef, argInfo.typeName, node.Kind, variableValue) + + // Analyze the node to see what kind of variable was passed. + switch node.Kind { + case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: + argInfo.isSimple = true + case ast.NodeKindInputObjectTypeDefinition: + argInfo.isInputObject = true + + } + + // TODO: we need to analyze variables that contains input object fields. + // If these fields has weight attached, use them for calculation. + // Inline values extracted into variables here, so we need to inspect them via AST. + } + + arguments[argName] = argInfo + } + + return arguments +} + +func (v *StaticCostVisitor) finalCostTree() *CostTreeNode { + return v.tree +} diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 1f8a469d05..34d7014fe3 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -39,7 +39,7 @@ type Visitor struct { response *resolve.GraphQLResponse subscription *resolve.GraphQLSubscription OperationName string - operationDefinition int + operationDefinitionRef int objects []*resolve.Object currentFields []objectFields currentField *resolve.Field @@ -57,9 +57,11 @@ type Visitor struct { pathCache map[astvisitor.VisitorKind]map[int]string // plannerFields maps plannerID to fieldRefs planned on this planner. + // Values added in AllowVisitor callback which is fired before calling LeaveField plannerFields map[int][]int // fieldPlanners maps fieldRef to the plannerIDs where it was planned on. + // Values added in AllowVisitor callback which is fired before calling LeaveField fieldPlanners map[int][]int // fieldEnclosingTypeNames maps fieldRef to the enclosing type name. @@ -134,6 +136,10 @@ func (v *Visitor) AllowVisitor(kind astvisitor.VisitorKind, ref int, visitor any // main planner visitor should always be allowed return true } + if _, isCostVisitor := visitor.(*StaticCostVisitor); isCostVisitor { + // cost tree visitor should always be allowed + return true + } var ( path string isFragmentPath bool @@ -610,24 +616,24 @@ func (v *Visitor) addInterfaceObjectNameToTypeNames(fieldRef int, typeName []byt return onTypeNames } -func (v *Visitor) LeaveField(ref int) { - v.debugOnLeaveNode(ast.NodeKindField, ref) +func (v *Visitor) LeaveField(fieldRef int) { + v.debugOnLeaveNode(ast.NodeKindField, fieldRef) - if v.skipField(ref) { + if v.skipField(fieldRef) { // we should also check skips on field leave // cause on nested keys we could mistakenly remove wrong object // from the stack of the current objects return } - if v.currentFields[len(v.currentFields)-1].popOnField == ref { + if v.currentFields[len(v.currentFields)-1].popOnField == fieldRef { v.currentFields = v.currentFields[:len(v.currentFields)-1] } - fieldDefinition, ok := v.Walker.FieldDefinition(ref) + fieldDefinitionRef, ok := v.Walker.FieldDefinition(fieldRef) if !ok { return } - fieldDefinitionTypeNode := v.Definition.FieldDefinitionTypeNode(fieldDefinition) + fieldDefinitionTypeNode := v.Definition.FieldDefinitionTypeNode(fieldDefinitionRef) switch fieldDefinitionTypeNode.Kind { case ast.NodeKindObjectTypeDefinition, ast.NodeKindInterfaceTypeDefinition, ast.NodeKindUnionTypeDefinition: v.objects = v.objects[:len(v.objects)-1] @@ -969,14 +975,14 @@ func (v *Visitor) valueRequiresExportedVariable(value ast.Value) bool { } } -func (v *Visitor) EnterOperationDefinition(ref int) { - operationName := v.Operation.OperationDefinitionNameString(ref) +func (v *Visitor) EnterOperationDefinition(opRef int) { + operationName := v.Operation.OperationDefinitionNameString(opRef) if v.OperationName != operationName { v.Walker.SkipNode() return } - v.operationDefinition = ref + v.operationDefinitionRef = opRef rootObject := &resolve.Object{ Fields: []*resolve.Field{}, @@ -1168,10 +1174,10 @@ func (v *Visitor) resolveInputTemplates(config *objectFetchConfiguration, input return v.renderJSONValueTemplate(value, variables, inputValueDefinition) } variableValue := v.Operation.VariableValueNameString(value.Ref) - if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { + if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinitionRef, variableValue) { break // omit optional argument when variable is not defined } - variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinition, v.Operation.VariableValueNameBytes(value.Ref)) + variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinitionRef, v.Operation.VariableValueNameBytes(value.Ref)) if !exists { break } diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 474dd66a60..25bfeac79a 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -16,12 +16,19 @@ import ( // Context should not ever be initialized directly, and should be initialized via the NewContext function type Context struct { - ctx context.Context - Variables *astjson.Value + ctx context.Context + + // Variables contains the variables to be used to render values of variables for the subgraph. + // Resolver takes into account RemapVariables for variable names. + Variables *astjson.Value + + // RemapVariables contains a map from new names to old names. When variables are renamed, + // the resolver will use the new name to look up the old name to render the variable in the query. + RemapVariables map[string]string + Files []*httpclient.FileUpload Request Request RenameTypeNames []RenameTypeName - RemapVariables map[string]string TracingOptions TraceOptions RateLimitOptions RateLimitOptions ExecutionOptions ExecutionOptions