From 8cbe5df75b933a079d3bee047f2c84ce90300505 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:55:09 +0200 Subject: [PATCH 01/43] feat: compute static (estimated) cost WIP --- execution/engine/execution_engine_test.go | 14 +- .../engine/plan/datasource_configuration.go | 13 + v2/pkg/engine/plan/planner.go | 11 + v2/pkg/engine/plan/static_cost.go | 501 ++++++++++++++++++ v2/pkg/engine/plan/static_cost_test.go | 319 +++++++++++ v2/pkg/engine/plan/visitor.go | 103 ++++ 6 files changed, 954 insertions(+), 7 deletions(-) create mode 100644 v2/pkg/engine/plan/static_cost.go create mode 100644 v2/pkg/engine/plan/static_cost_test.go diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 35ca1c0963..133115f71a 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -261,13 +261,13 @@ func TestExecutionEngine_Execute(t *testing.T) { engineConf.SetCustomResolveMap(testCase.customResolveMap) engineConf.plannerConfig.Debug = plan.DebugConfiguration{ - // PrintOperationTransformations: true, - // PrintPlanningPaths: true, - // PrintNodeSuggestions: true, - // PrintQueryPlans: true, - // ConfigurationVisitor: true, - // PlanningVisitor: true, - // DatasourceVisitor: true, + PrintOperationTransformations: true, + PrintPlanningPaths: true, + PrintNodeSuggestions: true, + PrintQueryPlans: true, + ConfigurationVisitor: true, + PlanningVisitor: true, + DatasourceVisitor: true, } ctx, cancel := context.WithCancel(context.Background()) diff --git a/v2/pkg/engine/plan/datasource_configuration.go b/v2/pkg/engine/plan/datasource_configuration.go index f196a00bc8..afc922be8d 100644 --- a/v2/pkg/engine/plan/datasource_configuration.go +++ b/v2/pkg/engine/plan/datasource_configuration.go @@ -51,6 +51,9 @@ type DataSourceMetadata struct { Directives *DirectiveConfigurations + // CostConfig holds IBM static cost configuration for this data source + CostConfig *DataSourceCostConfig + rootNodesIndex map[string]fieldsIndex // maps TypeName to fieldsIndex childNodesIndex map[string]fieldsIndex // maps TypeName to fieldsIndex @@ -287,6 +290,9 @@ type DataSource interface { Hash() DSHash FederationConfiguration() FederationMetaData CreatePlannerConfiguration(logger abstractlogger.Logger, fetchConfig *objectFetchConfiguration, pathConfig *plannerPathsConfiguration, configuration *Configuration) PlannerConfiguration + + // GetCostConfig returns the IBM static cost configuration for this data source + GetCostConfig() *DataSourceCostConfig } func (d *dataSourceConfiguration[T]) CustomConfiguration() T { @@ -335,6 +341,13 @@ func (d *dataSourceConfiguration[T]) Hash() DSHash { return d.hash } +func (d *dataSourceConfiguration[T]) GetCostConfig() *DataSourceCostConfig { + if d.DataSourceMetadata == nil { + return nil + } + return d.DataSourceMetadata.CostConfig +} + type DataSourcePlannerConfiguration struct { RequiredFields FederationFieldConfigurations ParentPath string diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index 7b948ab6bd..0cc3e46fdb 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -59,10 +59,21 @@ func NewPlanner(config Configuration) (*Planner, error) { // planning planningWalker := astvisitor.NewWalkerWithID(48, "PlanningWalker") + + // Initialize cost calculator and configure from data sources + costCalc := NewCostCalculator() + for _, ds := range config.DataSources { + if costConfig := ds.GetCostConfig(); costConfig != nil { + costCalc.SetDataSourceCostConfig(ds.Hash(), costConfig) + costCalc.Enable() + } + } + planningVisitor := &Visitor{ Walker: &planningWalker, fieldConfigs: map[int]*FieldConfiguration{}, disableResolveFieldPositions: config.DisableResolveFieldPositions, + costCalculator: costCalc, } p := &Planner{ diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go new file mode 100644 index 0000000000..9c397dde42 --- /dev/null +++ b/v2/pkg/engine/plan/static_cost.go @@ -0,0 +1,501 @@ +package plan + +// StaticCostDefaults contains default cost values when no specific costs are configured +var StaticCostDefaults = CostDefaults{ + FieldCost: 1, + ArgumentCost: 0, + ScalarCost: 0, + EnumCost: 0, + ObjectCost: 1, + ListCost: 10, // The assumed maximum size of a list for fields that return lists. +} + +// CostDefaults defines default cost values for different GraphQL elements +type CostDefaults struct { + FieldCost int + ArgumentCost int + ScalarCost int + EnumCost int + ObjectCost int + ListCost int +} + +// FieldCostConfig defines cost configuration for a specific field +// Includes @listSize directive fields for list cost calculation +type FieldCostConfig struct { + Weight int + + // ArgumentWeights maps argument name to its weight/cost + ArgumentWeights map[string]int + + // AssumedSize is the default assumed size when no slicing argument is provided (from @listSize) + // If 0, the global default list cost is used + AssumedSize int + + // SlicingArguments are argument names that control list size (e.g., "first", "last", "limit") + // The value of these arguments will be used as the multiplier (from @listSize) + SlicingArguments []string + + // SizedFields are field names that return the actual size of the list (from @listSize) + // These can be used for more accurate, actual cost estimation + SizedFields []string + + // RequireOneSlicingArgument if true, at least one slicing argument must be provided (from @listSize) + // If false and no slicing argument is provided, AssumedSize is used + RequireOneSlicingArgument bool +} + +// DataSourceCostConfig holds all cost configurations for a data source +type DataSourceCostConfig struct { + // Defaults overrides the global defaults for this data source + Defaults *CostDefaults + + // FieldConfig maps "TypeName.FieldName" to cost config + FieldConfig map[string]*FieldCostConfig + + // ScalarWeights maps scalar type name to weight + ScalarWeights map[string]int + + // EnumWeights maps enum type name to weight + EnumWeights map[string]int +} + +// NewDataSourceCostConfig creates a new cost config with defaults +func NewDataSourceCostConfig() *DataSourceCostConfig { + return &DataSourceCostConfig{ + FieldConfig: make(map[string]*FieldCostConfig), + ScalarWeights: make(map[string]int), + EnumWeights: make(map[string]int), + } +} + +// GetFieldCost returns the cost for a field, falling back to defaults +func (c *DataSourceCostConfig) GetFieldCost(typeName, fieldName string) int { + if c == nil { + return 0 + } + + key := typeName + "." + fieldName + if fc, ok := c.FieldConfig[key]; ok { + return fc.Weight + } + + if c.Defaults != nil { + return c.Defaults.FieldCost + } + return StaticCostDefaults.FieldCost +} + +// GetSlicingArguments returns the slicing argument names for a field +// These are arguments that control list size (e.g., "first", "last", "limit") +func (c *DataSourceCostConfig) GetSlicingArguments(typeName, fieldName string) []string { + if c == nil { + return nil + } + + key := typeName + "." + fieldName + if fc, ok := c.FieldConfig[key]; ok { + return fc.SlicingArguments + } + return nil +} + +// GetAssumedListSize returns the assumed list size for a field when no slicing argument is provided +func (c *DataSourceCostConfig) GetAssumedListSize(typeName, fieldName string) int { + if c == nil { + return 0 + } + + key := typeName + "." + fieldName + if fc, ok := c.FieldConfig[key]; ok { + return fc.AssumedSize + } + return 0 +} + +// GetArgumentCost returns the cost for an argument, falling back to defaults +func (c *DataSourceCostConfig) GetArgumentCost(typeName, fieldName, argName string) int { + if c == nil { + return 0 + } + + key := typeName + "." + fieldName + if fc, ok := c.FieldConfig[key]; ok && fc.ArgumentWeights != nil { + if weight, ok := fc.ArgumentWeights[argName]; ok { + return weight + } + } + + if c.Defaults != nil { + return c.Defaults.ArgumentCost + } + return StaticCostDefaults.ArgumentCost +} + +// GetScalarCost returns the cost for a scalar type +func (c *DataSourceCostConfig) GetScalarCost(scalarName string) int { + if c == nil { + return 0 + } + + if cost, ok := c.ScalarWeights[scalarName]; ok { + return cost + } + + if c.Defaults != nil { + return c.Defaults.ScalarCost + } + return StaticCostDefaults.ScalarCost +} + +// GetEnumCost returns the cost for an enum type +func (c *DataSourceCostConfig) GetEnumCost(enumName string) int { + if c == nil { + return 0 + } + + if cost, ok := c.EnumWeights[enumName]; ok { + return cost + } + + if c.Defaults != nil { + return c.Defaults.EnumCost + } + return StaticCostDefaults.EnumCost +} + +// GetListCost returns the default list cost +func (c *DataSourceCostConfig) GetListCost() int { + if c == nil { + return 0 + } + + if c.Defaults != nil { + return c.Defaults.ListCost + } + return StaticCostDefaults.ListCost +} + +// GetObjectCost returns the default object cost +func (c *DataSourceCostConfig) GetObjectCost() int { + if c == nil { + return 0 + } + + if c.Defaults != nil { + return c.Defaults.ObjectCost + } + return StaticCostDefaults.ObjectCost +} + +// CostTreeNode represents a node in the cost calculation tree +type CostTreeNode struct { + // FieldRef is the AST field reference + FieldRef int + + // TypeName is the enclosing type name + TypeName string + + // FieldName is the field name + FieldName string + + // DataSourceHashes identifies which data sources this field is resolved from + // A field can be planned on multiple data sources in federation scenarios + DataSourceHashes []DSHash + + // FieldCost is the base cost of this field (aggregated from all data sources) + FieldCost int + + // ArgumentsCost is the total cost of all arguments (aggregated from all data sources) + ArgumentsCost int + + // TypeCost is the cost based on return type (scalar/enum/object) + TypeCost int + + // Multiplier is applied to child costs (e.g., from "first" or "limit" arguments) + Multiplier int + + // Children contains child field costs + Children []*CostTreeNode + + // Parent points to the parent node + Parent *CostTreeNode + + // isListType and arguments are stored temporarily for deferred cost calculation + isListType bool + arguments []CostFieldArgument +} + +// TotalCost calculates the total cost of this node and all descendants +func (n *CostTreeNode) TotalCost() int { + if n == nil { + return 0 + } + + // Base cost for this field + cost := n.FieldCost + n.ArgumentsCost + n.TypeCost + + // Sum children costs + var childrenCost int + for _, child := range n.Children { + childrenCost += child.TotalCost() + } + + // Apply multiplier to children cost + multiplier := n.Multiplier + if multiplier == 0 { + multiplier = 1 + } + cost += childrenCost * multiplier + + return cost +} + +// CostTree represents the complete cost tree for a query +type CostTree struct { + Root *CostTreeNode + Total int +} + +// Calculate computes the total cost and checks against max +func (t *CostTree) Calculate() { + if t.Root != nil { + t.Total = t.Root.TotalCost() + } +} + +// CostCalculator manages cost calculation during AST traversal +type CostCalculator struct { + // stack maintains the current path in the cost tree + stack []*CostTreeNode + + // tree is the complete cost tree being built + tree *CostTree + + // costConfigs maps data source hash to its cost configuration + costConfigs map[DSHash]*DataSourceCostConfig + + // defaultConfig is used when no data source specific config exists + defaultConfig *DataSourceCostConfig + + // enabled controls whether cost calculation is active + enabled bool +} + +// NewCostCalculator creates a new cost calculator +func NewCostCalculator() *CostCalculator { + tree := &CostTree{ + Root: &CostTreeNode{ + FieldName: "_root", + Multiplier: 1, + }, + } + c := CostCalculator{ + stack: make([]*CostTreeNode, 0, 16), + costConfigs: make(map[DSHash]*DataSourceCostConfig), + tree: tree, + enabled: false, + } + c.stack = append(c.stack, c.tree.Root) + + return &c +} + +// Enable activates cost calculation +func (c *CostCalculator) Enable() { + c.enabled = true +} + +// SetDataSourceCostConfig sets the cost config for a specific data source +func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSourceCostConfig) { + c.costConfigs[dsHash] = config +} + +// SetDefaultCostConfig sets the default cost config +func (c *CostCalculator) SetDefaultCostConfig(config *DataSourceCostConfig) { + c.defaultConfig = config +} + +// getCostConfig returns the cost config for a specific data source hash +func (c *CostCalculator) getCostConfig(dsHash DSHash) *DataSourceCostConfig { + if config, ok := c.costConfigs[dsHash]; ok { + return config + } + return c.getDefaultCostConfig() +} + +// getDefaultCostConfig returns the default cost config when no specific data source is available +func (c *CostCalculator) getDefaultCostConfig() *DataSourceCostConfig { + if c.defaultConfig != nil { + return c.defaultConfig + } + // Return a dummy config with defaults + return &DataSourceCostConfig{} +} + +// IsEnabled returns whether cost calculation is enabled +func (c *CostCalculator) IsEnabled() bool { + return c.enabled +} + +// CurrentNode returns the current node on the stack +func (c *CostCalculator) CurrentNode() *CostTreeNode { + if len(c.stack) == 0 { + return nil + } + return c.stack[len(c.stack)-1] +} + +// EnterField is called when entering a field during AST traversal. +// It creates a skeleton node and pushes it onto the stack. +// The actual cost calculation happens in LeaveField when fieldPlanners data is available. +func (c *CostCalculator) EnterField(fieldRef int, typeName, fieldName string, isListType bool, arguments []CostFieldArgument) { + if !c.enabled { + return + } + + // Create skeleton cost node - costs will be calculated in LeaveField + node := &CostTreeNode{ + FieldRef: fieldRef, + TypeName: typeName, + FieldName: fieldName, + Multiplier: 1, + isListType: isListType, + arguments: arguments, + } + + // Attach to parent + parent := c.CurrentNode() + if parent != nil { + node.Parent = parent + parent.Children = append(parent.Children, node) + } + + // Push onto stack + c.stack = append(c.stack, node) +} + +// LeaveField is called when leaving a field during AST traversal. +// This is where we calculate costs because fieldPlanners data is now available. +func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { + if !c.enabled { + return + } + + // Find the current node (should match fieldRef) + if len(c.stack) <= 1 { // Keep root on stack + return + } + + current := c.stack[len(c.stack)-1] + if current.FieldRef != fieldRef { + return + } + + // Now calculate costs with the data source information + current.DataSourceHashes = dsHashes + c.calculateNodeCosts(current) + + // Pop from stack + c.stack = c.stack[:len(c.stack)-1] +} + +// calculateNodeCosts fills in the cost values for a node based on its data sources +func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { + dsHashes := node.DataSourceHashes + typeName := node.TypeName + fieldName := node.FieldName + arguments := node.arguments + isListType := node.isListType + + // Aggregate costs from all data sources this field is planned on + // We sum the costs because each data source will be queried + for _, dsHash := range dsHashes { + config := c.getCostConfig(dsHash) + + node.FieldCost += config.GetFieldCost(typeName, fieldName) + + // Calculate argument costs for this data source + for _, arg := range arguments { + node.ArgumentsCost += config.GetArgumentCost(typeName, fieldName, arg.Name) + } + + // Calculate multiplier from @listSize directive + c.calculateListMultiplier(node, config, typeName, fieldName, arguments) + + // Add list cost if this is a list type (only once, take highest) + if isListType { + listCost := config.GetListCost() + if listCost > node.TypeCost { + node.TypeCost = listCost + } + } + } + + // If no data sources, use default config + if len(dsHashes) == 0 { + config := c.getDefaultCostConfig() + node.FieldCost = config.GetFieldCost(typeName, fieldName) + + for _, arg := range arguments { + node.ArgumentsCost += config.GetArgumentCost(typeName, fieldName, arg.Name) + } + + c.calculateListMultiplier(node, config, typeName, fieldName, arguments) + + if isListType { + node.TypeCost = config.GetListCost() + } + } +} + +// calculateListMultiplier calculates the list multiplier based on @listSize directive +func (c *CostCalculator) calculateListMultiplier(node *CostTreeNode, config *DataSourceCostConfig, typeName, fieldName string, arguments []CostFieldArgument) { + slicingArguments := config.GetSlicingArguments(typeName, fieldName) + assumedSize := config.GetAssumedListSize(typeName, fieldName) + + // If no list size config, nothing to do + if len(slicingArguments) == 0 && assumedSize == 0 { + return + } + + // Check if any slicing argument is provided + slicingArgFound := false + for _, arg := range arguments { + for _, slicingArg := range slicingArguments { + if arg.Name == slicingArg && arg.IntValue > 0 { + // Use the highest multiplier + if arg.IntValue > node.Multiplier { + node.Multiplier = arg.IntValue + } + slicingArgFound = true + } + } + } + + // If no slicing argument found, use assumed size + if !slicingArgFound && assumedSize > 0 { + if assumedSize > node.Multiplier { + node.Multiplier = assumedSize + } + } +} + +// GetTree returns the cost tree +func (c *CostCalculator) GetTree() *CostTree { + c.tree.Calculate() + return c.tree +} + +// GetTotalCost returns the calculated total cost +func (c *CostCalculator) GetTotalCost() int { + c.tree.Calculate() + return c.tree.Total +} + +// CostFieldArgument represents a parsed field argument for cost calculation +type CostFieldArgument struct { + Name string + IntValue int + // Add other value types as needed +} diff --git a/v2/pkg/engine/plan/static_cost_test.go b/v2/pkg/engine/plan/static_cost_test.go new file mode 100644 index 0000000000..17f0184929 --- /dev/null +++ b/v2/pkg/engine/plan/static_cost_test.go @@ -0,0 +1,319 @@ +package plan + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Test DSHash values +const ( + testDSHash1 DSHash = 1001 + testDSHash2 DSHash = 1002 +) + +func TestCostDefaults(t *testing.T) { + // Test that defaults are set correctly + assert.Equal(t, 1, StaticCostDefaults.FieldCost) + assert.Equal(t, 0, StaticCostDefaults.ArgumentCost) + assert.Equal(t, 0, StaticCostDefaults.ScalarCost) + assert.Equal(t, 0, StaticCostDefaults.EnumCost) + assert.Equal(t, 1, StaticCostDefaults.ObjectCost) + assert.Equal(t, 10, StaticCostDefaults.ListCost) +} + +func TestNewDataSourceCostConfig(t *testing.T) { + config := NewDataSourceCostConfig() + + assert.NotNil(t, config.FieldConfig) + assert.NotNil(t, config.ScalarWeights) + assert.NotNil(t, config.EnumWeights) +} + +func TestDataSourceCostConfig_GetFieldCost(t *testing.T) { + config := NewDataSourceCostConfig() + + // Test default cost + cost := config.GetFieldCost("Query", "users") + assert.Equal(t, StaticCostDefaults.FieldCost, cost) + + // Test custom cost + config.FieldConfig["Query.users"] = &FieldCostConfig{ + Weight: 5, + } + cost = config.GetFieldCost("Query", "users") + assert.Equal(t, 5, cost) + + // Test with custom defaults + config.Defaults = &CostDefaults{ + FieldCost: 2, + } + cost = config.GetFieldCost("Query", "posts") + assert.Equal(t, 2, cost) +} + +func TestDataSourceCostConfig_GetSlicingArguments(t *testing.T) { + config := NewDataSourceCostConfig() + + // Test no list size config + args := config.GetSlicingArguments("Query", "users") + assert.Nil(t, args) + + // Test with list size config + config.FieldConfig["Query.users"] = &FieldCostConfig{ + Weight: 1, + AssumedSize: 100, + SlicingArguments: []string{"first", "last"}, + RequireOneSlicingArgument: true, + } + + // Test GetSlicingArguments + args = config.GetSlicingArguments("Query", "users") + assert.Equal(t, []string{"first", "last"}, args) + + // Test GetAssumedListSize + assumed := config.GetAssumedListSize("Query", "users") + assert.Equal(t, 100, assumed) +} + +func TestCostTreeNode_TotalCost(t *testing.T) { + // Build a simple tree: + // root (cost: 1) + // └── users (cost: 1, multiplier: 10 from "first" arg) + // └── name (cost: 1) + // └── email (cost: 1) + + root := &CostTreeNode{ + FieldName: "_root", + Multiplier: 1, + } + + users := &CostTreeNode{ + FieldName: "users", + FieldCost: 1, + Multiplier: 10, // "first: 10" + Parent: root, + } + root.Children = append(root.Children, users) + + name := &CostTreeNode{ + FieldName: "name", + FieldCost: 1, + Parent: users, + } + users.Children = append(users.Children, name) + + email := &CostTreeNode{ + FieldName: "email", + FieldCost: 1, + Parent: users, + } + users.Children = append(users.Children, email) + + // Calculate: root cost = users cost + (children cost * multiplier) + // users: 1 + (1 + 1) * 10 = 1 + 20 = 21 + // root: 0 + 21 * 1 = 21 + total := root.TotalCost() + assert.Equal(t, 21, total) +} + +func TestCostCalculator_BasicFlow(t *testing.T) { + calc := NewCostCalculator() + calc.Enable() + + config := NewDataSourceCostConfig() + config.FieldConfig["Query.users"] = &FieldCostConfig{ + Weight: 2, + SlicingArguments: []string{"first"}, + } + calc.SetDataSourceCostConfig(testDSHash1, config) + + // Simulate entering and leaving fields (two-phase: Enter creates skeleton, Leave calculates costs) + calc.EnterField(1, "Query", "users", true, []CostFieldArgument{ + {Name: "first", IntValue: 10}, + }) + calc.EnterField(2, "User", "name", false, nil) + calc.LeaveField(2, []DSHash{testDSHash1}) + calc.EnterField(3, "User", "email", false, nil) + calc.LeaveField(3, []DSHash{testDSHash1}) + calc.LeaveField(1, []DSHash{testDSHash1}) + + // Get results + tree := calc.GetTree() + assert.NotNil(t, tree) + assert.True(t, tree.Total > 0) + + totalCost := calc.GetTotalCost() + // users: 2 (field) + 10 (list) + (1 + 1) * 10 = 12 + 20 = 32 + assert.Equal(t, 32, totalCost) +} + +func TestCostCalculator_Disabled(t *testing.T) { + calc := NewCostCalculator() + // Don't enable + + calc.EnterField(1, "Query", "users", true, nil) + calc.LeaveField(1, []DSHash{testDSHash1}) + + // Should return 0 when disabled + assert.Equal(t, 0, calc.GetTotalCost()) +} + +func TestCostCalculator_MultipleDataSources(t *testing.T) { + calc := NewCostCalculator() + calc.Enable() + + // Configure two different data sources with different costs + config1 := NewDataSourceCostConfig() + config1.FieldConfig["User.name"] = &FieldCostConfig{ + Weight: 2, + } + calc.SetDataSourceCostConfig(testDSHash1, config1) + + config2 := NewDataSourceCostConfig() + config2.FieldConfig["User.name"] = &FieldCostConfig{ + Weight: 3, + } + calc.SetDataSourceCostConfig(testDSHash2, config2) + + // Field planned on both data sources - costs should be aggregated + calc.EnterField(1, "User", "name", false, nil) + calc.LeaveField(1, []DSHash{testDSHash1, testDSHash2}) + + totalCost := calc.GetTotalCost() + // Weight from subgraph1 (2) + cost from subgraph2 (3) = 5 + assert.Equal(t, 5, totalCost) +} + +func TestCostCalculator_NoDataSource(t *testing.T) { + calc := NewCostCalculator() + calc.Enable() + + // Set default config + defaultConfig := NewDataSourceCostConfig() + defaultConfig.Defaults = &CostDefaults{ + FieldCost: 2, + } + calc.SetDefaultCostConfig(defaultConfig) + + // Field with no data source - should use default config + calc.EnterField(1, "Query", "unknown", false, nil) + calc.LeaveField(1, nil) + + totalCost := calc.GetTotalCost() + assert.Equal(t, 2, totalCost) +} + +func TestCostTree_Calculate(t *testing.T) { + tree := &CostTree{ + Root: &CostTreeNode{ + FieldName: "_root", + Multiplier: 1, + Children: []*CostTreeNode{ + { + FieldName: "field1", + FieldCost: 5, + }, + }, + }, + } + + tree.Calculate() + + assert.Equal(t, 5, tree.Total) +} + +func TestNilCostConfig(t *testing.T) { + var config *DataSourceCostConfig + + // All methods should handle nil gracefully + assert.Equal(t, 0, config.GetFieldCost("Type", "field")) + assert.Equal(t, 0, config.GetArgumentCost("Type", "field", "arg")) + assert.Equal(t, 0, config.GetScalarCost("String")) + assert.Equal(t, 0, config.GetEnumCost("Status")) + assert.Equal(t, 0, config.GetListCost()) + assert.Equal(t, 0, config.GetObjectCost()) + + assert.Nil(t, config.GetSlicingArguments("Type", "field")) + assert.Equal(t, 0, config.GetAssumedListSize("Type", "field")) +} + +func TestCostCalculator_TwoPhaseFlow(t *testing.T) { + // Test that the two-phase flow works correctly: + // EnterField creates skeleton, LeaveField fills in costs + calc := NewCostCalculator() + calc.Enable() + + config := NewDataSourceCostConfig() + config.FieldConfig["Query.users"] = &FieldCostConfig{ + Weight: 5, + } + calc.SetDataSourceCostConfig(testDSHash1, config) + + // Enter creates skeleton node + calc.EnterField(1, "Query", "users", false, nil) + + // At this point, the node exists but has no cost calculated yet + currentNode := calc.CurrentNode() + assert.NotNil(t, currentNode) + assert.Equal(t, "users", currentNode.FieldName) + assert.Equal(t, 0, currentNode.FieldCost) // Weight not yet calculated + + // Leave fills in DS info and calculates cost + calc.LeaveField(1, []DSHash{testDSHash1}) + + // Now the cost should be calculated + totalCost := calc.GetTotalCost() + assert.Equal(t, 5, totalCost) +} + +func TestCostCalculator_ListSizeAssumedSize(t *testing.T) { + // Test that assumed size is used when no slicing argument is provided + calc := NewCostCalculator() + calc.Enable() + + config := NewDataSourceCostConfig() + config.FieldConfig["Query.users"] = &FieldCostConfig{ + Weight: 1, + AssumedSize: 50, // Assume 50 items if no slicing arg + SlicingArguments: []string{"first", "last"}, + } + calc.SetDataSourceCostConfig(testDSHash1, config) + + // Enter field with no slicing arguments + calc.EnterField(1, "Query", "users", true, nil) + + // Enter child field + calc.EnterField(2, "User", "name", false, nil) + calc.LeaveField(2, []DSHash{testDSHash1}) + + calc.LeaveField(1, []DSHash{testDSHash1}) + + // multiplier should be 50 (assumed size) + tree := calc.GetTree() + assert.Equal(t, 50, tree.Root.Children[0].Multiplier) +} + +func TestCostCalculator_ListSizeSlicingArg(t *testing.T) { + // Test that slicing argument overrides assumed size + calc := NewCostCalculator() + calc.Enable() + + config := NewDataSourceCostConfig() + config.FieldConfig["Query.users"] = &FieldCostConfig{ + Weight: 1, + AssumedSize: 50, // This should NOT be used + SlicingArguments: []string{"first", "last"}, + } + calc.SetDataSourceCostConfig(testDSHash1, config) + + // Enter field with "first: 10" argument + calc.EnterField(1, "Query", "users", true, []CostFieldArgument{ + {Name: "first", IntValue: 10}, + }) + calc.LeaveField(1, []DSHash{testDSHash1}) + + // multiplier should be 10 (from slicing arg), not 50 + tree := calc.GetTree() + assert.Equal(t, 10, tree.Root.Children[0].Multiplier) +} diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 1f8a469d05..61aae887d4 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -57,13 +57,18 @@ type Visitor struct { pathCache map[astvisitor.VisitorKind]map[int]string // plannerFields maps plannerID to fieldRefs planned on this planner. + // It is available just before the LeaveField. plannerFields map[int][]int // fieldPlanners maps fieldRef to the plannerIDs where it was planned on. + // It is available just before the LeaveField. fieldPlanners map[int][]int // fieldEnclosingTypeNames maps fieldRef to the enclosing type name. fieldEnclosingTypeNames map[int]string + + // costCalculator calculates IBM static costs during AST traversal + costCalculator *CostCalculator } type indirectInterfaceField struct { @@ -399,6 +404,9 @@ func (v *Visitor) EnterField(ref int) { *v.currentFields[len(v.currentFields)-1].fields = append(*v.currentFields[len(v.currentFields)-1].fields, v.currentField) v.mapFieldConfig(ref) + + // Enter cost calculation for this field (skeleton node, actual costs calculated in LeaveField) + v.enterFieldCost(ref) } func (v *Visitor) mapFieldConfig(ref int) { @@ -411,6 +419,97 @@ func (v *Visitor) mapFieldConfig(ref int) { v.fieldConfigs[ref] = fieldConfig } +// enterFieldCost creates a skeleton cost node when entering a field. +// Actual cost calculation is deferred to leaveFieldCost when fieldPlanners data is available. +func (v *Visitor) enterFieldCost(ref int) { + if v.costCalculator == nil || !v.costCalculator.IsEnabled() { + return + } + + typeName := v.Walker.EnclosingTypeDefinition.NameString(v.Definition) + fieldName := v.Operation.FieldNameUnsafeString(ref) + + // Check if the field returns a list type + fieldDefinition, ok := v.Walker.FieldDefinition(ref) + if !ok { + return + } + fieldDefinitionTypeRef := v.Definition.FieldDefinitionType(fieldDefinition) + isListType := v.Definition.TypeIsList(fieldDefinitionTypeRef) + + // Extract arguments for cost calculation + arguments := v.costFieldArguments(ref) + + // Create skeleton node - dsHashes will be filled in leaveFieldCost + v.costCalculator.EnterField(ref, typeName, fieldName, isListType, arguments) +} + +// leaveFieldCost calculates costs and pops from the cost stack. +// Called in LeaveField because fieldPlanners is populated by AllowVisitor on LeaveField. +func (v *Visitor) leaveFieldCost(ref int) { + if v.costCalculator == nil || !v.costCalculator.IsEnabled() { + return + } + + // Now fieldPlanners is populated, get the data source hashes + dsHashes := v.getFieldDataSourceHashes(ref) + + v.costCalculator.LeaveField(ref, dsHashes) +} + +// getFieldDataSourceHashes returns all data source hashes for the field. +// A field can be planned on multiple data sources in federation scenarios. +func (v *Visitor) getFieldDataSourceHashes(ref int) []DSHash { + plannerIDs, ok := v.fieldPlanners[ref] + if !ok || len(plannerIDs) == 0 { + return nil + } + + dsHashes := make([]DSHash, 0, len(plannerIDs)) + for _, plannerID := range plannerIDs { + if plannerID >= 0 && plannerID < len(v.planners) { + dsHash := v.planners[plannerID].DataSourceConfiguration().Hash() + dsHashes = append(dsHashes, dsHash) + } + } + return dsHashes +} + +// costFieldArguments extracts arguments from a field for cost calculation +func (v *Visitor) costFieldArguments(ref int) []CostFieldArgument { + argRefs := v.Operation.FieldArguments(ref) + if len(argRefs) == 0 { + return nil + } + + arguments := make([]CostFieldArgument, 0, len(argRefs)) + for _, argRef := range argRefs { + argName := v.Operation.ArgumentNameString(argRef) + argValue := v.Operation.ArgumentValue(argRef) + + arg := CostFieldArgument{ + Name: argName, + } + + // Extract integer value if present (for multipliers like "first", "limit") + if argValue.Kind == ast.ValueKindInteger { + arg.IntValue = int(v.Operation.IntValueAsInt(argValue.Ref)) + } + + arguments = append(arguments, arg) + } + + return arguments +} + +// GetTotalCost returns the total calculated cost for the query +func (v *Visitor) GetTotalCost() int { + if v.costCalculator == nil { + return 0 + } + return v.costCalculator.GetTotalCost() +} + func (v *Visitor) resolveFieldInfo(ref, typeRef int, onTypeNames [][]byte) *resolve.FieldInfo { if v.Config.DisableIncludeInfo { return nil @@ -620,6 +719,10 @@ func (v *Visitor) LeaveField(ref int) { return } + // Calculate costs and pop from cost stack + // This is done in LeaveField because fieldPlanners is populated by AllowVisitor on LeaveField + v.leaveFieldCost(ref) + if v.currentFields[len(v.currentFields)-1].popOnField == ref { v.currentFields = v.currentFields[:len(v.currentFields)-1] } From b4d2404847a939970703958a37a40faf993e0665 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:49:35 +0200 Subject: [PATCH 02/43] simplify structures --- v2/pkg/engine/plan/static_cost.go | 166 +++++++++---------------- v2/pkg/engine/plan/static_cost_test.go | 36 ++---- v2/pkg/engine/plan/visitor.go | 12 +- 3 files changed, 73 insertions(+), 141 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 9c397dde42..47331d7ea5 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -47,12 +47,11 @@ type FieldCostConfig struct { // DataSourceCostConfig holds all cost configurations for a data source type DataSourceCostConfig struct { - // Defaults overrides the global defaults for this data source - Defaults *CostDefaults - // FieldConfig maps "TypeName.FieldName" to cost config FieldConfig map[string]*FieldCostConfig + // Object Weights include all object, scalar and enum definitions. + // ScalarWeights maps scalar type name to weight ScalarWeights map[string]int @@ -69,6 +68,15 @@ func NewDataSourceCostConfig() *DataSourceCostConfig { } } +func (c *DataSourceCostConfig) GetFieldCostConfig(typeName, fieldName string) *FieldCostConfig { + if c == nil { + return nil + } + + key := typeName + "." + fieldName + return c.FieldConfig[key] +} + // GetFieldCost returns the cost for a field, falling back to defaults func (c *DataSourceCostConfig) GetFieldCost(typeName, fieldName string) int { if c == nil { @@ -80,9 +88,6 @@ func (c *DataSourceCostConfig) GetFieldCost(typeName, fieldName string) int { return fc.Weight } - if c.Defaults != nil { - return c.Defaults.FieldCost - } return StaticCostDefaults.FieldCost } @@ -126,9 +131,6 @@ func (c *DataSourceCostConfig) GetArgumentCost(typeName, fieldName, argName stri } } - if c.Defaults != nil { - return c.Defaults.ArgumentCost - } return StaticCostDefaults.ArgumentCost } @@ -142,9 +144,6 @@ func (c *DataSourceCostConfig) GetScalarCost(scalarName string) int { return cost } - if c.Defaults != nil { - return c.Defaults.ScalarCost - } return StaticCostDefaults.ScalarCost } @@ -158,37 +157,24 @@ func (c *DataSourceCostConfig) GetEnumCost(enumName string) int { return cost } - if c.Defaults != nil { - return c.Defaults.EnumCost - } return StaticCostDefaults.EnumCost } -// GetListCost returns the default list cost -func (c *DataSourceCostConfig) GetListCost() int { - if c == nil { - return 0 - } - - if c.Defaults != nil { - return c.Defaults.ListCost - } - return StaticCostDefaults.ListCost -} - // GetObjectCost returns the default object cost func (c *DataSourceCostConfig) GetObjectCost() int { if c == nil { return 0 } - if c.Defaults != nil { - return c.Defaults.ObjectCost - } return StaticCostDefaults.ObjectCost } +func (c *DataSourceCostConfig) GetDefaultListCost() int { + return 10 +} + // CostTreeNode represents a node in the cost calculation tree +// Based on IBM GraphQL Cost Specification: https://ibm.github.io/graphql-specs/cost-spec.html type CostTreeNode struct { // FieldRef is the AST field reference FieldRef int @@ -200,48 +186,45 @@ type CostTreeNode struct { FieldName string // DataSourceHashes identifies which data sources this field is resolved from - // A field can be planned on multiple data sources in federation scenarios DataSourceHashes []DSHash - // FieldCost is the base cost of this field (aggregated from all data sources) + // FieldCost is the weight of this field from @cost directive FieldCost int - // ArgumentsCost is the total cost of all arguments (aggregated from all data sources) + // ArgumentsCost is the sum of argument weights and input fields used on each directive ArgumentsCost int - // TypeCost is the cost based on return type (scalar/enum/object) - TypeCost int + DirectivesCost int - // Multiplier is applied to child costs (e.g., from "first" or "limit" arguments) + // Multiplier is the list size multiplier from @listSize directive + // Applied to children costs for list fields Multiplier int // Children contains child field costs Children []*CostTreeNode - // Parent points to the parent node - Parent *CostTreeNode - // isListType and arguments are stored temporarily for deferred cost calculation isListType bool - arguments []CostFieldArgument + arguments map[string]int } // TotalCost calculates the total cost of this node and all descendants +// Per IBM spec: total = field_weight + argument_weights + (children_total * multiplier) func (n *CostTreeNode) TotalCost() int { if n == nil { return 0 } - // Base cost for this field - cost := n.FieldCost + n.ArgumentsCost + n.TypeCost + // TODO: negative sum should be rounded up to zero + cost := n.FieldCost + n.ArgumentsCost + n.DirectivesCost - // Sum children costs + // Sum children (fields) costs var childrenCost int for _, child := range n.Children { childrenCost += child.TotalCost() } - // Apply multiplier to children cost + // Apply multiplier to children cost (for list fields) multiplier := n.Multiplier if multiplier == 0 { multiplier = 1 @@ -349,7 +332,7 @@ func (c *CostCalculator) CurrentNode() *CostTreeNode { // EnterField is called when entering a field during AST traversal. // It creates a skeleton node and pushes it onto the stack. // The actual cost calculation happens in LeaveField when fieldPlanners data is available. -func (c *CostCalculator) EnterField(fieldRef int, typeName, fieldName string, isListType bool, arguments []CostFieldArgument) { +func (c *CostCalculator) EnterField(fieldRef int, typeName, fieldName string, isListType bool, arguments map[string]int) { if !c.enabled { return } @@ -367,7 +350,6 @@ func (c *CostCalculator) EnterField(fieldRef int, typeName, fieldName string, is // Attach to parent parent := c.CurrentNode() if parent != nil { - node.Parent = parent parent.Children = append(parent.Children, node) } @@ -401,86 +383,56 @@ func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { } // calculateNodeCosts fills in the cost values for a node based on its data sources +// calculateNodeCosts implements IBM GraphQL Cost Specification +// See: https://ibm.github.io/graphql-specs/cost-spec.html#sec-Field-Cost func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { - dsHashes := node.DataSourceHashes typeName := node.TypeName fieldName := node.FieldName - arguments := node.arguments - isListType := node.isListType - // Aggregate costs from all data sources this field is planned on - // We sum the costs because each data source will be queried - for _, dsHash := range dsHashes { - config := c.getCostConfig(dsHash) - - node.FieldCost += config.GetFieldCost(typeName, fieldName) - - // Calculate argument costs for this data source - for _, arg := range arguments { - node.ArgumentsCost += config.GetArgumentCost(typeName, fieldName, arg.Name) - } - - // Calculate multiplier from @listSize directive - c.calculateListMultiplier(node, config, typeName, fieldName, arguments) - - // Add list cost if this is a list type (only once, take highest) - if isListType { - listCost := config.GetListCost() - if listCost > node.TypeCost { - node.TypeCost = listCost - } - } + // Get the cost config (use first data source config, or default) + var config *DataSourceCostConfig + if len(node.DataSourceHashes) > 0 { + config = c.getCostConfig(node.DataSourceHashes[0]) + } else { + config = c.getDefaultCostConfig() } - // If no data sources, use default config - if len(dsHashes) == 0 { - config := c.getDefaultCostConfig() - node.FieldCost = config.GetFieldCost(typeName, fieldName) + node.FieldCost = config.GetFieldCost(typeName, fieldName) - for _, arg := range arguments { - node.ArgumentsCost += config.GetArgumentCost(typeName, fieldName, arg.Name) - } - - c.calculateListMultiplier(node, config, typeName, fieldName, arguments) - - if isListType { - node.TypeCost = config.GetListCost() - } + for argName := range node.arguments { + node.ArgumentsCost += config.GetArgumentCost(typeName, fieldName, argName) + // TODO: arguments should include costs of input object fields } -} -// calculateListMultiplier calculates the list multiplier based on @listSize directive -func (c *CostCalculator) calculateListMultiplier(node *CostTreeNode, config *DataSourceCostConfig, typeName, fieldName string, arguments []CostFieldArgument) { - slicingArguments := config.GetSlicingArguments(typeName, fieldName) - assumedSize := config.GetAssumedListSize(typeName, fieldName) + // TODO: Directives Cost should includes the weights of all its arguments + + // TODO: arguments, directives and fields of input object are mutually recursive, + // we should recurse on them and sum all of possible values. - // If no list size config, nothing to do - if len(slicingArguments) == 0 && assumedSize == 0 { + // Compute multiplier + if !node.isListType { + node.Multiplier = 1 return } - // Check if any slicing argument is provided - slicingArgFound := false - for _, arg := range arguments { - for _, slicingArg := range slicingArguments { - if arg.Name == slicingArg && arg.IntValue > 0 { - // Use the highest multiplier - if arg.IntValue > node.Multiplier { - node.Multiplier = arg.IntValue - } - slicingArgFound = true + fieldCostConfig := config.GetFieldCostConfig(typeName, fieldName) + node.Multiplier = 0 + for _, slicingArg := range fieldCostConfig.SlicingArguments { + if argValue, ok := node.arguments[slicingArg]; ok && argValue > 0 { + if argValue > node.Multiplier { + node.Multiplier = argValue } } } - - // If no slicing argument found, use assumed size - if !slicingArgFound && assumedSize > 0 { - if assumedSize > node.Multiplier { - node.Multiplier = assumedSize - } + if node.Multiplier == 0 && fieldCostConfig.AssumedSize > 0 { + node.Multiplier = fieldCostConfig.AssumedSize + return } + node.Multiplier = config.GetDefaultListCost() + } + // GetTree returns the cost tree func (c *CostCalculator) GetTree() *CostTree { c.tree.Calculate() diff --git a/v2/pkg/engine/plan/static_cost_test.go b/v2/pkg/engine/plan/static_cost_test.go index 17f0184929..f77d4b7c8c 100644 --- a/v2/pkg/engine/plan/static_cost_test.go +++ b/v2/pkg/engine/plan/static_cost_test.go @@ -44,12 +44,9 @@ func TestDataSourceCostConfig_GetFieldCost(t *testing.T) { cost = config.GetFieldCost("Query", "users") assert.Equal(t, 5, cost) - // Test with custom defaults - config.Defaults = &CostDefaults{ - FieldCost: 2, - } + // Test with defaults cost = config.GetFieldCost("Query", "posts") - assert.Equal(t, 2, cost) + assert.Equal(t, 1, cost) } func TestDataSourceCostConfig_GetSlicingArguments(t *testing.T) { @@ -92,21 +89,18 @@ func TestCostTreeNode_TotalCost(t *testing.T) { FieldName: "users", FieldCost: 1, Multiplier: 10, // "first: 10" - Parent: root, } root.Children = append(root.Children, users) name := &CostTreeNode{ FieldName: "name", FieldCost: 1, - Parent: users, } users.Children = append(users.Children, name) email := &CostTreeNode{ FieldName: "email", FieldCost: 1, - Parent: users, } users.Children = append(users.Children, email) @@ -129,9 +123,7 @@ func TestCostCalculator_BasicFlow(t *testing.T) { calc.SetDataSourceCostConfig(testDSHash1, config) // Simulate entering and leaving fields (two-phase: Enter creates skeleton, Leave calculates costs) - calc.EnterField(1, "Query", "users", true, []CostFieldArgument{ - {Name: "first", IntValue: 10}, - }) + calc.EnterField(1, "Query", "users", true, map[string]int{"first":10}) calc.EnterField(2, "User", "name", false, nil) calc.LeaveField(2, []DSHash{testDSHash1}) calc.EnterField(3, "User", "email", false, nil) @@ -144,8 +136,8 @@ func TestCostCalculator_BasicFlow(t *testing.T) { assert.True(t, tree.Total > 0) totalCost := calc.GetTotalCost() - // users: 2 (field) + 10 (list) + (1 + 1) * 10 = 12 + 20 = 32 - assert.Equal(t, 32, totalCost) + // Per IBM spec: users weight=2 + (name(1) + email(1)) * 10 = 2 + 20 = 22 + assert.Equal(t, 22, totalCost) } func TestCostCalculator_Disabled(t *testing.T) { @@ -163,7 +155,7 @@ func TestCostCalculator_MultipleDataSources(t *testing.T) { calc := NewCostCalculator() calc.Enable() - // Configure two different data sources with different costs + // Configure two different data sources with different weights config1 := NewDataSourceCostConfig() config1.FieldConfig["User.name"] = &FieldCostConfig{ Weight: 2, @@ -176,13 +168,13 @@ func TestCostCalculator_MultipleDataSources(t *testing.T) { } calc.SetDataSourceCostConfig(testDSHash2, config2) - // Field planned on both data sources - costs should be aggregated + // Field planned on multiple data sources - per IBM spec, use first data source's weight calc.EnterField(1, "User", "name", false, nil) calc.LeaveField(1, []DSHash{testDSHash1, testDSHash2}) totalCost := calc.GetTotalCost() - // Weight from subgraph1 (2) + cost from subgraph2 (3) = 5 - assert.Equal(t, 5, totalCost) + // Per IBM spec: field is resolved once, using first data source weight = 2 + assert.Equal(t, 2, totalCost) } func TestCostCalculator_NoDataSource(t *testing.T) { @@ -191,9 +183,6 @@ func TestCostCalculator_NoDataSource(t *testing.T) { // Set default config defaultConfig := NewDataSourceCostConfig() - defaultConfig.Defaults = &CostDefaults{ - FieldCost: 2, - } calc.SetDefaultCostConfig(defaultConfig) // Field with no data source - should use default config @@ -201,7 +190,7 @@ func TestCostCalculator_NoDataSource(t *testing.T) { calc.LeaveField(1, nil) totalCost := calc.GetTotalCost() - assert.Equal(t, 2, totalCost) + assert.Equal(t, 1, totalCost) } func TestCostTree_Calculate(t *testing.T) { @@ -231,7 +220,6 @@ func TestNilCostConfig(t *testing.T) { assert.Equal(t, 0, config.GetArgumentCost("Type", "field", "arg")) assert.Equal(t, 0, config.GetScalarCost("String")) assert.Equal(t, 0, config.GetEnumCost("Status")) - assert.Equal(t, 0, config.GetListCost()) assert.Equal(t, 0, config.GetObjectCost()) assert.Nil(t, config.GetSlicingArguments("Type", "field")) @@ -308,9 +296,7 @@ func TestCostCalculator_ListSizeSlicingArg(t *testing.T) { calc.SetDataSourceCostConfig(testDSHash1, config) // Enter field with "first: 10" argument - calc.EnterField(1, "Query", "users", true, []CostFieldArgument{ - {Name: "first", IntValue: 10}, - }) + calc.EnterField(1, "Query", "users", true, map[string]int{"first":10}) calc.LeaveField(1, []DSHash{testDSHash1}) // multiplier should be 10 (from slicing arg), not 50 diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 61aae887d4..d3f7736fc7 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -476,27 +476,21 @@ func (v *Visitor) getFieldDataSourceHashes(ref int) []DSHash { } // costFieldArguments extracts arguments from a field for cost calculation -func (v *Visitor) costFieldArguments(ref int) []CostFieldArgument { +func (v *Visitor) costFieldArguments(ref int) map[string]int { argRefs := v.Operation.FieldArguments(ref) if len(argRefs) == 0 { return nil } - arguments := make([]CostFieldArgument, 0, len(argRefs)) + arguments := make(map[string]int, len(argRefs)) for _, argRef := range argRefs { argName := v.Operation.ArgumentNameString(argRef) argValue := v.Operation.ArgumentValue(argRef) - arg := CostFieldArgument{ - Name: argName, - } - // Extract integer value if present (for multipliers like "first", "limit") if argValue.Kind == ast.ValueKindInteger { - arg.IntValue = int(v.Operation.IntValueAsInt(argValue.Ref)) + arguments[argName] = int(v.Operation.IntValueAsInt(argValue.Ref)) } - - arguments = append(arguments, arg) } return arguments From b84582a588450c473fbd7e334e41153ac3ef6512 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Thu, 18 Dec 2025 18:00:06 +0200 Subject: [PATCH 03/43] remove leaveCost --- execution/engine/execution_engine_test.go | 4 +- v2/pkg/engine/plan/planner.go | 1 + v2/pkg/engine/plan/static_cost.go | 4 ++ v2/pkg/engine/plan/static_cost_test.go | 1 + v2/pkg/engine/plan/visitor.go | 45 ++++++++++++++++------- 5 files changed, 39 insertions(+), 16 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 133115f71a..8ba95ac7ce 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -263,11 +263,11 @@ func TestExecutionEngine_Execute(t *testing.T) { engineConf.plannerConfig.Debug = plan.DebugConfiguration{ PrintOperationTransformations: true, PrintPlanningPaths: true, - PrintNodeSuggestions: true, + // PrintNodeSuggestions: true, PrintQueryPlans: true, ConfigurationVisitor: true, PlanningVisitor: true, - DatasourceVisitor: true, + // DatasourceVisitor: true, } ctx, cancel := context.WithCancel(context.Background()) diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index 0cc3e46fdb..08624e8084 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -62,6 +62,7 @@ func NewPlanner(config Configuration) (*Planner, error) { // Initialize cost calculator and configure from data sources costCalc := NewCostCalculator() + costCalc.Enable() for _, ds := range config.DataSources { if costConfig := ds.GetCostConfig(); costConfig != nil { costCalc.SetDataSourceCostConfig(ds.Hash(), costConfig) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 47331d7ea5..d3d9aa8ce3 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -416,6 +416,10 @@ func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { } fieldCostConfig := config.GetFieldCostConfig(typeName, fieldName) + if fieldCostConfig == nil { + node.Multiplier = 1 + return + } node.Multiplier = 0 for _, slicingArg := range fieldCostConfig.SlicingArguments { if argValue, ok := node.arguments[slicingArg]; ok && argValue > 0 { diff --git a/v2/pkg/engine/plan/static_cost_test.go b/v2/pkg/engine/plan/static_cost_test.go index f77d4b7c8c..b521401316 100644 --- a/v2/pkg/engine/plan/static_cost_test.go +++ b/v2/pkg/engine/plan/static_cost_test.go @@ -280,6 +280,7 @@ func TestCostCalculator_ListSizeAssumedSize(t *testing.T) { // multiplier should be 50 (assumed size) tree := calc.GetTree() assert.Equal(t, 50, tree.Root.Children[0].Multiplier) + assert.Equal(t, 51, calc.GetTotalCost()) } func TestCostCalculator_ListSizeSlicingArg(t *testing.T) { diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index d3f7736fc7..76480c1aed 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -440,23 +440,12 @@ func (v *Visitor) enterFieldCost(ref int) { // Extract arguments for cost calculation arguments := v.costFieldArguments(ref) + // directives := v.costFieldDirectives(ref) + // Create skeleton node - dsHashes will be filled in leaveFieldCost v.costCalculator.EnterField(ref, typeName, fieldName, isListType, arguments) } -// leaveFieldCost calculates costs and pops from the cost stack. -// Called in LeaveField because fieldPlanners is populated by AllowVisitor on LeaveField. -func (v *Visitor) leaveFieldCost(ref int) { - if v.costCalculator == nil || !v.costCalculator.IsEnabled() { - return - } - - // Now fieldPlanners is populated, get the data source hashes - dsHashes := v.getFieldDataSourceHashes(ref) - - v.costCalculator.LeaveField(ref, dsHashes) -} - // getFieldDataSourceHashes returns all data source hashes for the field. // A field can be planned on multiple data sources in federation scenarios. func (v *Visitor) getFieldDataSourceHashes(ref int) []DSHash { @@ -475,6 +464,7 @@ func (v *Visitor) getFieldDataSourceHashes(ref int) []DSHash { return dsHashes } +// costFieldArguments extracts arguments from a field for cost calculation // costFieldArguments extracts arguments from a field for cost calculation func (v *Visitor) costFieldArguments(ref int) map[string]int { argRefs := v.Operation.FieldArguments(ref) @@ -487,15 +477,39 @@ func (v *Visitor) costFieldArguments(ref int) map[string]int { argName := v.Operation.ArgumentNameString(argRef) argValue := v.Operation.ArgumentValue(argRef) + fmt.Printf("costFieldArguments: argName=%s, argValue=%v\n", argName, argValue) // Extract integer value if present (for multipliers like "first", "limit") if argValue.Kind == ast.ValueKindInteger { arguments[argName] = int(v.Operation.IntValueAsInt(argValue.Ref)) } + if argValue.Kind == ast.ValueKindVariable { + } } return arguments } +// func (v *Visitor) costFieldDirectives(ref int) map[string]int { +// refs := v.Operation.FieldDirectives(ref) +// if len(refs) == 0 { +// return nil +// } +// +// arguments := make(map[string]int, len(refs)) +// for _, dirRef := range refs { +// dirName := v.Operation.DirectiveName(dirRef) +// dirArgsRef := v.Operation.DirectiveArgumentSet(dirRef) +// +// fmt.Printf("costFieldDirectives: dirName=%s, dirArgsRef=%v\n", dirName, dirArgsRef) +// // Extract integer value if present (for multipliers like "first", "limit") +// if dirArgsRef.Kind == ast.ValueKindInteger { +// arguments[dirName] = int(v.Operation.IntValueAsInt(dirArgsRef.Ref)) +// } +// } +// +// return arguments +// } + // GetTotalCost returns the total calculated cost for the query func (v *Visitor) GetTotalCost() int { if v.costCalculator == nil { @@ -715,7 +729,10 @@ func (v *Visitor) LeaveField(ref int) { // Calculate costs and pop from cost stack // This is done in LeaveField because fieldPlanners is populated by AllowVisitor on LeaveField - v.leaveFieldCost(ref) + if v.costCalculator != nil && v.costCalculator.IsEnabled() { + dsHashes := v.getFieldDataSourceHashes(ref) + v.costCalculator.LeaveField(ref, dsHashes) + } if v.currentFields[len(v.currentFields)-1].popOnField == ref { v.currentFields = v.currentFields[:len(v.currentFields)-1] From 41b4baac973f0eb4bf6f28119580af65e746f74d Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:54:00 +0200 Subject: [PATCH 04/43] start refactoring the code --- v2/pkg/ast/ast_type.go | 2 + v2/pkg/engine/plan/node_selection_visitor.go | 2 +- v2/pkg/engine/plan/path_builder_visitor.go | 2 +- v2/pkg/engine/plan/static_cost.go | 291 +++++++++---------- v2/pkg/engine/plan/static_cost_test.go | 41 ++- v2/pkg/engine/plan/visitor.go | 55 +++- 6 files changed, 201 insertions(+), 192 deletions(-) diff --git a/v2/pkg/ast/ast_type.go b/v2/pkg/ast/ast_type.go index 6c22f33fe3..92aff65a7b 100644 --- a/v2/pkg/ast/ast_type.go +++ b/v2/pkg/ast/ast_type.go @@ -250,6 +250,7 @@ func (d *Document) ResolveTypeNameString(ref int) string { return unsafebytes.BytesToString(d.ResolveTypeNameBytes(ref)) } +// ResolveUnderlyingType unwraps the ref type until it finds the named type. func (d *Document) ResolveUnderlyingType(ref int) (typeRef int) { typeRef = ref graphqlType := d.Types[ref] @@ -261,6 +262,7 @@ func (d *Document) ResolveUnderlyingType(ref int) (typeRef int) { return } +// ResolveListOrNameType unwraps the ref type until it finds the named or list type. func (d *Document) ResolveListOrNameType(ref int) (typeRef int) { typeRef = ref graphqlType := d.Types[ref] diff --git a/v2/pkg/engine/plan/node_selection_visitor.go b/v2/pkg/engine/plan/node_selection_visitor.go index 8f9f80fb48..bbcffd3c43 100644 --- a/v2/pkg/engine/plan/node_selection_visitor.go +++ b/v2/pkg/engine/plan/node_selection_visitor.go @@ -34,7 +34,7 @@ type nodeSelectionVisitor struct { visitedFieldsRequiresChecks map[fieldIndexKey]struct{} // visitedFieldsRequiresChecks is a map[fieldIndexKey] of already processed fields which we check for presence of @requires directive visitedFieldsKeyChecks map[fieldIndexKey]struct{} // visitedFieldsKeyChecks is a map[fieldIndexKey] of already processed fields which we check for @key requirements - visitedFieldsAbstractChecks map[int]struct{} // visitedFieldsAbstractChecks is a map[FieldRef] of already processed fields which we check for abstract type, e.g. union or interface + visitedFieldsAbstractChecks map[int]struct{} // visitedFieldsAbstractChecks is a map[fieldRef] of already processed fields which we check for abstract type, e.g. union or interface fieldDependsOn map[fieldIndexKey][]int // fieldDependsOn is a map[fieldIndexKey][]fieldRef - holds list of field refs which are required by a field ref, e.g. field should be planned only after required fields were planned fieldRefDependsOn map[int][]int // fieldRefDependsOn is a map[fieldRef][]fieldRef - holds list of field refs which are required by a field ref, it is a second index without datasource hash fieldRequirementsConfigs map[fieldIndexKey][]FederationFieldConfiguration // fieldRequirementsConfigs is a map[fieldIndexKey]FederationFieldConfiguration - holds a list of required configuratuibs for a field ref to later built representation variables diff --git a/v2/pkg/engine/plan/path_builder_visitor.go b/v2/pkg/engine/plan/path_builder_visitor.go index 50f74d0ae3..db26ea67ca 100644 --- a/v2/pkg/engine/plan/path_builder_visitor.go +++ b/v2/pkg/engine/plan/path_builder_visitor.go @@ -43,7 +43,7 @@ type pathBuilderVisitor struct { addedPathTracker []pathConfiguration // addedPathTracker is a list of paths which were added addedPathTrackerIndex map[string][]int // addedPathTrackerIndex is a map of path to index in addedPathTracker - fieldDependenciesForPlanners map[int][]int // fieldDependenciesForPlanners is a map[FieldRef][]plannerIdx holds list of planner ids which depends on a field ref. Used for @key dependencies + fieldDependenciesForPlanners map[int][]int // fieldDependenciesForPlanners is a map[fieldRef][]plannerIdx holds list of planner ids which depends on a field ref. Used for @key dependencies fieldsPlannedOn map[int][]int // fieldsPlannedOn is a map[fieldRef][]plannerIdx holds list of planner ids which planned a field ref secondaryRun bool // secondaryRun is a flag to indicate that we're running the pathBuilderVisitor not the first time diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index d3d9aa8ce3..9e9e6c362e 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -1,70 +1,77 @@ package plan // StaticCostDefaults contains default cost values when no specific costs are configured -var StaticCostDefaults = CostDefaults{ - FieldCost: 1, - ArgumentCost: 0, - ScalarCost: 0, - EnumCost: 0, - ObjectCost: 1, - ListCost: 10, // The assumed maximum size of a list for fields that return lists. +var StaticCostDefaults = WeightDefaults{ + Field: 1, + Scalar: 0, + Enum: 0, + Object: 1, + List: 10, // The assumed maximum size of a list for fields that return lists. } -// CostDefaults defines default cost values for different GraphQL elements -type CostDefaults struct { - FieldCost int - ArgumentCost int - ScalarCost int - EnumCost int - ObjectCost int - ListCost int +// WeightDefaults defines default cost values for different GraphQL elements +type WeightDefaults struct { + Field int + Scalar int + Enum int + Object int + List int } -// FieldCostConfig defines cost configuration for a specific field -// Includes @listSize directive fields for list cost calculation +// FieldCostConfig defines cost configuration for a specific field of an object or input object. +// Includes @listSize directive fields for objects. type FieldCostConfig struct { Weight int - // ArgumentWeights maps argument name to its weight/cost + // ArgumentWeights maps an argument name to its weight. + // Location: ARGUMENT_DEFINITION ArgumentWeights map[string]int - // AssumedSize is the default assumed size when no slicing argument is provided (from @listSize) - // If 0, the global default list cost is used + // Fields below are defined only on FIELD_DEFINITION from the @listSize directive. + + // AssumedSize is the default assumed size when no slicing argument is provided. + // If 0, the global default list cost is used. AssumedSize int // SlicingArguments are argument names that control list size (e.g., "first", "last", "limit") - // The value of these arguments will be used as the multiplier (from @listSize) + // The value of these arguments will be used as the multiplier. SlicingArguments []string - // SizedFields are field names that return the actual size of the list (from @listSize) - // These can be used for more accurate, actual cost estimation + // SizedFields are field names that return the actual size of the list. + // These can be used for more accurate, actual cost estimation. SizedFields []string - // RequireOneSlicingArgument if true, at least one slicing argument must be provided (from @listSize) - // If false and no slicing argument is provided, AssumedSize is used + // RequireOneSlicingArgument if true, at least one slicing argument must be provided. + // If false and no slicing argument is provided, AssumedSize is used. RequireOneSlicingArgument bool } -// DataSourceCostConfig holds all cost configurations for a data source +// DataSourceCostConfig holds all cost configurations for a data source. +// This data is passed from the composition. type DataSourceCostConfig struct { - // FieldConfig maps "TypeName.FieldName" to cost config - FieldConfig map[string]*FieldCostConfig - - // Object Weights include all object, scalar and enum definitions. + // Fields maps field coordinate to its cost config. + // Location: FIELD_DEFINITION, INPUT_FIELD_DEFINITION + Fields map[FieldCoordinate]*FieldCostConfig - // ScalarWeights maps scalar type name to weight - ScalarWeights map[string]int + // Types maps TypeName to the weight of the object, scalar or enum definition. + // If TypeName is not present, the default value for Enums and Scalars is 0, otherwise 1. + // Weight assigned to the field or argument definitions overrides the weight of type definition. + // Location: ENUM, OBJECT, SCALAR + Types map[string]int - // EnumWeights maps enum type name to weight - EnumWeights map[string]int + // Arguments on directives is a special case. They use a special kind of coordinate: + // directive name + argument name. That should be the key mapped to the weight. + // + // Directives can be used on [input] object fields and arguments of fields. This creates + // mutual recursion between them that complicates cost calculation. + // We avoid them intentionally in the first iteration. } // NewDataSourceCostConfig creates a new cost config with defaults func NewDataSourceCostConfig() *DataSourceCostConfig { return &DataSourceCostConfig{ - FieldConfig: make(map[string]*FieldCostConfig), - ScalarWeights: make(map[string]int), - EnumWeights: make(map[string]int), + Fields: make(map[FieldCoordinate]*FieldCostConfig), + Types: make(map[string]int), } } @@ -72,101 +79,40 @@ func (c *DataSourceCostConfig) GetFieldCostConfig(typeName, fieldName string) *F if c == nil { return nil } - - key := typeName + "." + fieldName - return c.FieldConfig[key] -} - -// GetFieldCost returns the cost for a field, falling back to defaults -func (c *DataSourceCostConfig) GetFieldCost(typeName, fieldName string) int { - if c == nil { - return 0 - } - - key := typeName + "." + fieldName - if fc, ok := c.FieldConfig[key]; ok { - return fc.Weight - } - - return StaticCostDefaults.FieldCost -} - -// GetSlicingArguments returns the slicing argument names for a field -// These are arguments that control list size (e.g., "first", "last", "limit") -func (c *DataSourceCostConfig) GetSlicingArguments(typeName, fieldName string) []string { - if c == nil { - return nil - } - - key := typeName + "." + fieldName - if fc, ok := c.FieldConfig[key]; ok { - return fc.SlicingArguments - } - return nil -} - -// GetAssumedListSize returns the assumed list size for a field when no slicing argument is provided -func (c *DataSourceCostConfig) GetAssumedListSize(typeName, fieldName string) int { - if c == nil { - return 0 - } - - key := typeName + "." + fieldName - if fc, ok := c.FieldConfig[key]; ok { - return fc.AssumedSize - } - return 0 + return c.Fields[FieldCoordinate{typeName, fieldName}] } -// GetArgumentCost returns the cost for an argument, falling back to defaults -func (c *DataSourceCostConfig) GetArgumentCost(typeName, fieldName, argName string) int { +// ScalarWeight returns the cost for a scalar type +func (c *DataSourceCostConfig) ScalarWeight(scalarName string) int { if c == nil { return 0 } - - key := typeName + "." + fieldName - if fc, ok := c.FieldConfig[key]; ok && fc.ArgumentWeights != nil { - if weight, ok := fc.ArgumentWeights[argName]; ok { - return weight - } + if cost, ok := c.Types[scalarName]; ok { + return cost } - - return StaticCostDefaults.ArgumentCost + return StaticCostDefaults.Scalar } -// GetScalarCost returns the cost for a scalar type -func (c *DataSourceCostConfig) GetScalarCost(scalarName string) int { +// EnumWeight returns the cost for an enum type +func (c *DataSourceCostConfig) EnumWeight(enumName string) int { if c == nil { return 0 } - - if cost, ok := c.ScalarWeights[scalarName]; ok { + if cost, ok := c.Types[enumName]; ok { return cost } - - return StaticCostDefaults.ScalarCost + return StaticCostDefaults.Enum } -// GetEnumCost returns the cost for an enum type -func (c *DataSourceCostConfig) GetEnumCost(enumName string) int { +// ObjectWeight returns the default object cost +func (c *DataSourceCostConfig) ObjectWeight(name string) int { if c == nil { return 0 } - - if cost, ok := c.EnumWeights[enumName]; ok { + if cost, ok := c.Types[name]; ok { return cost } - - return StaticCostDefaults.EnumCost -} - -// GetObjectCost returns the default object cost -func (c *DataSourceCostConfig) GetObjectCost() int { - if c == nil { - return 0 - } - - return StaticCostDefaults.ObjectCost + return StaticCostDefaults.Object } func (c *DataSourceCostConfig) GetDefaultListCost() int { @@ -176,17 +122,14 @@ func (c *DataSourceCostConfig) GetDefaultListCost() int { // CostTreeNode represents a node in the cost calculation tree // Based on IBM GraphQL Cost Specification: https://ibm.github.io/graphql-specs/cost-spec.html type CostTreeNode struct { - // FieldRef is the AST field reference - FieldRef int + // fieldRef is the AST field reference + fieldRef int - // TypeName is the enclosing type name - TypeName string + // Enclosing type name and field name + fieldCoord FieldCoordinate - // FieldName is the field name - FieldName string - - // DataSourceHashes identifies which data sources this field is resolved from - DataSourceHashes []DSHash + // dataSourceHashes identifies which data sources this field is resolved from + dataSourceHashes []DSHash // FieldCost is the weight of this field from @cost directive FieldCost int @@ -200,12 +143,42 @@ type CostTreeNode struct { // Applied to children costs for list fields Multiplier int - // Children contains child field costs + // Children contain child field costs Children []*CostTreeNode - // isListType and arguments are stored temporarily for deferred cost calculation + // The data below is stored for deferred cost calculation. + + // What is the name of an unwrapped (named) type that is returned by this field? + fieldTypeName string + isListType bool - arguments map[string]int + + // arguments contain the values of arguments passed to the field + arguments map[string]ArgumentInfo +} + +type ArgumentInfo struct { + intValue int + + // The name of an unwrapped type. + typeName string + + // isInputObject is true for an input object passed to the argument, + // otherwise the argument is Scalar or Enum. + isInputObject bool + + // If argument is passed an input object, we want to gather counts + // for all the field coordinates with non-null values used in the argument. + // + // For example, for + // "input A { x: Int, rec: A! }" + // following value is passed: + // { x: 1, rec: { x: 2, rec: { x: 3 } } }, + // then coordCounts will be: + // { {"A", "rec"}: 2, {"A", "x"}: 3 } + // + coordCounts map[FieldCoordinate]int + isScalar bool } // TotalCost calculates the total cost of this node and all descendants @@ -269,7 +242,7 @@ type CostCalculator struct { func NewCostCalculator() *CostCalculator { tree := &CostTree{ Root: &CostTreeNode{ - FieldName: "_root", + fieldCoord: FieldCoordinate{"_none", "_root"}, Multiplier: 1, }, } @@ -332,19 +305,20 @@ func (c *CostCalculator) CurrentNode() *CostTreeNode { // EnterField is called when entering a field during AST traversal. // It creates a skeleton node and pushes it onto the stack. // The actual cost calculation happens in LeaveField when fieldPlanners data is available. -func (c *CostCalculator) EnterField(fieldRef int, typeName, fieldName string, isListType bool, arguments map[string]int) { +func (c *CostCalculator) EnterField(fieldRef int, coord FieldCoordinate, namedTypeName string, + isListType bool, arguments map[string]ArgumentInfo) { if !c.enabled { return } - // Create skeleton cost node - costs will be calculated in LeaveField + // Create skeleton cost node. Costs will be calculated in LeaveField node := &CostTreeNode{ - FieldRef: fieldRef, - TypeName: typeName, - FieldName: fieldName, - Multiplier: 1, - isListType: isListType, - arguments: arguments, + fieldRef: fieldRef, + fieldCoord: coord, + Multiplier: 1, + fieldTypeName: namedTypeName, + isListType: isListType, + arguments: arguments, } // Attach to parent @@ -370,12 +344,12 @@ func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { } current := c.stack[len(c.stack)-1] - if current.FieldRef != fieldRef { + if current.fieldRef != fieldRef { return } // Now calculate costs with the data source information - current.DataSourceHashes = dsHashes + current.dataSourceHashes = dsHashes c.calculateNodeCosts(current) // Pop from stack @@ -386,57 +360,58 @@ func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { // calculateNodeCosts implements IBM GraphQL Cost Specification // See: https://ibm.github.io/graphql-specs/cost-spec.html#sec-Field-Cost func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { - typeName := node.TypeName - fieldName := node.FieldName - // Get the cost config (use first data source config, or default) var config *DataSourceCostConfig - if len(node.DataSourceHashes) > 0 { - config = c.getCostConfig(node.DataSourceHashes[0]) + if len(node.dataSourceHashes) > 0 { + config = c.getCostConfig(node.dataSourceHashes[0]) } else { config = c.getDefaultCostConfig() } - node.FieldCost = config.GetFieldCost(typeName, fieldName) + fieldConfig := config.Fields[node.fieldCoord] + if fieldConfig != nil { + node.FieldCost = fieldConfig.Weight + } else { + // use the weight of the type returned by this field + if typeWeight, ok := config.Types[node.fieldTypeName]; ok { + node.FieldCost = typeWeight + } + } + // TODO: check how we fill node.arguments for argName := range node.arguments { - node.ArgumentsCost += config.GetArgumentCost(typeName, fieldName, argName) + weight, ok := fieldConfig.ArgumentWeights[argName] + if ok { + node.ArgumentsCost += weight + } // TODO: arguments should include costs of input object fields } - // TODO: Directives Cost should includes the weights of all its arguments - - // TODO: arguments, directives and fields of input object are mutually recursive, - // we should recurse on them and sum all of possible values. - // Compute multiplier if !node.isListType { node.Multiplier = 1 return } - fieldCostConfig := config.GetFieldCostConfig(typeName, fieldName) - if fieldCostConfig == nil { + if fieldConfig == nil { node.Multiplier = 1 return } node.Multiplier = 0 - for _, slicingArg := range fieldCostConfig.SlicingArguments { - if argValue, ok := node.arguments[slicingArg]; ok && argValue > 0 { - if argValue > node.Multiplier { - node.Multiplier = argValue - } + for _, slicingArg := range fieldConfig.SlicingArguments { + argInfo, ok := node.arguments[slicingArg] + if ok && argInfo.isScalar && argInfo.intValue > node.Multiplier{ + node.Multiplier = argInfo.intValue } } - if node.Multiplier == 0 && fieldCostConfig.AssumedSize > 0 { - node.Multiplier = fieldCostConfig.AssumedSize + if node.Multiplier == 0 && fieldConfig.AssumedSize > 0 { + node.Multiplier = fieldConfig.AssumedSize return } - node.Multiplier = config.GetDefaultListCost() + node.Multiplier = StaticCostDefaults.List } - // GetTree returns the cost tree func (c *CostCalculator) GetTree() *CostTree { c.tree.Calculate() diff --git a/v2/pkg/engine/plan/static_cost_test.go b/v2/pkg/engine/plan/static_cost_test.go index b521401316..9a9b4a16e4 100644 --- a/v2/pkg/engine/plan/static_cost_test.go +++ b/v2/pkg/engine/plan/static_cost_test.go @@ -14,18 +14,17 @@ const ( func TestCostDefaults(t *testing.T) { // Test that defaults are set correctly - assert.Equal(t, 1, StaticCostDefaults.FieldCost) - assert.Equal(t, 0, StaticCostDefaults.ArgumentCost) - assert.Equal(t, 0, StaticCostDefaults.ScalarCost) - assert.Equal(t, 0, StaticCostDefaults.EnumCost) - assert.Equal(t, 1, StaticCostDefaults.ObjectCost) - assert.Equal(t, 10, StaticCostDefaults.ListCost) + assert.Equal(t, 1, StaticCostDefaults.Field) + assert.Equal(t, 0, StaticCostDefaults.Scalar) + assert.Equal(t, 0, StaticCostDefaults.Enum) + assert.Equal(t, 1, StaticCostDefaults.Object) + assert.Equal(t, 10, StaticCostDefaults.List) } func TestNewDataSourceCostConfig(t *testing.T) { config := NewDataSourceCostConfig() - assert.NotNil(t, config.FieldConfig) + assert.NotNil(t, config.Fields) assert.NotNil(t, config.ScalarWeights) assert.NotNil(t, config.EnumWeights) } @@ -35,10 +34,10 @@ func TestDataSourceCostConfig_GetFieldCost(t *testing.T) { // Test default cost cost := config.GetFieldCost("Query", "users") - assert.Equal(t, StaticCostDefaults.FieldCost, cost) + assert.Equal(t, StaticCostDefaults.Field, cost) // Test custom cost - config.FieldConfig["Query.users"] = &FieldCostConfig{ + config.Fields["Query.users"] = &FieldCostConfig{ Weight: 5, } cost = config.GetFieldCost("Query", "users") @@ -57,7 +56,7 @@ func TestDataSourceCostConfig_GetSlicingArguments(t *testing.T) { assert.Nil(t, args) // Test with list size config - config.FieldConfig["Query.users"] = &FieldCostConfig{ + config.Fields["Query.users"] = &FieldCostConfig{ Weight: 1, AssumedSize: 100, SlicingArguments: []string{"first", "last"}, @@ -116,14 +115,14 @@ func TestCostCalculator_BasicFlow(t *testing.T) { calc.Enable() config := NewDataSourceCostConfig() - config.FieldConfig["Query.users"] = &FieldCostConfig{ + config.Fields["Query.users"] = &FieldCostConfig{ Weight: 2, SlicingArguments: []string{"first"}, } calc.SetDataSourceCostConfig(testDSHash1, config) // Simulate entering and leaving fields (two-phase: Enter creates skeleton, Leave calculates costs) - calc.EnterField(1, "Query", "users", true, map[string]int{"first":10}) + calc.EnterField(1, "Query", "users", true, map[string]int{"first": 10}) calc.EnterField(2, "User", "name", false, nil) calc.LeaveField(2, []DSHash{testDSHash1}) calc.EnterField(3, "User", "email", false, nil) @@ -157,13 +156,13 @@ func TestCostCalculator_MultipleDataSources(t *testing.T) { // Configure two different data sources with different weights config1 := NewDataSourceCostConfig() - config1.FieldConfig["User.name"] = &FieldCostConfig{ + config1.Fields["User.name"] = &FieldCostConfig{ Weight: 2, } calc.SetDataSourceCostConfig(testDSHash1, config1) config2 := NewDataSourceCostConfig() - config2.FieldConfig["User.name"] = &FieldCostConfig{ + config2.Fields["User.name"] = &FieldCostConfig{ Weight: 3, } calc.SetDataSourceCostConfig(testDSHash2, config2) @@ -218,9 +217,9 @@ func TestNilCostConfig(t *testing.T) { // All methods should handle nil gracefully assert.Equal(t, 0, config.GetFieldCost("Type", "field")) assert.Equal(t, 0, config.GetArgumentCost("Type", "field", "arg")) - assert.Equal(t, 0, config.GetScalarCost("String")) - assert.Equal(t, 0, config.GetEnumCost("Status")) - assert.Equal(t, 0, config.GetObjectCost()) + assert.Equal(t, 0, config.ScalarWeight("String")) + assert.Equal(t, 0, config.EnumWeight("Status")) + assert.Equal(t, 0, config.ObjectWeight()) assert.Nil(t, config.GetSlicingArguments("Type", "field")) assert.Equal(t, 0, config.GetAssumedListSize("Type", "field")) @@ -233,7 +232,7 @@ func TestCostCalculator_TwoPhaseFlow(t *testing.T) { calc.Enable() config := NewDataSourceCostConfig() - config.FieldConfig["Query.users"] = &FieldCostConfig{ + config.Fields["Query.users"] = &FieldCostConfig{ Weight: 5, } calc.SetDataSourceCostConfig(testDSHash1, config) @@ -261,7 +260,7 @@ func TestCostCalculator_ListSizeAssumedSize(t *testing.T) { calc.Enable() config := NewDataSourceCostConfig() - config.FieldConfig["Query.users"] = &FieldCostConfig{ + config.Fields["Query.users"] = &FieldCostConfig{ Weight: 1, AssumedSize: 50, // Assume 50 items if no slicing arg SlicingArguments: []string{"first", "last"}, @@ -289,7 +288,7 @@ func TestCostCalculator_ListSizeSlicingArg(t *testing.T) { calc.Enable() config := NewDataSourceCostConfig() - config.FieldConfig["Query.users"] = &FieldCostConfig{ + config.Fields["Query.users"] = &FieldCostConfig{ Weight: 1, AssumedSize: 50, // This should NOT be used SlicingArguments: []string{"first", "last"}, @@ -297,7 +296,7 @@ func TestCostCalculator_ListSizeSlicingArg(t *testing.T) { calc.SetDataSourceCostConfig(testDSHash1, config) // Enter field with "first: 10" argument - calc.EnterField(1, "Query", "users", true, map[string]int{"first":10}) + calc.EnterField(1, "Query", "users", true, map[string]int{"first": 10}) calc.LeaveField(1, []DSHash{testDSHash1}) // multiplier should be 10 (from slicing arg), not 50 diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 76480c1aed..ea51b616c5 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -428,22 +428,22 @@ func (v *Visitor) enterFieldCost(ref int) { typeName := v.Walker.EnclosingTypeDefinition.NameString(v.Definition) fieldName := v.Operation.FieldNameUnsafeString(ref) + coord := FieldCoordinate{typeName, fieldName} - // Check if the field returns a list type fieldDefinition, ok := v.Walker.FieldDefinition(ref) if !ok { return } fieldDefinitionTypeRef := v.Definition.FieldDefinitionType(fieldDefinition) isListType := v.Definition.TypeIsList(fieldDefinitionTypeRef) + namedTypeName := v.Definition.ResolveTypeNameString(fieldDefinitionTypeRef) - // Extract arguments for cost calculation arguments := v.costFieldArguments(ref) // directives := v.costFieldDirectives(ref) // Create skeleton node - dsHashes will be filled in leaveFieldCost - v.costCalculator.EnterField(ref, typeName, fieldName, isListType, arguments) + v.costCalculator.EnterField(ref, coord, namedTypeName, isListType, arguments) } // getFieldDataSourceHashes returns all data source hashes for the field. @@ -465,25 +465,58 @@ func (v *Visitor) getFieldDataSourceHashes(ref int) []DSHash { } // costFieldArguments extracts arguments from a field for cost calculation -// costFieldArguments extracts arguments from a field for cost calculation -func (v *Visitor) costFieldArguments(ref int) map[string]int { +func (v *Visitor) costFieldArguments(ref int) map[string]ArgumentInfo { argRefs := v.Operation.FieldArguments(ref) if len(argRefs) == 0 { return nil } - arguments := make(map[string]int, len(argRefs)) + arguments := make(map[string]ArgumentInfo, len(argRefs)) for _, argRef := range argRefs { argName := v.Operation.ArgumentNameString(argRef) argValue := v.Operation.ArgumentValue(argRef) + argInfo := ArgumentInfo{} fmt.Printf("costFieldArguments: argName=%s, argValue=%v\n", argName, argValue) - // Extract integer value if present (for multipliers like "first", "limit") - if argValue.Kind == ast.ValueKindInteger { - arguments[argName] = int(v.Operation.IntValueAsInt(argValue.Ref)) - } - if argValue.Kind == ast.ValueKindVariable { + val, err := v.Operation.PrintValueBytes(argValue, nil) + if err != nil { + panic(err) + } + fmt.Printf("value = %s\n", val) + switch argValue.Kind { + case ast.ValueKindBoolean, ast.ValueKindEnum, ast.ValueKindString, ast.ValueKindFloat: + argInfo.isScalar = true + case ast.ValueKindNull: + continue + case ast.ValueKindInteger: + // Extract integer value if present (for multipliers like "first", "limit") + argInfo.intValue = int(v.Operation.IntValueAsInt(argValue.Ref)) + argInfo.isScalar = true + case ast.ValueKindVariable: + // TODO: we need to analyze variables that contains input object fields. + // If these fields has weight attached, use them for calculation. + // Variables are not inlined at this stage, so we need to inspect them via AST. + argInfo.isInputObject = true + variableValue := v.Operation.VariableValueNameString(argValue.Ref) + if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { + continue // omit optional argument when variable is not defined + } + variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinition, v.Operation.VariableValueNameBytes(argValue.Ref)) + if !exists { + break + } + // variableTypeRef := v.Operation.VariableDefinitions[variableDefinition].Type + argInfo.typeName = v.Operation.ResolveTypeNameString(v.Operation.VariableDefinitions[variableDefinition].Type) + + case ast.ValueKindList: + // should we do something? is it possible at all? + continue + default: + fmt.Printf("unhandled case: %v\n", argValue.Kind) + continue } + + arguments[argName] = argInfo } return arguments From cb39b19819844dcade5d93c2f6eb7b6c54cfb4dd Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Wed, 24 Dec 2025 11:12:03 +0200 Subject: [PATCH 05/43] compute for all data sources --- v2/pkg/engine/plan/static_cost.go | 167 ++++--- v2/pkg/engine/plan/static_cost_test.go | 610 ++++++++++++------------- v2/pkg/engine/plan/visitor.go | 21 +- 3 files changed, 398 insertions(+), 400 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 9e9e6c362e..b50d7e918b 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -1,5 +1,7 @@ package plan +import "fmt" + // StaticCostDefaults contains default cost values when no specific costs are configured var StaticCostDefaults = WeightDefaults{ Field: 1, @@ -75,13 +77,6 @@ func NewDataSourceCostConfig() *DataSourceCostConfig { } } -func (c *DataSourceCostConfig) GetFieldCostConfig(typeName, fieldName string) *FieldCostConfig { - if c == nil { - return nil - } - return c.Fields[FieldCoordinate{typeName, fieldName}] -} - // ScalarWeight returns the cost for a scalar type func (c *DataSourceCostConfig) ScalarWeight(scalarName string) int { if c == nil { @@ -115,10 +110,6 @@ func (c *DataSourceCostConfig) ObjectWeight(name string) int { return StaticCostDefaults.Object } -func (c *DataSourceCostConfig) GetDefaultListCost() int { - return 10 -} - // CostTreeNode represents a node in the cost calculation tree // Based on IBM GraphQL Cost Specification: https://ibm.github.io/graphql-specs/cost-spec.html type CostTreeNode struct { @@ -163,10 +154,6 @@ type ArgumentInfo struct { // The name of an unwrapped type. typeName string - // isInputObject is true for an input object passed to the argument, - // otherwise the argument is Scalar or Enum. - isInputObject bool - // If argument is passed an input object, we want to gather counts // for all the field coordinates with non-null values used in the argument. // @@ -178,7 +165,12 @@ type ArgumentInfo struct { // { {"A", "rec"}: 2, {"A", "x"}: 3 } // coordCounts map[FieldCoordinate]int - isScalar bool + + // isInputObject is true for an input object passed to the argument, + // otherwise the argument is Scalar or Enum. + isInputObject bool + + isScalar bool } // TotalCost calculates the total cost of this node and all descendants @@ -267,27 +259,27 @@ func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSour c.costConfigs[dsHash] = config } -// SetDefaultCostConfig sets the default cost config -func (c *CostCalculator) SetDefaultCostConfig(config *DataSourceCostConfig) { - c.defaultConfig = config -} - -// getCostConfig returns the cost config for a specific data source hash -func (c *CostCalculator) getCostConfig(dsHash DSHash) *DataSourceCostConfig { - if config, ok := c.costConfigs[dsHash]; ok { - return config - } - return c.getDefaultCostConfig() -} - -// getDefaultCostConfig returns the default cost config when no specific data source is available -func (c *CostCalculator) getDefaultCostConfig() *DataSourceCostConfig { - if c.defaultConfig != nil { - return c.defaultConfig - } - // Return a dummy config with defaults - return &DataSourceCostConfig{} -} +// // SetDefaultCostConfig sets the default cost config +// func (c *CostCalculator) SetDefaultCostConfig(config *DataSourceCostConfig) { +// c.defaultConfig = config +// } + +// // getCostConfig returns the cost config for a specific data source hash +// func (c *CostCalculator) getCostConfig(dsHash DSHash) *DataSourceCostConfig { +// if config, ok := c.costConfigs[dsHash]; ok { +// return config +// } +// return c.getDefaultCostConfig() +// } + +// // getDefaultCostConfig returns the default cost config when no specific data source is available +// func (c *CostCalculator) getDefaultCostConfig() *DataSourceCostConfig { +// if c.defaultConfig != nil { +// return c.defaultConfig +// } +// // Return a dummy config with defaults +// return &DataSourceCostConfig{} +// } // IsEnabled returns whether cost calculation is enabled func (c *CostCalculator) IsEnabled() bool { @@ -356,59 +348,69 @@ func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { c.stack = c.stack[:len(c.stack)-1] } -// calculateNodeCosts fills in the cost values for a node based on its data sources -// calculateNodeCosts implements IBM GraphQL Cost Specification +// calculateNodeCosts fills in the cost values for a node based on its data sources. +// It implements IBM GraphQL Cost Specification. // See: https://ibm.github.io/graphql-specs/cost-spec.html#sec-Field-Cost func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { - // Get the cost config (use first data source config, or default) - var config *DataSourceCostConfig - if len(node.dataSourceHashes) > 0 { - config = c.getCostConfig(node.dataSourceHashes[0]) - } else { - config = c.getDefaultCostConfig() + // For every data source we get different weights. + // For this node we sum weights of the field and its arguments. + // For the multiplier we pick the maximum. + if len(node.dataSourceHashes) <= 0 { + // no data source is responsible for this field + return } - fieldConfig := config.Fields[node.fieldCoord] - if fieldConfig != nil { - node.FieldCost = fieldConfig.Weight - } else { - // use the weight of the type returned by this field - if typeWeight, ok := config.Types[node.fieldTypeName]; ok { - node.FieldCost = typeWeight + node.Multiplier = 0 + + for _, dsHash := range node.dataSourceHashes { + config, ok := c.costConfigs[dsHash] + if !ok { + fmt.Printf("WARNING: no cost config for data source %v\n", dsHash) + continue } - } - // TODO: check how we fill node.arguments - for argName := range node.arguments { - weight, ok := fieldConfig.ArgumentWeights[argName] - if ok { - node.ArgumentsCost += weight + fieldConfig := config.Fields[node.fieldCoord] + if fieldConfig != nil { + node.FieldCost = fieldConfig.Weight + } else { + // use the weight of the type returned by this field + if typeWeight, ok := config.Types[node.fieldTypeName]; ok { + node.FieldCost = typeWeight + } } - // TODO: arguments should include costs of input object fields - } - // Compute multiplier - if !node.isListType { - node.Multiplier = 1 - return - } + for argName := range node.arguments { + weight, ok := fieldConfig.ArgumentWeights[argName] + if ok { + node.ArgumentsCost += weight + } + // TODO: arguments should include costs of input object fields + } - if fieldConfig == nil { - node.Multiplier = 1 - return - } - node.Multiplier = 0 - for _, slicingArg := range fieldConfig.SlicingArguments { - argInfo, ok := node.arguments[slicingArg] - if ok && argInfo.isScalar && argInfo.intValue > node.Multiplier{ - node.Multiplier = argInfo.intValue + // Compute multiplier as the maximum of data sources. + if !node.isListType { + node.Multiplier = 1 + continue + } + + if fieldConfig == nil { + node.Multiplier = 1 + continue + } + multiplier := -1 + for _, slicingArg := range fieldConfig.SlicingArguments { + argInfo, ok := node.arguments[slicingArg] + if ok && argInfo.isScalar && argInfo.intValue > 0 && argInfo.intValue > multiplier { + multiplier = argInfo.intValue + } + } + if multiplier == -1 && fieldConfig.AssumedSize > 0 { + multiplier = fieldConfig.AssumedSize + } + if multiplier > node.Multiplier { + node.Multiplier = multiplier } } - if node.Multiplier == 0 && fieldConfig.AssumedSize > 0 { - node.Multiplier = fieldConfig.AssumedSize - return - } - node.Multiplier = StaticCostDefaults.List } @@ -423,10 +425,3 @@ func (c *CostCalculator) GetTotalCost() int { c.tree.Calculate() return c.tree.Total } - -// CostFieldArgument represents a parsed field argument for cost calculation -type CostFieldArgument struct { - Name string - IntValue int - // Add other value types as needed -} diff --git a/v2/pkg/engine/plan/static_cost_test.go b/v2/pkg/engine/plan/static_cost_test.go index 9a9b4a16e4..f096b60733 100644 --- a/v2/pkg/engine/plan/static_cost_test.go +++ b/v2/pkg/engine/plan/static_cost_test.go @@ -1,305 +1,305 @@ -package plan - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -// Test DSHash values -const ( - testDSHash1 DSHash = 1001 - testDSHash2 DSHash = 1002 -) - -func TestCostDefaults(t *testing.T) { - // Test that defaults are set correctly - assert.Equal(t, 1, StaticCostDefaults.Field) - assert.Equal(t, 0, StaticCostDefaults.Scalar) - assert.Equal(t, 0, StaticCostDefaults.Enum) - assert.Equal(t, 1, StaticCostDefaults.Object) - assert.Equal(t, 10, StaticCostDefaults.List) -} - -func TestNewDataSourceCostConfig(t *testing.T) { - config := NewDataSourceCostConfig() - - assert.NotNil(t, config.Fields) - assert.NotNil(t, config.ScalarWeights) - assert.NotNil(t, config.EnumWeights) -} - -func TestDataSourceCostConfig_GetFieldCost(t *testing.T) { - config := NewDataSourceCostConfig() - - // Test default cost - cost := config.GetFieldCost("Query", "users") - assert.Equal(t, StaticCostDefaults.Field, cost) - - // Test custom cost - config.Fields["Query.users"] = &FieldCostConfig{ - Weight: 5, - } - cost = config.GetFieldCost("Query", "users") - assert.Equal(t, 5, cost) - - // Test with defaults - cost = config.GetFieldCost("Query", "posts") - assert.Equal(t, 1, cost) -} - -func TestDataSourceCostConfig_GetSlicingArguments(t *testing.T) { - config := NewDataSourceCostConfig() - - // Test no list size config - args := config.GetSlicingArguments("Query", "users") - assert.Nil(t, args) - - // Test with list size config - config.Fields["Query.users"] = &FieldCostConfig{ - Weight: 1, - AssumedSize: 100, - SlicingArguments: []string{"first", "last"}, - RequireOneSlicingArgument: true, - } - - // Test GetSlicingArguments - args = config.GetSlicingArguments("Query", "users") - assert.Equal(t, []string{"first", "last"}, args) - - // Test GetAssumedListSize - assumed := config.GetAssumedListSize("Query", "users") - assert.Equal(t, 100, assumed) -} - -func TestCostTreeNode_TotalCost(t *testing.T) { - // Build a simple tree: - // root (cost: 1) - // └── users (cost: 1, multiplier: 10 from "first" arg) - // └── name (cost: 1) - // └── email (cost: 1) - - root := &CostTreeNode{ - FieldName: "_root", - Multiplier: 1, - } - - users := &CostTreeNode{ - FieldName: "users", - FieldCost: 1, - Multiplier: 10, // "first: 10" - } - root.Children = append(root.Children, users) - - name := &CostTreeNode{ - FieldName: "name", - FieldCost: 1, - } - users.Children = append(users.Children, name) - - email := &CostTreeNode{ - FieldName: "email", - FieldCost: 1, - } - users.Children = append(users.Children, email) - - // Calculate: root cost = users cost + (children cost * multiplier) - // users: 1 + (1 + 1) * 10 = 1 + 20 = 21 - // root: 0 + 21 * 1 = 21 - total := root.TotalCost() - assert.Equal(t, 21, total) -} - -func TestCostCalculator_BasicFlow(t *testing.T) { - calc := NewCostCalculator() - calc.Enable() - - config := NewDataSourceCostConfig() - config.Fields["Query.users"] = &FieldCostConfig{ - Weight: 2, - SlicingArguments: []string{"first"}, - } - calc.SetDataSourceCostConfig(testDSHash1, config) - - // Simulate entering and leaving fields (two-phase: Enter creates skeleton, Leave calculates costs) - calc.EnterField(1, "Query", "users", true, map[string]int{"first": 10}) - calc.EnterField(2, "User", "name", false, nil) - calc.LeaveField(2, []DSHash{testDSHash1}) - calc.EnterField(3, "User", "email", false, nil) - calc.LeaveField(3, []DSHash{testDSHash1}) - calc.LeaveField(1, []DSHash{testDSHash1}) - - // Get results - tree := calc.GetTree() - assert.NotNil(t, tree) - assert.True(t, tree.Total > 0) - - totalCost := calc.GetTotalCost() - // Per IBM spec: users weight=2 + (name(1) + email(1)) * 10 = 2 + 20 = 22 - assert.Equal(t, 22, totalCost) -} - -func TestCostCalculator_Disabled(t *testing.T) { - calc := NewCostCalculator() - // Don't enable - - calc.EnterField(1, "Query", "users", true, nil) - calc.LeaveField(1, []DSHash{testDSHash1}) - - // Should return 0 when disabled - assert.Equal(t, 0, calc.GetTotalCost()) -} - -func TestCostCalculator_MultipleDataSources(t *testing.T) { - calc := NewCostCalculator() - calc.Enable() - - // Configure two different data sources with different weights - config1 := NewDataSourceCostConfig() - config1.Fields["User.name"] = &FieldCostConfig{ - Weight: 2, - } - calc.SetDataSourceCostConfig(testDSHash1, config1) - - config2 := NewDataSourceCostConfig() - config2.Fields["User.name"] = &FieldCostConfig{ - Weight: 3, - } - calc.SetDataSourceCostConfig(testDSHash2, config2) - - // Field planned on multiple data sources - per IBM spec, use first data source's weight - calc.EnterField(1, "User", "name", false, nil) - calc.LeaveField(1, []DSHash{testDSHash1, testDSHash2}) - - totalCost := calc.GetTotalCost() - // Per IBM spec: field is resolved once, using first data source weight = 2 - assert.Equal(t, 2, totalCost) -} - -func TestCostCalculator_NoDataSource(t *testing.T) { - calc := NewCostCalculator() - calc.Enable() - - // Set default config - defaultConfig := NewDataSourceCostConfig() - calc.SetDefaultCostConfig(defaultConfig) - - // Field with no data source - should use default config - calc.EnterField(1, "Query", "unknown", false, nil) - calc.LeaveField(1, nil) - - totalCost := calc.GetTotalCost() - assert.Equal(t, 1, totalCost) -} - -func TestCostTree_Calculate(t *testing.T) { - tree := &CostTree{ - Root: &CostTreeNode{ - FieldName: "_root", - Multiplier: 1, - Children: []*CostTreeNode{ - { - FieldName: "field1", - FieldCost: 5, - }, - }, - }, - } - - tree.Calculate() - - assert.Equal(t, 5, tree.Total) -} - -func TestNilCostConfig(t *testing.T) { - var config *DataSourceCostConfig - - // All methods should handle nil gracefully - assert.Equal(t, 0, config.GetFieldCost("Type", "field")) - assert.Equal(t, 0, config.GetArgumentCost("Type", "field", "arg")) - assert.Equal(t, 0, config.ScalarWeight("String")) - assert.Equal(t, 0, config.EnumWeight("Status")) - assert.Equal(t, 0, config.ObjectWeight()) - - assert.Nil(t, config.GetSlicingArguments("Type", "field")) - assert.Equal(t, 0, config.GetAssumedListSize("Type", "field")) -} - -func TestCostCalculator_TwoPhaseFlow(t *testing.T) { - // Test that the two-phase flow works correctly: - // EnterField creates skeleton, LeaveField fills in costs - calc := NewCostCalculator() - calc.Enable() - - config := NewDataSourceCostConfig() - config.Fields["Query.users"] = &FieldCostConfig{ - Weight: 5, - } - calc.SetDataSourceCostConfig(testDSHash1, config) - - // Enter creates skeleton node - calc.EnterField(1, "Query", "users", false, nil) - - // At this point, the node exists but has no cost calculated yet - currentNode := calc.CurrentNode() - assert.NotNil(t, currentNode) - assert.Equal(t, "users", currentNode.FieldName) - assert.Equal(t, 0, currentNode.FieldCost) // Weight not yet calculated - - // Leave fills in DS info and calculates cost - calc.LeaveField(1, []DSHash{testDSHash1}) - - // Now the cost should be calculated - totalCost := calc.GetTotalCost() - assert.Equal(t, 5, totalCost) -} - -func TestCostCalculator_ListSizeAssumedSize(t *testing.T) { - // Test that assumed size is used when no slicing argument is provided - calc := NewCostCalculator() - calc.Enable() - - config := NewDataSourceCostConfig() - config.Fields["Query.users"] = &FieldCostConfig{ - Weight: 1, - AssumedSize: 50, // Assume 50 items if no slicing arg - SlicingArguments: []string{"first", "last"}, - } - calc.SetDataSourceCostConfig(testDSHash1, config) - - // Enter field with no slicing arguments - calc.EnterField(1, "Query", "users", true, nil) - - // Enter child field - calc.EnterField(2, "User", "name", false, nil) - calc.LeaveField(2, []DSHash{testDSHash1}) - - calc.LeaveField(1, []DSHash{testDSHash1}) - - // multiplier should be 50 (assumed size) - tree := calc.GetTree() - assert.Equal(t, 50, tree.Root.Children[0].Multiplier) - assert.Equal(t, 51, calc.GetTotalCost()) -} - -func TestCostCalculator_ListSizeSlicingArg(t *testing.T) { - // Test that slicing argument overrides assumed size - calc := NewCostCalculator() - calc.Enable() - - config := NewDataSourceCostConfig() - config.Fields["Query.users"] = &FieldCostConfig{ - Weight: 1, - AssumedSize: 50, // This should NOT be used - SlicingArguments: []string{"first", "last"}, - } - calc.SetDataSourceCostConfig(testDSHash1, config) - - // Enter field with "first: 10" argument - calc.EnterField(1, "Query", "users", true, map[string]int{"first": 10}) - calc.LeaveField(1, []DSHash{testDSHash1}) - - // multiplier should be 10 (from slicing arg), not 50 - tree := calc.GetTree() - assert.Equal(t, 10, tree.Root.Children[0].Multiplier) -} +// package plan +// +// import ( +// "testing" +// +// "github.com/stretchr/testify/assert" +// ) +// +// // Test DSHash values +// const ( +// testDSHash1 DSHash = 1001 +// testDSHash2 DSHash = 1002 +// ) +// +// func TestCostDefaults(t *testing.T) { +// // Test that defaults are set correctly +// assert.Equal(t, 1, StaticCostDefaults.Field) +// assert.Equal(t, 0, StaticCostDefaults.Scalar) +// assert.Equal(t, 0, StaticCostDefaults.Enum) +// assert.Equal(t, 1, StaticCostDefaults.Object) +// assert.Equal(t, 10, StaticCostDefaults.List) +// } +// +// func TestNewDataSourceCostConfig(t *testing.T) { +// config := NewDataSourceCostConfig() +// +// assert.NotNil(t, config.Fields) +// assert.NotNil(t, config.ScalarWeights) +// assert.NotNil(t, config.EnumWeights) +// } +// +// func TestDataSourceCostConfig_GetFieldCost(t *testing.T) { +// config := NewDataSourceCostConfig() +// +// // Test default cost +// cost := config.GetFieldCost("Query", "users") +// assert.Equal(t, StaticCostDefaults.Field, cost) +// +// // Test custom cost +// config.Fields["Query.users"] = &FieldCostConfig{ +// Weight: 5, +// } +// cost = config.GetFieldCost("Query", "users") +// assert.Equal(t, 5, cost) +// +// // Test with defaults +// cost = config.GetFieldCost("Query", "posts") +// assert.Equal(t, 1, cost) +// } +// +// func TestDataSourceCostConfig_GetSlicingArguments(t *testing.T) { +// config := NewDataSourceCostConfig() +// +// // Test no list size config +// args := config.GetSlicingArguments("Query", "users") +// assert.Nil(t, args) +// +// // Test with list size config +// config.Fields["Query.users"] = &FieldCostConfig{ +// Weight: 1, +// AssumedSize: 100, +// SlicingArguments: []string{"first", "last"}, +// RequireOneSlicingArgument: true, +// } +// +// // Test GetSlicingArguments +// args = config.GetSlicingArguments("Query", "users") +// assert.Equal(t, []string{"first", "last"}, args) +// +// // Test GetAssumedListSize +// assumed := config.GetAssumedListSize("Query", "users") +// assert.Equal(t, 100, assumed) +// } +// +// func TestCostTreeNode_TotalCost(t *testing.T) { +// // Build a simple tree: +// // root (cost: 1) +// // └── users (cost: 1, multiplier: 10 from "first" arg) +// // └── name (cost: 1) +// // └── email (cost: 1) +// +// root := &CostTreeNode{ +// FieldName: "_root", +// Multiplier: 1, +// } +// +// users := &CostTreeNode{ +// FieldName: "users", +// FieldCost: 1, +// Multiplier: 10, // "first: 10" +// } +// root.Children = append(root.Children, users) +// +// name := &CostTreeNode{ +// FieldName: "name", +// FieldCost: 1, +// } +// users.Children = append(users.Children, name) +// +// email := &CostTreeNode{ +// FieldName: "email", +// FieldCost: 1, +// } +// users.Children = append(users.Children, email) +// +// // Calculate: root cost = users cost + (children cost * multiplier) +// // users: 1 + (1 + 1) * 10 = 1 + 20 = 21 +// // root: 0 + 21 * 1 = 21 +// total := root.TotalCost() +// assert.Equal(t, 21, total) +// } +// +// func TestCostCalculator_BasicFlow(t *testing.T) { +// calc := NewCostCalculator() +// calc.Enable() +// +// config := NewDataSourceCostConfig() +// config.Fields["Query.users"] = &FieldCostConfig{ +// Weight: 2, +// SlicingArguments: []string{"first"}, +// } +// calc.SetDataSourceCostConfig(testDSHash1, config) +// +// // Simulate entering and leaving fields (two-phase: Enter creates skeleton, Leave calculates costs) +// calc.EnterField(1, "Query", "users", true, map[string]int{"first": 10}) +// calc.EnterField(2, "User", "name", false, nil) +// calc.LeaveField(2, []DSHash{testDSHash1}) +// calc.EnterField(3, "User", "email", false, nil) +// calc.LeaveField(3, []DSHash{testDSHash1}) +// calc.LeaveField(1, []DSHash{testDSHash1}) +// +// // Get results +// tree := calc.GetTree() +// assert.NotNil(t, tree) +// assert.True(t, tree.Total > 0) +// +// totalCost := calc.GetTotalCost() +// // Per IBM spec: users weight=2 + (name(1) + email(1)) * 10 = 2 + 20 = 22 +// assert.Equal(t, 22, totalCost) +// } +// +// func TestCostCalculator_Disabled(t *testing.T) { +// calc := NewCostCalculator() +// // Don't enable +// +// calc.EnterField(1, "Query", "users", true, nil) +// calc.LeaveField(1, []DSHash{testDSHash1}) +// +// // Should return 0 when disabled +// assert.Equal(t, 0, calc.GetTotalCost()) +// } +// +// func TestCostCalculator_MultipleDataSources(t *testing.T) { +// calc := NewCostCalculator() +// calc.Enable() +// +// // Configure two different data sources with different weights +// config1 := NewDataSourceCostConfig() +// config1.Fields["User.name"] = &FieldCostConfig{ +// Weight: 2, +// } +// calc.SetDataSourceCostConfig(testDSHash1, config1) +// +// config2 := NewDataSourceCostConfig() +// config2.Fields["User.name"] = &FieldCostConfig{ +// Weight: 3, +// } +// calc.SetDataSourceCostConfig(testDSHash2, config2) +// +// // Field planned on multiple data sources - per IBM spec, use first data source's weight +// calc.EnterField(1, "User", "name", false, nil) +// calc.LeaveField(1, []DSHash{testDSHash1, testDSHash2}) +// +// totalCost := calc.GetTotalCost() +// // Per IBM spec: field is resolved once, using first data source weight = 2 +// assert.Equal(t, 2, totalCost) +// } +// +// func TestCostCalculator_NoDataSource(t *testing.T) { +// calc := NewCostCalculator() +// calc.Enable() +// +// // Set default config +// defaultConfig := NewDataSourceCostConfig() +// calc.SetDefaultCostConfig(defaultConfig) +// +// // Field with no data source - should use default config +// calc.EnterField(1, "Query", "unknown", false, nil) +// calc.LeaveField(1, nil) +// +// totalCost := calc.GetTotalCost() +// assert.Equal(t, 1, totalCost) +// } +// +// func TestCostTree_Calculate(t *testing.T) { +// tree := &CostTree{ +// Root: &CostTreeNode{ +// FieldName: "_root", +// Multiplier: 1, +// Children: []*CostTreeNode{ +// { +// FieldName: "field1", +// FieldCost: 5, +// }, +// }, +// }, +// } +// +// tree.Calculate() +// +// assert.Equal(t, 5, tree.Total) +// } +// +// func TestNilCostConfig(t *testing.T) { +// var config *DataSourceCostConfig +// +// // All methods should handle nil gracefully +// assert.Equal(t, 0, config.GetFieldCost("Type", "field")) +// assert.Equal(t, 0, config.GetArgumentCost("Type", "field", "arg")) +// assert.Equal(t, 0, config.ScalarWeight("String")) +// assert.Equal(t, 0, config.EnumWeight("Status")) +// assert.Equal(t, 0, config.ObjectWeight()) +// +// assert.Nil(t, config.GetSlicingArguments("Type", "field")) +// assert.Equal(t, 0, config.GetAssumedListSize("Type", "field")) +// } +// +// func TestCostCalculator_TwoPhaseFlow(t *testing.T) { +// // Test that the two-phase flow works correctly: +// // EnterField creates skeleton, LeaveField fills in costs +// calc := NewCostCalculator() +// calc.Enable() +// +// config := NewDataSourceCostConfig() +// config.Fields["Query.users"] = &FieldCostConfig{ +// Weight: 5, +// } +// calc.SetDataSourceCostConfig(testDSHash1, config) +// +// // Enter creates skeleton node +// calc.EnterField(1, "Query", "users", false, nil) +// +// // At this point, the node exists but has no cost calculated yet +// currentNode := calc.CurrentNode() +// assert.NotNil(t, currentNode) +// assert.Equal(t, "users", currentNode.FieldName) +// assert.Equal(t, 0, currentNode.FieldCost) // Weight not yet calculated +// +// // Leave fills in DS info and calculates cost +// calc.LeaveField(1, []DSHash{testDSHash1}) +// +// // Now the cost should be calculated +// totalCost := calc.GetTotalCost() +// assert.Equal(t, 5, totalCost) +// } +// +// func TestCostCalculator_ListSizeAssumedSize(t *testing.T) { +// // Test that assumed size is used when no slicing argument is provided +// calc := NewCostCalculator() +// calc.Enable() +// +// config := NewDataSourceCostConfig() +// config.Fields["Query.users"] = &FieldCostConfig{ +// Weight: 1, +// AssumedSize: 50, // Assume 50 items if no slicing arg +// SlicingArguments: []string{"first", "last"}, +// } +// calc.SetDataSourceCostConfig(testDSHash1, config) +// +// // Enter field with no slicing arguments +// calc.EnterField(1, "Query", "users", true, nil) +// +// // Enter child field +// calc.EnterField(2, "User", "name", false, nil) +// calc.LeaveField(2, []DSHash{testDSHash1}) +// +// calc.LeaveField(1, []DSHash{testDSHash1}) +// +// // multiplier should be 50 (assumed size) +// tree := calc.GetTree() +// assert.Equal(t, 50, tree.Root.Children[0].Multiplier) +// assert.Equal(t, 51, calc.GetTotalCost()) +// } +// +// func TestCostCalculator_ListSizeSlicingArg(t *testing.T) { +// // Test that slicing argument overrides assumed size +// calc := NewCostCalculator() +// calc.Enable() +// +// config := NewDataSourceCostConfig() +// config.Fields["Query.users"] = &FieldCostConfig{ +// Weight: 1, +// AssumedSize: 50, // This should NOT be used +// SlicingArguments: []string{"first", "last"}, +// } +// calc.SetDataSourceCostConfig(testDSHash1, config) +// +// // Enter field with "first: 10" argument +// calc.EnterField(1, "Query", "users", true, map[string]int{"first": 10}) +// calc.LeaveField(1, []DSHash{testDSHash1}) +// +// // multiplier should be 10 (from slicing arg), not 50 +// tree := calc.GetTree() +// assert.Equal(t, 10, tree.Root.Children[0].Multiplier) +// } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index ea51b616c5..87fb84eeff 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -438,7 +438,7 @@ func (v *Visitor) enterFieldCost(ref int) { isListType := v.Definition.TypeIsList(fieldDefinitionTypeRef) namedTypeName := v.Definition.ResolveTypeNameString(fieldDefinitionTypeRef) - arguments := v.costFieldArguments(ref) + arguments := v.extractFieldArguments(ref) // directives := v.costFieldDirectives(ref) @@ -464,8 +464,8 @@ func (v *Visitor) getFieldDataSourceHashes(ref int) []DSHash { return dsHashes } -// costFieldArguments extracts arguments from a field for cost calculation -func (v *Visitor) costFieldArguments(ref int) map[string]ArgumentInfo { +// extractFieldArguments extracts arguments from a field for cost calculation +func (v *Visitor) extractFieldArguments(ref int) map[string]ArgumentInfo { argRefs := v.Operation.FieldArguments(ref) if len(argRefs) == 0 { return nil @@ -477,7 +477,7 @@ func (v *Visitor) costFieldArguments(ref int) map[string]ArgumentInfo { argValue := v.Operation.ArgumentValue(argRef) argInfo := ArgumentInfo{} - fmt.Printf("costFieldArguments: argName=%s, argValue=%v\n", argName, argValue) + fmt.Printf("extractFieldArguments: argName=%s, argValue=%v\n", argName, argValue) val, err := v.Operation.PrintValueBytes(argValue, nil) if err != nil { panic(err) @@ -486,16 +486,15 @@ func (v *Visitor) costFieldArguments(ref int) map[string]ArgumentInfo { switch argValue.Kind { case ast.ValueKindBoolean, ast.ValueKindEnum, ast.ValueKindString, ast.ValueKindFloat: argInfo.isScalar = true + argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) case ast.ValueKindNull: continue case ast.ValueKindInteger: // Extract integer value if present (for multipliers like "first", "limit") argInfo.intValue = int(v.Operation.IntValueAsInt(argValue.Ref)) argInfo.isScalar = true + argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) case ast.ValueKindVariable: - // TODO: we need to analyze variables that contains input object fields. - // If these fields has weight attached, use them for calculation. - // Variables are not inlined at this stage, so we need to inspect them via AST. argInfo.isInputObject = true variableValue := v.Operation.VariableValueNameString(argValue.Ref) if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { @@ -507,10 +506,14 @@ func (v *Visitor) costFieldArguments(ref int) map[string]ArgumentInfo { } // variableTypeRef := v.Operation.VariableDefinitions[variableDefinition].Type argInfo.typeName = v.Operation.ResolveTypeNameString(v.Operation.VariableDefinitions[variableDefinition].Type) + // TODO: we need to analyze variables that contains input object fields. + // If these fields has weight attached, use them for calculation. + // Variables are not inlined at this stage, so we need to inspect them via AST. case ast.ValueKindList: - // should we do something? is it possible at all? - continue + unwrappedTypeRef := v.Operation.ResolveUnderlyingType(argValue.Ref) + argInfo.typeName = v.Operation.TypeNameString(unwrappedTypeRef) + // how to figure out if the unwrapped is scalar? default: fmt.Printf("unhandled case: %v\n", argValue.Kind) continue From 12033d0190c459775fa7c5e6f29e2f51acca7efc Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Wed, 24 Dec 2025 13:17:56 +0200 Subject: [PATCH 06/43] push the cost through the stack --- execution/engine/execution_engine.go | 3 ++ execution/engine/execution_engine_test.go | 43 ++++++++++++++++------- v2/pkg/engine/plan/configuration.go | 2 ++ v2/pkg/engine/plan/plan.go | 20 +++++++++++ v2/pkg/engine/plan/planner.go | 22 ++++++++---- v2/pkg/engine/plan/static_cost.go | 21 ----------- v2/pkg/engine/plan/static_cost_test.go | 4 ++- v2/pkg/engine/plan/visitor.go | 12 ++----- 8 files changed, 76 insertions(+), 51 deletions(-) diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index 53b4f2c5e8..f442888a25 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -68,6 +68,8 @@ type ExecutionEngine struct { resolver *resolve.Resolver executionPlanCache *lru.Cache apolloCompatibilityFlags apollocompatibility.Flags + // Holds the plan after Execute(). Used in testing. + lastPlan plan.Plan } type WebsocketBeforeStartHook interface { @@ -214,6 +216,7 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques if report.HasErrors() { return report } + e.lastPlan = cachedPlan if execContext.resolveContext.TracingOptions.Enable && !execContext.resolveContext.TracingOptions.ExcludePlannerStats { planningTime := resolve.GetDurationNanoSinceTraceStart(execContext.resolveContext.Context()) - tracePlanStart diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 8ba95ac7ce..ea349e053b 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -201,17 +201,19 @@ func TestWithAdditionalHttpHeaders(t *testing.T) { } type ExecutionEngineTestCase struct { - schema *graphql.Schema - operation func(t *testing.T) graphql.Request - dataSources []plan.DataSource - fields plan.FieldConfigurations - engineOptions []ExecutionOptions + schema *graphql.Schema + operation func(t *testing.T) graphql.Request + dataSources []plan.DataSource + fields plan.FieldConfigurations + engineOptions []ExecutionOptions + customResolveMap map[string]resolve.CustomResolve + skipReason string + indentJSON bool + expectedResponse string expectedJSONResponse string expectedFixture string - customResolveMap map[string]resolve.CustomResolve - skipReason string - indentJSON bool + expectedStaticCost int } type _executionTestOptions struct { @@ -220,6 +222,7 @@ type _executionTestOptions struct { apolloRouterCompatibilitySubrequestHTTPError bool propagateFetchReasons bool validateRequiredExternalFields bool + computeStaticCost bool } type executionTestOptions func(*_executionTestOptions) @@ -244,6 +247,12 @@ func validateRequiredExternalFields() executionTestOptions { } } +func computeStaticCost() executionTestOptions { + return func(options *_executionTestOptions) { + options.computeStaticCost = true + } +} + func TestExecutionEngine_Execute(t *testing.T) { run := func(testCase ExecutionEngineTestCase, withError bool, expectedErrorMessage string, options ...executionTestOptions) func(t *testing.T) { t.Helper() @@ -264,9 +273,9 @@ func TestExecutionEngine_Execute(t *testing.T) { PrintOperationTransformations: true, PrintPlanningPaths: true, // PrintNodeSuggestions: true, - PrintQueryPlans: true, - ConfigurationVisitor: true, - PlanningVisitor: true, + PrintQueryPlans: true, + ConfigurationVisitor: true, + PlanningVisitor: true, // DatasourceVisitor: true, } @@ -278,13 +287,15 @@ func TestExecutionEngine_Execute(t *testing.T) { } engineConf.plannerConfig.BuildFetchReasons = opts.propagateFetchReasons engineConf.plannerConfig.ValidateRequiredExternalFields = opts.validateRequiredExternalFields - engine, err := NewExecutionEngine(ctx, abstractlogger.Noop{}, engineConf, resolve.ResolverOptions{ + engineConf.plannerConfig.ComputeStaticCost = opts.computeStaticCost + resolveOpts := resolve.ResolverOptions{ MaxConcurrency: 1024, ResolvableOptions: opts.resolvableOptions, ApolloRouterCompatibilitySubrequestHTTPError: opts.apolloRouterCompatibilitySubrequestHTTPError, PropagateFetchReasons: opts.propagateFetchReasons, ValidateRequiredExternalFields: opts.validateRequiredExternalFields, - }) + } + engine, err := NewExecutionEngine(ctx, abstractlogger.Noop{}, engineConf, resolveOpts) require.NoError(t, err) operation := testCase.operation(t) @@ -314,6 +325,12 @@ func TestExecutionEngine_Execute(t *testing.T) { assert.Equal(t, testCase.expectedResponse, actualResponse) } + if testCase.expectedStaticCost != 0 { + lastPlan := engine.lastPlan + assert.NotNil(t, lastPlan) + assert.Equal(t, testCase.expectedStaticCost, lastPlan.GetStaticCost()) + } + if withError { require.Error(t, err) if expectedErrorMessage != "" { diff --git a/v2/pkg/engine/plan/configuration.go b/v2/pkg/engine/plan/configuration.go index dafcb021c5..ead65c0596 100644 --- a/v2/pkg/engine/plan/configuration.go +++ b/v2/pkg/engine/plan/configuration.go @@ -46,6 +46,8 @@ type Configuration struct { // entity. // This option requires BuildFetchReasons set to true. ValidateRequiredExternalFields bool + + ComputeStaticCost bool } type DebugConfiguration struct { diff --git a/v2/pkg/engine/plan/plan.go b/v2/pkg/engine/plan/plan.go index 15f97769f0..1252c2e01c 100644 --- a/v2/pkg/engine/plan/plan.go +++ b/v2/pkg/engine/plan/plan.go @@ -14,11 +14,22 @@ const ( type Plan interface { PlanKind() Kind SetFlushInterval(interval int64) + GetStaticCost() int + SetStaticCost(cost int) } type SynchronousResponsePlan struct { Response *resolve.GraphQLResponse FlushInterval int64 + StaticCost int +} + +func (s *SynchronousResponsePlan) GetStaticCost() int { + return s.StaticCost +} + +func (s *SynchronousResponsePlan) SetStaticCost(cost int) { + s.StaticCost = cost } func (s *SynchronousResponsePlan) SetFlushInterval(interval int64) { @@ -32,6 +43,15 @@ func (*SynchronousResponsePlan) PlanKind() Kind { type SubscriptionResponsePlan struct { Response *resolve.GraphQLSubscription FlushInterval int64 + StaticCost int +} + +func (s *SubscriptionResponsePlan) GetStaticCost() int { + return s.StaticCost +} + +func (s *SubscriptionResponsePlan) SetStaticCost(cost int) { + s.StaticCost = cost } func (s *SubscriptionResponsePlan) SetFlushInterval(interval int64) { diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index 08624e8084..c0edda2f34 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -61,12 +61,13 @@ func NewPlanner(config Configuration) (*Planner, error) { planningWalker := astvisitor.NewWalkerWithID(48, "PlanningWalker") // Initialize cost calculator and configure from data sources - costCalc := NewCostCalculator() - costCalc.Enable() - for _, ds := range config.DataSources { - if costConfig := ds.GetCostConfig(); costConfig != nil { - costCalc.SetDataSourceCostConfig(ds.Hash(), costConfig) - costCalc.Enable() + var costCalc *CostCalculator + if config.ComputeStaticCost { + costCalc = NewCostCalculator() + for _, ds := range config.DataSources { + if costConfig := ds.GetCostConfig(); costConfig != nil { + costCalc.SetDataSourceCostConfig(ds.Hash(), costConfig) + } } } @@ -167,6 +168,10 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string p.planningVisitor.fieldRefDependsOnFieldRefs = selectionsConfig.fieldRefDependsOn p.planningVisitor.fieldDependencyKind = selectionsConfig.fieldDependencyKind p.planningVisitor.fieldRefDependants = inverseMap(selectionsConfig.fieldRefDependsOn) + // if p.config.ComputeStaticCost { + // p.planningVisitor.costCalculator = NewCostCalculator() + // p.planningVisitor.costCalculator.Enable() + // } p.planningWalker.ResetVisitors() p.planningWalker.SetVisitorFilter(p.planningVisitor) @@ -213,6 +218,11 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string return } + if p.config.ComputeStaticCost { + cost := p.planningVisitor.costCalculator.GetTotalCost() + p.planningVisitor.plan.SetStaticCost(cost) + } + return p.planningVisitor.plan } diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index b50d7e918b..e620eab34b 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -225,9 +225,6 @@ type CostCalculator struct { // defaultConfig is used when no data source specific config exists defaultConfig *DataSourceCostConfig - - // enabled controls whether cost calculation is active - enabled bool } // NewCostCalculator creates a new cost calculator @@ -242,18 +239,12 @@ func NewCostCalculator() *CostCalculator { stack: make([]*CostTreeNode, 0, 16), costConfigs: make(map[DSHash]*DataSourceCostConfig), tree: tree, - enabled: false, } c.stack = append(c.stack, c.tree.Root) return &c } -// Enable activates cost calculation -func (c *CostCalculator) Enable() { - c.enabled = true -} - // SetDataSourceCostConfig sets the cost config for a specific data source func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSourceCostConfig) { c.costConfigs[dsHash] = config @@ -281,11 +272,6 @@ func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSour // return &DataSourceCostConfig{} // } -// IsEnabled returns whether cost calculation is enabled -func (c *CostCalculator) IsEnabled() bool { - return c.enabled -} - // CurrentNode returns the current node on the stack func (c *CostCalculator) CurrentNode() *CostTreeNode { if len(c.stack) == 0 { @@ -299,9 +285,6 @@ func (c *CostCalculator) CurrentNode() *CostTreeNode { // The actual cost calculation happens in LeaveField when fieldPlanners data is available. func (c *CostCalculator) EnterField(fieldRef int, coord FieldCoordinate, namedTypeName string, isListType bool, arguments map[string]ArgumentInfo) { - if !c.enabled { - return - } // Create skeleton cost node. Costs will be calculated in LeaveField node := &CostTreeNode{ @@ -326,10 +309,6 @@ func (c *CostCalculator) EnterField(fieldRef int, coord FieldCoordinate, namedTy // LeaveField is called when leaving a field during AST traversal. // This is where we calculate costs because fieldPlanners data is now available. func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { - if !c.enabled { - return - } - // Find the current node (should match fieldRef) if len(c.stack) <= 1 { // Keep root on stack return diff --git a/v2/pkg/engine/plan/static_cost_test.go b/v2/pkg/engine/plan/static_cost_test.go index f096b60733..ab4fd7a173 100644 --- a/v2/pkg/engine/plan/static_cost_test.go +++ b/v2/pkg/engine/plan/static_cost_test.go @@ -1,4 +1,6 @@ -// package plan +package plan + +// Tests below are commented out during development // // import ( // "testing" diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 87fb84eeff..8205797df0 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -422,7 +422,7 @@ func (v *Visitor) mapFieldConfig(ref int) { // enterFieldCost creates a skeleton cost node when entering a field. // Actual cost calculation is deferred to leaveFieldCost when fieldPlanners data is available. func (v *Visitor) enterFieldCost(ref int) { - if v.costCalculator == nil || !v.costCalculator.IsEnabled() { + if v.costCalculator == nil { return } @@ -546,14 +546,6 @@ func (v *Visitor) extractFieldArguments(ref int) map[string]ArgumentInfo { // return arguments // } -// GetTotalCost returns the total calculated cost for the query -func (v *Visitor) GetTotalCost() int { - if v.costCalculator == nil { - return 0 - } - return v.costCalculator.GetTotalCost() -} - func (v *Visitor) resolveFieldInfo(ref, typeRef int, onTypeNames [][]byte) *resolve.FieldInfo { if v.Config.DisableIncludeInfo { return nil @@ -765,7 +757,7 @@ func (v *Visitor) LeaveField(ref int) { // Calculate costs and pop from cost stack // This is done in LeaveField because fieldPlanners is populated by AllowVisitor on LeaveField - if v.costCalculator != nil && v.costCalculator.IsEnabled() { + if v.costCalculator != nil { dsHashes := v.getFieldDataSourceHashes(ref) v.costCalculator.LeaveField(ref, dsHashes) } From cdb7f930e7a065bcd060d4e8516c4ad9ffc8d395 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Thu, 8 Jan 2026 19:31:02 +0200 Subject: [PATCH 07/43] handle basic cost calculation and capture abstract types --- execution/engine/execution_engine_test.go | 231 +++++++++++++++- v2/pkg/engine/plan/planner.go | 6 +- v2/pkg/engine/plan/static_cost.go | 152 ++++------- v2/pkg/engine/plan/static_cost_test.go | 307 ---------------------- v2/pkg/engine/plan/visitor.go | 67 +++-- 5 files changed, 314 insertions(+), 449 deletions(-) delete mode 100644 v2/pkg/engine/plan/static_cost_test.go diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index ea349e053b..03ddc0b1cd 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -317,6 +317,15 @@ func TestExecutionEngine_Execute(t *testing.T) { return } + if withError { + require.Error(t, err) + if expectedErrorMessage != "" { + assert.Contains(t, err.Error(), expectedErrorMessage) + } + } else { + require.NoError(t, err) + } + if testCase.expectedJSONResponse != "" { assert.JSONEq(t, testCase.expectedJSONResponse, actualResponse) } @@ -331,14 +340,6 @@ func TestExecutionEngine_Execute(t *testing.T) { assert.Equal(t, testCase.expectedStaticCost, lastPlan.GetStaticCost()) } - if withError { - require.Error(t, err) - if expectedErrorMessage != "" { - assert.Contains(t, err.Error(), expectedErrorMessage) - } - } else { - require.NoError(t, err) - } } } @@ -889,7 +890,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, }, }, ChildNodes: []plan.TypeField{ @@ -948,7 +949,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, // Only for this field propagate the fetch reasons, // even if a user has asked for the interface in the query. FetchReasonFields: []string{"name"}, @@ -1008,7 +1009,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, FetchReasonFields: []string{"name"}, // implementing is marked }, }, @@ -1079,7 +1080,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, }, }, ChildNodes: []plan.TypeField{ @@ -1152,7 +1153,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, { TypeName: "Droid", - FieldNames: []string{"name", "primaryFunctions", "friends"}, + FieldNames: []string{"name", "primaryFunction", "friends"}, FetchReasonFields: []string{"name"}, // implementing is marked }, }, @@ -5544,6 +5545,210 @@ func TestExecutionEngine_Execute(t *testing.T) { }, withFetchReasons(), validateRequiredExternalFields())) }) }) + + t.Run("static cost computation", func(t *testing.T) { + rootNodes := []plan.TypeField{ + { + TypeName: "Query", + FieldNames: []string{"hero", "droid"}, + }, + { + TypeName: "Human", + FieldNames: []string{"name", "height", "friends"}, + }, + { + TypeName: "Droid", + FieldNames: []string{"name", "primaryFunction", "friends"}, + }, + } + childNodes := []plan.TypeField{ + { + TypeName: "Character", + FieldNames: []string{"name", "friends"}, + }, + } + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", + }, + SchemaConfiguration: mustSchemaConfig( + t, + nil, + string(graphql.StarwarsSchema(t).RawSchema()), + ), + }) + costConfig := &plan.DataSourceCostConfig{ + Fields: map[plan.FieldCoordinate]*plan.FieldCostConfig{ + {TypeName: "Query", FieldName: "hero"}: {Weight: 2}, + {TypeName: "Human", FieldName: "name"}: {Weight: 7}, + {TypeName: "Human", FieldName: "height"}: {Weight: 3}, + {TypeName: "Droid", FieldName: "name"}: {Weight: 17}, + }, + Types: map[string]int{ + "Human": 13, + }, + } + + t.Run("droid simple fields", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: ` + { + droid(id: "R2D2") { + name + primaryFunction + } + } + `, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, + "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: costConfig, + }, + customConfig, + ), + }, + fields: []plan.FieldConfiguration{ + { + TypeName: "Query", + FieldName: "droid", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "id", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, + }, + }, + }, + }, + expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + expectedStaticCost: 18, // droid (1) + droid.name (17) + }, + computeStaticCost(), + )) + + t.Run("hero with interfaces", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + name + ... on Human { + height + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, + "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker","height":"12"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: costConfig, + }, + customConfig, + ), + }, + fields: []plan.FieldConfiguration{}, + expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker","height":"12"}}}`, + expectedStaticCost: 5, // hero (2) + hero.height (3) + // But should be: hero (2) + hero.height (3) + droid.name (17=max(7, 17)) + }, + computeStaticCost(), + )) + + t.Run("query with list field and slicing argument", func(t *testing.T) { + // Schema: Query.users(first: Int) returns [User] + // Cost calculation (IBM spec): + // - Query.users: weight = 2 (custom) + // - "first: 5" slicing argument -> multiplier = 5 + // - User.id: weight = 0 + // - User.name: weight = 0 + // Total = 2 + (0 + 0) * 5 = 2 + + schema, err := graphql.NewSchemaFromString(` + type Query { + users(first: Int): [User!] + } + type User { + id: ID! + name: String! + } + `) + require.NoError(t, err) + + dsCfg, err := plan.NewDataSourceConfiguration[staticdatasource.Configuration]( + "users-ds", + &staticdatasource.Factory[staticdatasource.Configuration]{}, + &plan.DataSourceMetadata{ + RootNodes: []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"users"}}, + }, + ChildNodes: []plan.TypeField{ + {TypeName: "User", FieldNames: []string{"id", "name"}}, + }, + CostConfig: &plan.DataSourceCostConfig{ + Fields: map[plan.FieldCoordinate]*plan.FieldCostConfig{ + {TypeName: "Query", FieldName: "users"}: { + Weight: 2, + SlicingArguments: []string{"first"}, + AssumedSize: 10, + }, + {TypeName: "User", FieldName: "id"}: {Weight: 0}, + {TypeName: "User", FieldName: "name"}: {Weight: 0}, + }, + Types: map[string]int{}, + }, + }, + staticdatasource.Configuration{ + Data: `{"users": [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}]}`, + }, + ) + require.NoError(t, err) + + t.Run("run", runWithoutError(ExecutionEngineTestCase{ + schema: schema, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ users(first: 5) { id name } }`, + } + }, + dataSources: []plan.DataSource{dsCfg}, + expectedJSONResponse: `{"data":{"users":[{"id":"1","name":"Alice"},{"id":"2","name":"Bob"}]}}`, + expectedStaticCost: 2, + }, computeStaticCost())) + }) + }) } func testNetHttpClient(t *testing.T, testCase roundTripperTestCase) *http.Client { diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index c0edda2f34..713dceab25 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -168,10 +168,6 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string p.planningVisitor.fieldRefDependsOnFieldRefs = selectionsConfig.fieldRefDependsOn p.planningVisitor.fieldDependencyKind = selectionsConfig.fieldDependencyKind p.planningVisitor.fieldRefDependants = inverseMap(selectionsConfig.fieldRefDependsOn) - // if p.config.ComputeStaticCost { - // p.planningVisitor.costCalculator = NewCostCalculator() - // p.planningVisitor.costCalculator.Enable() - // } p.planningWalker.ResetVisitors() p.planningWalker.SetVisitorFilter(p.planningVisitor) @@ -212,7 +208,7 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string } } - // create raw execution plan + // create a raw execution plan p.planningWalker.Walk(operation, definition, report) if report.HasErrors() { return diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index e620eab34b..3f391067ae 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -4,20 +4,18 @@ import "fmt" // StaticCostDefaults contains default cost values when no specific costs are configured var StaticCostDefaults = WeightDefaults{ - Field: 1, - Scalar: 0, - Enum: 0, - Object: 1, - List: 10, // The assumed maximum size of a list for fields that return lists. + Field: 1, + EnumScalar: 0, + Object: 1, + List: 10, // The assumed maximum size of a list for fields that return lists. } // WeightDefaults defines default cost values for different GraphQL elements type WeightDefaults struct { - Field int - Scalar int - Enum int - Object int - List int + Field int + EnumScalar int + Object int + List int } // FieldCostConfig defines cost configuration for a specific field of an object or input object. @@ -51,7 +49,7 @@ type FieldCostConfig struct { // DataSourceCostConfig holds all cost configurations for a data source. // This data is passed from the composition. type DataSourceCostConfig struct { - // Fields maps field coordinate to its cost config. + // Fields maps field coordinate to its cost config. Cannot be on fields of interfaces. // Location: FIELD_DEFINITION, INPUT_FIELD_DEFINITION Fields map[FieldCoordinate]*FieldCostConfig @@ -65,7 +63,7 @@ type DataSourceCostConfig struct { // directive name + argument name. That should be the key mapped to the weight. // // Directives can be used on [input] object fields and arguments of fields. This creates - // mutual recursion between them that complicates cost calculation. + // mutual recursion between them; it complicates cost calculation. // We avoid them intentionally in the first iteration. } @@ -77,26 +75,15 @@ func NewDataSourceCostConfig() *DataSourceCostConfig { } } -// ScalarWeight returns the cost for a scalar type -func (c *DataSourceCostConfig) ScalarWeight(scalarName string) int { - if c == nil { - return 0 - } - if cost, ok := c.Types[scalarName]; ok { - return cost - } - return StaticCostDefaults.Scalar -} - -// EnumWeight returns the cost for an enum type -func (c *DataSourceCostConfig) EnumWeight(enumName string) int { +// EnumScalarWeight returns the cost for an enum or scalar types +func (c *DataSourceCostConfig) EnumScalarWeight(enumName string) int { if c == nil { return 0 } if cost, ok := c.Types[enumName]; ok { return cost } - return StaticCostDefaults.Enum + return StaticCostDefaults.EnumScalar } // ObjectWeight returns the default object cost @@ -119,33 +106,40 @@ type CostTreeNode struct { // Enclosing type name and field name fieldCoord FieldCoordinate - // dataSourceHashes identifies which data sources this field is resolved from + // dataSourceHashes identifies which data sources resolve this field. dataSourceHashes []DSHash // FieldCost is the weight of this field from @cost directive FieldCost int - // ArgumentsCost is the sum of argument weights and input fields used on each directive + // ArgumentsCost is the sum of argument weights and input fields used on this field. ArgumentsCost int + // Weights on directives ignored for now. DirectivesCost int - // Multiplier is the list size multiplier from @listSize directive + // multiplier is the list size multiplier from @listSize directive // Applied to children costs for list fields - Multiplier int + multiplier int // Children contain child field costs Children []*CostTreeNode // The data below is stored for deferred cost calculation. - - // What is the name of an unwrapped (named) type that is returned by this field? + // We populate these fields in EnterField and use them as a source of truth in LeaveField. + // + // fieldTypeName contains the name of an unwrapped (named) type that is returned by this field. fieldTypeName string - isListType bool + // implementTypeNames contains the names of all types that implement this interface/union field. + implementingTypeNames []string - // arguments contain the values of arguments passed to the field + // arguments contain the values of arguments passed to the field. arguments map[string]ArgumentInfo + + isListType bool + isSimpleType bool + isAbstractType bool } type ArgumentInfo struct { @@ -180,9 +174,6 @@ func (n *CostTreeNode) TotalCost() int { return 0 } - // TODO: negative sum should be rounded up to zero - cost := n.FieldCost + n.ArgumentsCost + n.DirectivesCost - // Sum children (fields) costs var childrenCost int for _, child := range n.Children { @@ -190,11 +181,12 @@ func (n *CostTreeNode) TotalCost() int { } // Apply multiplier to children cost (for list fields) - multiplier := n.Multiplier + multiplier := n.multiplier if multiplier == 0 { multiplier = 1 } - cost += childrenCost * multiplier + // TODO: negative sum should be rounded up to zero + cost := n.ArgumentsCost + n.DirectivesCost + (n.FieldCost+childrenCost)*multiplier return cost } @@ -232,7 +224,7 @@ func NewCostCalculator() *CostCalculator { tree := &CostTree{ Root: &CostTreeNode{ fieldCoord: FieldCoordinate{"_none", "_root"}, - Multiplier: 1, + multiplier: 1, }, } c := CostCalculator{ @@ -250,28 +242,6 @@ func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSour c.costConfigs[dsHash] = config } -// // SetDefaultCostConfig sets the default cost config -// func (c *CostCalculator) SetDefaultCostConfig(config *DataSourceCostConfig) { -// c.defaultConfig = config -// } - -// // getCostConfig returns the cost config for a specific data source hash -// func (c *CostCalculator) getCostConfig(dsHash DSHash) *DataSourceCostConfig { -// if config, ok := c.costConfigs[dsHash]; ok { -// return config -// } -// return c.getDefaultCostConfig() -// } - -// // getDefaultCostConfig returns the default cost config when no specific data source is available -// func (c *CostCalculator) getDefaultCostConfig() *DataSourceCostConfig { -// if c.defaultConfig != nil { -// return c.defaultConfig -// } -// // Return a dummy config with defaults -// return &DataSourceCostConfig{} -// } - // CurrentNode returns the current node on the stack func (c *CostCalculator) CurrentNode() *CostTreeNode { if len(c.stack) == 0 { @@ -283,26 +253,13 @@ func (c *CostCalculator) CurrentNode() *CostTreeNode { // EnterField is called when entering a field during AST traversal. // It creates a skeleton node and pushes it onto the stack. // The actual cost calculation happens in LeaveField when fieldPlanners data is available. -func (c *CostCalculator) EnterField(fieldRef int, coord FieldCoordinate, namedTypeName string, - isListType bool, arguments map[string]ArgumentInfo) { - - // Create skeleton cost node. Costs will be calculated in LeaveField - node := &CostTreeNode{ - fieldRef: fieldRef, - fieldCoord: coord, - Multiplier: 1, - fieldTypeName: namedTypeName, - isListType: isListType, - arguments: arguments, - } - +func (c *CostCalculator) EnterField(node *CostTreeNode) { // Attach to parent parent := c.CurrentNode() if parent != nil { parent.Children = append(parent.Children, node) } - // Push onto stack c.stack = append(c.stack, node) } @@ -319,11 +276,9 @@ func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { return } - // Now calculate costs with the data source information current.dataSourceHashes = dsHashes c.calculateNodeCosts(current) - // Pop from stack c.stack = c.stack[:len(c.stack)-1] } @@ -339,7 +294,7 @@ func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { return } - node.Multiplier = 0 + node.multiplier = 0 for _, dsHash := range node.dataSourceHashes { config, ok := c.costConfigs[dsHash] @@ -348,34 +303,35 @@ func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { continue } + // TODO: handle abstract types + fieldConfig := config.Fields[node.fieldCoord] if fieldConfig != nil { - node.FieldCost = fieldConfig.Weight + node.FieldCost += fieldConfig.Weight + for argName := range node.arguments { + weight, ok := fieldConfig.ArgumentWeights[argName] + if ok { + node.ArgumentsCost += weight + } + // What to do if the argument definition itself does not have weight attached, + // but the type of the argument does have weight attached to it? + // TODO: arguments should include costs of input object fields + } } else { // use the weight of the type returned by this field - if typeWeight, ok := config.Types[node.fieldTypeName]; ok { - node.FieldCost = typeWeight + if node.isSimpleType { + node.FieldCost += config.EnumScalarWeight(node.fieldTypeName) + } else { + node.FieldCost += config.ObjectWeight(node.fieldTypeName) } } - for argName := range node.arguments { - weight, ok := fieldConfig.ArgumentWeights[argName] - if ok { - node.ArgumentsCost += weight - } - // TODO: arguments should include costs of input object fields - } - // Compute multiplier as the maximum of data sources. - if !node.isListType { - node.Multiplier = 1 + if !node.isListType || fieldConfig == nil { + node.multiplier = 1 continue } - if fieldConfig == nil { - node.Multiplier = 1 - continue - } multiplier := -1 for _, slicingArg := range fieldConfig.SlicingArguments { argInfo, ok := node.arguments[slicingArg] @@ -386,8 +342,8 @@ func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { if multiplier == -1 && fieldConfig.AssumedSize > 0 { multiplier = fieldConfig.AssumedSize } - if multiplier > node.Multiplier { - node.Multiplier = multiplier + if multiplier > node.multiplier { + node.multiplier = multiplier } } diff --git a/v2/pkg/engine/plan/static_cost_test.go b/v2/pkg/engine/plan/static_cost_test.go deleted file mode 100644 index ab4fd7a173..0000000000 --- a/v2/pkg/engine/plan/static_cost_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package plan - -// Tests below are commented out during development -// -// import ( -// "testing" -// -// "github.com/stretchr/testify/assert" -// ) -// -// // Test DSHash values -// const ( -// testDSHash1 DSHash = 1001 -// testDSHash2 DSHash = 1002 -// ) -// -// func TestCostDefaults(t *testing.T) { -// // Test that defaults are set correctly -// assert.Equal(t, 1, StaticCostDefaults.Field) -// assert.Equal(t, 0, StaticCostDefaults.Scalar) -// assert.Equal(t, 0, StaticCostDefaults.Enum) -// assert.Equal(t, 1, StaticCostDefaults.Object) -// assert.Equal(t, 10, StaticCostDefaults.List) -// } -// -// func TestNewDataSourceCostConfig(t *testing.T) { -// config := NewDataSourceCostConfig() -// -// assert.NotNil(t, config.Fields) -// assert.NotNil(t, config.ScalarWeights) -// assert.NotNil(t, config.EnumWeights) -// } -// -// func TestDataSourceCostConfig_GetFieldCost(t *testing.T) { -// config := NewDataSourceCostConfig() -// -// // Test default cost -// cost := config.GetFieldCost("Query", "users") -// assert.Equal(t, StaticCostDefaults.Field, cost) -// -// // Test custom cost -// config.Fields["Query.users"] = &FieldCostConfig{ -// Weight: 5, -// } -// cost = config.GetFieldCost("Query", "users") -// assert.Equal(t, 5, cost) -// -// // Test with defaults -// cost = config.GetFieldCost("Query", "posts") -// assert.Equal(t, 1, cost) -// } -// -// func TestDataSourceCostConfig_GetSlicingArguments(t *testing.T) { -// config := NewDataSourceCostConfig() -// -// // Test no list size config -// args := config.GetSlicingArguments("Query", "users") -// assert.Nil(t, args) -// -// // Test with list size config -// config.Fields["Query.users"] = &FieldCostConfig{ -// Weight: 1, -// AssumedSize: 100, -// SlicingArguments: []string{"first", "last"}, -// RequireOneSlicingArgument: true, -// } -// -// // Test GetSlicingArguments -// args = config.GetSlicingArguments("Query", "users") -// assert.Equal(t, []string{"first", "last"}, args) -// -// // Test GetAssumedListSize -// assumed := config.GetAssumedListSize("Query", "users") -// assert.Equal(t, 100, assumed) -// } -// -// func TestCostTreeNode_TotalCost(t *testing.T) { -// // Build a simple tree: -// // root (cost: 1) -// // └── users (cost: 1, multiplier: 10 from "first" arg) -// // └── name (cost: 1) -// // └── email (cost: 1) -// -// root := &CostTreeNode{ -// FieldName: "_root", -// Multiplier: 1, -// } -// -// users := &CostTreeNode{ -// FieldName: "users", -// FieldCost: 1, -// Multiplier: 10, // "first: 10" -// } -// root.Children = append(root.Children, users) -// -// name := &CostTreeNode{ -// FieldName: "name", -// FieldCost: 1, -// } -// users.Children = append(users.Children, name) -// -// email := &CostTreeNode{ -// FieldName: "email", -// FieldCost: 1, -// } -// users.Children = append(users.Children, email) -// -// // Calculate: root cost = users cost + (children cost * multiplier) -// // users: 1 + (1 + 1) * 10 = 1 + 20 = 21 -// // root: 0 + 21 * 1 = 21 -// total := root.TotalCost() -// assert.Equal(t, 21, total) -// } -// -// func TestCostCalculator_BasicFlow(t *testing.T) { -// calc := NewCostCalculator() -// calc.Enable() -// -// config := NewDataSourceCostConfig() -// config.Fields["Query.users"] = &FieldCostConfig{ -// Weight: 2, -// SlicingArguments: []string{"first"}, -// } -// calc.SetDataSourceCostConfig(testDSHash1, config) -// -// // Simulate entering and leaving fields (two-phase: Enter creates skeleton, Leave calculates costs) -// calc.EnterField(1, "Query", "users", true, map[string]int{"first": 10}) -// calc.EnterField(2, "User", "name", false, nil) -// calc.LeaveField(2, []DSHash{testDSHash1}) -// calc.EnterField(3, "User", "email", false, nil) -// calc.LeaveField(3, []DSHash{testDSHash1}) -// calc.LeaveField(1, []DSHash{testDSHash1}) -// -// // Get results -// tree := calc.GetTree() -// assert.NotNil(t, tree) -// assert.True(t, tree.Total > 0) -// -// totalCost := calc.GetTotalCost() -// // Per IBM spec: users weight=2 + (name(1) + email(1)) * 10 = 2 + 20 = 22 -// assert.Equal(t, 22, totalCost) -// } -// -// func TestCostCalculator_Disabled(t *testing.T) { -// calc := NewCostCalculator() -// // Don't enable -// -// calc.EnterField(1, "Query", "users", true, nil) -// calc.LeaveField(1, []DSHash{testDSHash1}) -// -// // Should return 0 when disabled -// assert.Equal(t, 0, calc.GetTotalCost()) -// } -// -// func TestCostCalculator_MultipleDataSources(t *testing.T) { -// calc := NewCostCalculator() -// calc.Enable() -// -// // Configure two different data sources with different weights -// config1 := NewDataSourceCostConfig() -// config1.Fields["User.name"] = &FieldCostConfig{ -// Weight: 2, -// } -// calc.SetDataSourceCostConfig(testDSHash1, config1) -// -// config2 := NewDataSourceCostConfig() -// config2.Fields["User.name"] = &FieldCostConfig{ -// Weight: 3, -// } -// calc.SetDataSourceCostConfig(testDSHash2, config2) -// -// // Field planned on multiple data sources - per IBM spec, use first data source's weight -// calc.EnterField(1, "User", "name", false, nil) -// calc.LeaveField(1, []DSHash{testDSHash1, testDSHash2}) -// -// totalCost := calc.GetTotalCost() -// // Per IBM spec: field is resolved once, using first data source weight = 2 -// assert.Equal(t, 2, totalCost) -// } -// -// func TestCostCalculator_NoDataSource(t *testing.T) { -// calc := NewCostCalculator() -// calc.Enable() -// -// // Set default config -// defaultConfig := NewDataSourceCostConfig() -// calc.SetDefaultCostConfig(defaultConfig) -// -// // Field with no data source - should use default config -// calc.EnterField(1, "Query", "unknown", false, nil) -// calc.LeaveField(1, nil) -// -// totalCost := calc.GetTotalCost() -// assert.Equal(t, 1, totalCost) -// } -// -// func TestCostTree_Calculate(t *testing.T) { -// tree := &CostTree{ -// Root: &CostTreeNode{ -// FieldName: "_root", -// Multiplier: 1, -// Children: []*CostTreeNode{ -// { -// FieldName: "field1", -// FieldCost: 5, -// }, -// }, -// }, -// } -// -// tree.Calculate() -// -// assert.Equal(t, 5, tree.Total) -// } -// -// func TestNilCostConfig(t *testing.T) { -// var config *DataSourceCostConfig -// -// // All methods should handle nil gracefully -// assert.Equal(t, 0, config.GetFieldCost("Type", "field")) -// assert.Equal(t, 0, config.GetArgumentCost("Type", "field", "arg")) -// assert.Equal(t, 0, config.ScalarWeight("String")) -// assert.Equal(t, 0, config.EnumWeight("Status")) -// assert.Equal(t, 0, config.ObjectWeight()) -// -// assert.Nil(t, config.GetSlicingArguments("Type", "field")) -// assert.Equal(t, 0, config.GetAssumedListSize("Type", "field")) -// } -// -// func TestCostCalculator_TwoPhaseFlow(t *testing.T) { -// // Test that the two-phase flow works correctly: -// // EnterField creates skeleton, LeaveField fills in costs -// calc := NewCostCalculator() -// calc.Enable() -// -// config := NewDataSourceCostConfig() -// config.Fields["Query.users"] = &FieldCostConfig{ -// Weight: 5, -// } -// calc.SetDataSourceCostConfig(testDSHash1, config) -// -// // Enter creates skeleton node -// calc.EnterField(1, "Query", "users", false, nil) -// -// // At this point, the node exists but has no cost calculated yet -// currentNode := calc.CurrentNode() -// assert.NotNil(t, currentNode) -// assert.Equal(t, "users", currentNode.FieldName) -// assert.Equal(t, 0, currentNode.FieldCost) // Weight not yet calculated -// -// // Leave fills in DS info and calculates cost -// calc.LeaveField(1, []DSHash{testDSHash1}) -// -// // Now the cost should be calculated -// totalCost := calc.GetTotalCost() -// assert.Equal(t, 5, totalCost) -// } -// -// func TestCostCalculator_ListSizeAssumedSize(t *testing.T) { -// // Test that assumed size is used when no slicing argument is provided -// calc := NewCostCalculator() -// calc.Enable() -// -// config := NewDataSourceCostConfig() -// config.Fields["Query.users"] = &FieldCostConfig{ -// Weight: 1, -// AssumedSize: 50, // Assume 50 items if no slicing arg -// SlicingArguments: []string{"first", "last"}, -// } -// calc.SetDataSourceCostConfig(testDSHash1, config) -// -// // Enter field with no slicing arguments -// calc.EnterField(1, "Query", "users", true, nil) -// -// // Enter child field -// calc.EnterField(2, "User", "name", false, nil) -// calc.LeaveField(2, []DSHash{testDSHash1}) -// -// calc.LeaveField(1, []DSHash{testDSHash1}) -// -// // multiplier should be 50 (assumed size) -// tree := calc.GetTree() -// assert.Equal(t, 50, tree.Root.Children[0].Multiplier) -// assert.Equal(t, 51, calc.GetTotalCost()) -// } -// -// func TestCostCalculator_ListSizeSlicingArg(t *testing.T) { -// // Test that slicing argument overrides assumed size -// calc := NewCostCalculator() -// calc.Enable() -// -// config := NewDataSourceCostConfig() -// config.Fields["Query.users"] = &FieldCostConfig{ -// Weight: 1, -// AssumedSize: 50, // This should NOT be used -// SlicingArguments: []string{"first", "last"}, -// } -// calc.SetDataSourceCostConfig(testDSHash1, config) -// -// // Enter field with "first: 10" argument -// calc.EnterField(1, "Query", "users", true, map[string]int{"first": 10}) -// calc.LeaveField(1, []DSHash{testDSHash1}) -// -// // multiplier should be 10 (from slicing arg), not 50 -// tree := calc.GetTree() -// assert.Equal(t, 10, tree.Root.Children[0].Multiplier) -// } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 8205797df0..3e7d56d734 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -436,14 +436,49 @@ func (v *Visitor) enterFieldCost(ref int) { } fieldDefinitionTypeRef := v.Definition.FieldDefinitionType(fieldDefinition) isListType := v.Definition.TypeIsList(fieldDefinitionTypeRef) - namedTypeName := v.Definition.ResolveTypeNameString(fieldDefinitionTypeRef) + isSimpleType := v.Definition.TypeIsEnum(fieldDefinitionTypeRef, v.Definition) || v.Definition.TypeIsScalar(fieldDefinitionTypeRef, v.Definition) + unwrappedTypeName := v.Definition.ResolveTypeNameString(fieldDefinitionTypeRef) arguments := v.extractFieldArguments(ref) - // directives := v.costFieldDirectives(ref) + // Check and push through if the unwrapped type of this field is interface or union. + unwrappedTypeNode, exists := v.Definition.NodeByNameStr(unwrappedTypeName) + var implementingTypeNames []string + var isAbstractType bool + if exists { + if unwrappedTypeNode.Kind == ast.NodeKindInterfaceTypeDefinition { + impl, ok := v.Definition.InterfaceTypeDefinitionImplementedByObjectWithNames(unwrappedTypeNode.Ref) + if ok { + implementingTypeNames = append(implementingTypeNames, impl...) + isAbstractType = true + } + } + if unwrappedTypeNode.Kind == ast.NodeKindUnionTypeDefinition { + impl, ok := v.Definition.UnionTypeDefinitionMemberTypeNames(unwrappedTypeNode.Ref) + if ok { + implementingTypeNames = append(implementingTypeNames, impl...) + isAbstractType = true + } + } + } + + if len(implementingTypeNames) > 0 { + fmt.Printf("enterFieldCost: field %s.%s is interface or union, implementingTypeNames=%v\n", typeName, fieldName, implementingTypeNames) + } // Create skeleton node - dsHashes will be filled in leaveFieldCost - v.costCalculator.EnterField(ref, coord, namedTypeName, isListType, arguments) + node := CostTreeNode{ + fieldRef: ref, + fieldCoord: coord, + multiplier: 1, + fieldTypeName: unwrappedTypeName, + implementingTypeNames: implementingTypeNames, + isListType: isListType, + isSimpleType: isSimpleType, + isAbstractType: isAbstractType, + arguments: arguments, + } + v.costCalculator.EnterField(&node) } // getFieldDataSourceHashes returns all data source hashes for the field. @@ -488,6 +523,7 @@ func (v *Visitor) extractFieldArguments(ref int) map[string]ArgumentInfo { argInfo.isScalar = true argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) case ast.ValueKindNull: + // Ignore any nulls continue case ast.ValueKindInteger: // Extract integer value if present (for multipliers like "first", "limit") @@ -525,27 +561,6 @@ func (v *Visitor) extractFieldArguments(ref int) map[string]ArgumentInfo { return arguments } -// func (v *Visitor) costFieldDirectives(ref int) map[string]int { -// refs := v.Operation.FieldDirectives(ref) -// if len(refs) == 0 { -// return nil -// } -// -// arguments := make(map[string]int, len(refs)) -// for _, dirRef := range refs { -// dirName := v.Operation.DirectiveName(dirRef) -// dirArgsRef := v.Operation.DirectiveArgumentSet(dirRef) -// -// fmt.Printf("costFieldDirectives: dirName=%s, dirArgsRef=%v\n", dirName, dirArgsRef) -// // Extract integer value if present (for multipliers like "first", "limit") -// if dirArgsRef.Kind == ast.ValueKindInteger { -// arguments[dirName] = int(v.Operation.IntValueAsInt(dirArgsRef.Ref)) -// } -// } -// -// return arguments -// } - func (v *Visitor) resolveFieldInfo(ref, typeRef int, onTypeNames [][]byte) *resolve.FieldInfo { if v.Config.DisableIncludeInfo { return nil @@ -755,8 +770,8 @@ func (v *Visitor) LeaveField(ref int) { return } - // Calculate costs and pop from cost stack - // This is done in LeaveField because fieldPlanners is populated by AllowVisitor on LeaveField + // Calculate costs and pop from cost stack. + // This is done in LeaveField because fieldPlanners is available in LeaveField only. if v.costCalculator != nil { dsHashes := v.getFieldDataSourceHashes(ref) v.costCalculator.LeaveField(ref, dsHashes) From fb659af7c6e8b500748e0098d90ac43baba9bfd2 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Fri, 9 Jan 2026 18:26:55 +0200 Subject: [PATCH 08/43] support interfaces for 2 major cases --- execution/engine/execution_engine_test.go | 148 +++++++---------- v2/pkg/engine/plan/static_cost.go | 187 +++++++++++++++------- v2/pkg/engine/plan/visitor.go | 60 +++---- 3 files changed, 222 insertions(+), 173 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 03ddc0b1cd..14452439bd 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -5580,34 +5580,35 @@ func TestExecutionEngine_Execute(t *testing.T) { }) costConfig := &plan.DataSourceCostConfig{ Fields: map[plan.FieldCoordinate]*plan.FieldCostConfig{ - {TypeName: "Query", FieldName: "hero"}: {Weight: 2}, - {TypeName: "Human", FieldName: "name"}: {Weight: 7}, - {TypeName: "Human", FieldName: "height"}: {Weight: 3}, - {TypeName: "Droid", FieldName: "name"}: {Weight: 17}, + {TypeName: "Query", FieldName: "droid"}: { + ArgumentWeights: map[string]int{"id": 1}, + HasWeight: false, + }, + {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 3}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, }, Types: map[string]int{ "Human": 13, }, } - t.Run("droid simple fields", runWithoutError( + t.Run("droid with weighted plain fields and an argument", runWithoutError( ExecutionEngineTestCase{ schema: graphql.StarwarsSchema(t), operation: func(t *testing.T) graphql.Request { return graphql.Request{ - Query: ` - { - droid(id: "R2D2") { - name - primaryFunction - } + Query: `{ + droid(id: "R2D2") { + name + primaryFunction } - `, + }`, } }, dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, - "id", + mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ expectedHost: "example.com", @@ -5617,11 +5618,7 @@ func TestExecutionEngine_Execute(t *testing.T) { sendStatusCode: 200, }), ), - &plan.DataSourceMetadata{ - RootNodes: rootNodes, - ChildNodes: childNodes, - CostConfig: costConfig, - }, + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: costConfig}, customConfig, ), }, @@ -5639,7 +5636,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, }, expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, - expectedStaticCost: 18, // droid (1) + droid.name (17) + expectedStaticCost: 19, // Query.droid (1) + Query.droid.id (1) + droid.name (17) }, computeStaticCost(), )) @@ -5660,8 +5657,7 @@ func TestExecutionEngine_Execute(t *testing.T) { } }, dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, - "id", + mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ expectedHost: "example.com", @@ -5671,83 +5667,61 @@ func TestExecutionEngine_Execute(t *testing.T) { sendStatusCode: 200, }), ), - &plan.DataSourceMetadata{ - RootNodes: rootNodes, - ChildNodes: childNodes, - CostConfig: costConfig, - }, + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: costConfig}, customConfig, ), }, - fields: []plan.FieldConfiguration{}, expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker","height":"12"}}}`, - expectedStaticCost: 5, // hero (2) + hero.height (3) - // But should be: hero (2) + hero.height (3) + droid.name (17=max(7, 17)) + expectedStaticCost: 22, // Query.hero (2) + Human.height (3) + Droid.name (17=max(7, 17)) }, computeStaticCost(), )) - t.Run("query with list field and slicing argument", func(t *testing.T) { - // Schema: Query.users(first: Int) returns [User] - // Cost calculation (IBM spec): - // - Query.users: weight = 2 (custom) - // - "first: 5" slicing argument -> multiplier = 5 - // - User.id: weight = 0 - // - User.name: weight = 0 - // Total = 2 + (0 + 0) * 5 = 2 - - schema, err := graphql.NewSchemaFromString(` - type Query { - users(first: Int): [User!] - } - type User { - id: ID! - name: String! - } - `) - require.NoError(t, err) - - dsCfg, err := plan.NewDataSourceConfiguration[staticdatasource.Configuration]( - "users-ds", - &staticdatasource.Factory[staticdatasource.Configuration]{}, - &plan.DataSourceMetadata{ - RootNodes: []plan.TypeField{ - {TypeName: "Query", FieldNames: []string{"users"}}, - }, - ChildNodes: []plan.TypeField{ - {TypeName: "User", FieldNames: []string{"id", "name"}}, - }, - CostConfig: &plan.DataSourceCostConfig{ - Fields: map[plan.FieldCoordinate]*plan.FieldCostConfig{ - {TypeName: "Query", FieldName: "users"}: { - Weight: 2, - SlicingArguments: []string{"first"}, - AssumedSize: 10, - }, - {TypeName: "User", FieldName: "id"}: {Weight: 0}, - {TypeName: "User", FieldName: "name"}: {Weight: 0}, - }, - Types: map[string]int{}, - }, - }, - staticdatasource.Configuration{ - Data: `{"users": [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}]}`, - }, - ) - require.NoError(t, err) - - t.Run("run", runWithoutError(ExecutionEngineTestCase{ - schema: schema, + t.Run("hero with 1 expected friend", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), operation: func(t *testing.T) graphql.Request { return graphql.Request{ - Query: `{ users(first: 5) { id name } }`, + Query: `{ + hero { + friends { + ...on Droid { + name + primaryFunction + } + ...on Human { + name + height + } + } + } + }`, } }, - dataSources: []plan.DataSource{dsCfg}, - expectedJSONResponse: `{"data":{"users":[{"id":"1","name":"Alice"},{"id":"2","name":"Bob"}]}}`, - expectedStaticCost: 2, - }, computeStaticCost())) - }) + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + {"__typename":"Human","name":"Luke Skywalker","height":"12"}, + {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} + ]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: costConfig}, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 42, // Query.hero(2)+Human(13=max(13,0))+Human.name(7)+Human.height(3)+Droid.name(17) + }, + computeStaticCost(), + )) + }) } diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 3f391067ae..1f14996306 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -21,8 +21,14 @@ type WeightDefaults struct { // FieldCostConfig defines cost configuration for a specific field of an object or input object. // Includes @listSize directive fields for objects. type FieldCostConfig struct { + + // Weight is the cost of this field definition. It could be negative or zero. + // Should be used only if HasWeight is true. Weight int + // Means that there was weight attached to the field definition. + HasWeight bool + // ArgumentWeights maps an argument name to its weight. // Location: ARGUMENT_DEFINITION ArgumentWeights map[string]int @@ -75,8 +81,8 @@ func NewDataSourceCostConfig() *DataSourceCostConfig { } } -// EnumScalarWeight returns the cost for an enum or scalar types -func (c *DataSourceCostConfig) EnumScalarWeight(enumName string) int { +// EnumScalarTypeWeight returns the cost for an enum or scalar types +func (c *DataSourceCostConfig) EnumScalarTypeWeight(enumName string) int { if c == nil { return 0 } @@ -86,8 +92,8 @@ func (c *DataSourceCostConfig) EnumScalarWeight(enumName string) int { return StaticCostDefaults.EnumScalar } -// ObjectWeight returns the default object cost -func (c *DataSourceCostConfig) ObjectWeight(name string) int { +// ObjectTypeWeight returns the default object cost +func (c *DataSourceCostConfig) ObjectTypeWeight(name string) int { if c == nil { return 0 } @@ -100,34 +106,34 @@ func (c *DataSourceCostConfig) ObjectWeight(name string) int { // CostTreeNode represents a node in the cost calculation tree // Based on IBM GraphQL Cost Specification: https://ibm.github.io/graphql-specs/cost-spec.html type CostTreeNode struct { - // fieldRef is the AST field reference - fieldRef int - - // Enclosing type name and field name - fieldCoord FieldCoordinate - // dataSourceHashes identifies which data sources resolve this field. dataSourceHashes []DSHash - // FieldCost is the weight of this field from @cost directive - FieldCost int + // fieldCost is the weight of this field or its returned type + fieldCost int - // ArgumentsCost is the sum of argument weights and input fields used on this field. - ArgumentsCost int + // argumentsCost is the sum of argument weights and input fields used on this field. + argumentsCost int // Weights on directives ignored for now. - DirectivesCost int + directivesCost int // multiplier is the list size multiplier from @listSize directive // Applied to children costs for list fields multiplier int - // Children contain child field costs - Children []*CostTreeNode + // children contain child field costs + children []*CostTreeNode // The data below is stored for deferred cost calculation. // We populate these fields in EnterField and use them as a source of truth in LeaveField. - // + + // fieldRef is the AST field reference + fieldRef int + + // Enclosing type name and field name + fieldCoord FieldCoordinate + // fieldTypeName contains the name of an unwrapped (named) type that is returned by this field. fieldTypeName string @@ -137,9 +143,28 @@ type CostTreeNode struct { // arguments contain the values of arguments passed to the field. arguments map[string]ArgumentInfo - isListType bool - isSimpleType bool - isAbstractType bool + isListType bool + isSimpleType bool + isAbstractType bool + isEnclosingTypeAbstract bool +} + +func (node *CostTreeNode) maxCostImplementingFieldConfig(config *DataSourceCostConfig, fieldName string) *FieldCostConfig { + var maxWeightConfig *FieldCostConfig + for _, implTypeName := range node.implementingTypeNames { + // Get the cost config for the field of an implementing type. + implFieldCoord := FieldCoordinate{implTypeName, fieldName} + fieldConfig := config.Fields[implFieldCoord] + + if fieldConfig != nil { + if fieldConfig.HasWeight && (maxWeightConfig == nil || fieldConfig.Weight > maxWeightConfig.Weight) { + fmt.Printf("found better maxWeightConfig for %v: %v\n", implFieldCoord, fieldConfig) + maxWeightConfig = fieldConfig + } + } + } + return maxWeightConfig + } type ArgumentInfo struct { @@ -164,29 +189,35 @@ type ArgumentInfo struct { // otherwise the argument is Scalar or Enum. isInputObject bool - isScalar bool + isSimple bool } // TotalCost calculates the total cost of this node and all descendants -// Per IBM spec: total = field_weight + argument_weights + (children_total * multiplier) -func (n *CostTreeNode) TotalCost() int { - if n == nil { +func (node *CostTreeNode) TotalCost() int { + if node == nil { return 0 } // Sum children (fields) costs var childrenCost int - for _, child := range n.Children { + for _, child := range node.children { childrenCost += child.TotalCost() } // Apply multiplier to children cost (for list fields) - multiplier := n.multiplier + multiplier := node.multiplier if multiplier == 0 { multiplier = 1 } - // TODO: negative sum should be rounded up to zero - cost := n.ArgumentsCost + n.DirectivesCost + (n.FieldCost+childrenCost)*multiplier + cost := node.argumentsCost + node.directivesCost + if cost < 0 { + // If arguments and directive weights decrease the field cost, floor it to zero. + cost = 0 + } + // Here we do not follow IBM spec. We multiply with field cost. + // If there is weight attached to the type that is returned (resolved) by the field, + // the more objects we request, the more expensive it should be. + cost += (node.fieldCost + childrenCost) * multiplier return cost } @@ -257,38 +288,41 @@ func (c *CostCalculator) EnterField(node *CostTreeNode) { // Attach to parent parent := c.CurrentNode() if parent != nil { - parent.Children = append(parent.Children, node) + parent.children = append(parent.children, node) } c.stack = append(c.stack, node) } -// LeaveField is called when leaving a field during AST traversal. -// This is where we calculate costs because fieldPlanners data is now available. +// LeaveField calculates the cose of the current node and pop from the cost stack. +// It is called when leaving a field during planning. func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { - // Find the current node (should match fieldRef) if len(c.stack) <= 1 { // Keep root on stack return } - current := c.stack[len(c.stack)-1] + // Find the current node (should match fieldRef) + lastIndex := len(c.stack) - 1 + current := c.stack[lastIndex] if current.fieldRef != fieldRef { return } current.dataSourceHashes = dsHashes - c.calculateNodeCosts(current) + parent := c.stack[lastIndex-1] + c.calculateNodeCosts(current, parent) - c.stack = c.stack[:len(c.stack)-1] + c.stack = c.stack[:lastIndex] } // calculateNodeCosts fills in the cost values for a node based on its data sources. // It implements IBM GraphQL Cost Specification. // See: https://ibm.github.io/graphql-specs/cost-spec.html#sec-Field-Cost -func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { - // For every data source we get different weights. - // For this node we sum weights of the field and its arguments. - // For the multiplier we pick the maximum. +// For this node we sum weights of the field or its returned type for all the data sources. +// Each data source can have its own cost configuration. If we plan field on two data sources, +// it means more work for the router: we should sum the costs. +// For the multiplier we pick the maximum. +func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { if len(node.dataSourceHashes) <= 0 { // no data source is responsible for this field return @@ -297,33 +331,72 @@ func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { node.multiplier = 0 for _, dsHash := range node.dataSourceHashes { - config, ok := c.costConfigs[dsHash] + dsCostConfig, ok := c.costConfigs[dsHash] if !ok { - fmt.Printf("WARNING: no cost config for data source %v\n", dsHash) + fmt.Printf("WARNING: no cost dsCostConfig for data source %v\n", dsHash) continue } - // TODO: handle abstract types + fieldConfig := dsCostConfig.Fields[node.fieldCoord] + // The cost directive is not allowed on fields in an interface. + // The cost of a field on an interface can be calculated based on the costs of + // the corresponding field on each concrete type implementing that interface, + // either directly or indirectly through other interfaces. + if fieldConfig != nil && node.isEnclosingTypeAbstract && parent.isAbstractType { + // Composition should not let interface fields have weights, so we assume that + // the enclosing type is concrete. + fmt.Printf("WARNING: cost directive on field %v of interface %v\n", node.fieldCoord, parent.fieldCoord) + } + if node.isEnclosingTypeAbstract && parent.isAbstractType { + fmt.Printf("WARNING: no dsCostConfig for %v, parent node: %v\n", node.fieldCoord, parent.fieldCoord) + // This field is part of the enclosing interface/union. We should look into + // implementing types and find the max-weighted field. + // Found fieldConfig can be used for all the calculations. + // Should we do the same when there is no weight on the field enclosed into the abstract type? + fieldConfig = parent.maxCostImplementingFieldConfig(dsCostConfig, node.fieldCoord.FieldName) + } + + if fieldConfig != nil && fieldConfig.HasWeight { + node.fieldCost += fieldConfig.Weight + } else { + fmt.Printf("WARNING: no weight for %v, parent node: %v\n", node.fieldCoord, parent) + switch { + case node.isSimpleType: + // use the weight of the type returned by this field + node.fieldCost += dsCostConfig.EnumScalarTypeWeight(node.fieldTypeName) + case node.isAbstractType: + // For the abstract field, find the max weight among all implementing types + maxWeight := 0 + for _, implTypeName := range node.implementingTypeNames { + weight := dsCostConfig.ObjectTypeWeight(implTypeName) + if weight > maxWeight { + maxWeight = weight + } + } + node.fieldCost += maxWeight + default: + node.fieldCost += dsCostConfig.ObjectTypeWeight(node.fieldTypeName) + } + } - fieldConfig := config.Fields[node.fieldCoord] if fieldConfig != nil { - node.FieldCost += fieldConfig.Weight - for argName := range node.arguments { + for argName, arg := range node.arguments { weight, ok := fieldConfig.ArgumentWeights[argName] if ok { - node.ArgumentsCost += weight + node.argumentsCost += weight + } else { + // Take into account the type of the argument. + // If the argument definition itself does not have weight attached, + // but the type of the argument does have weight attached to it. + if arg.isSimple { + node.argumentsCost += dsCostConfig.EnumScalarTypeWeight(arg.typeName) + } else { + node.argumentsCost += dsCostConfig.ObjectTypeWeight(arg.typeName) + } } - // What to do if the argument definition itself does not have weight attached, - // but the type of the argument does have weight attached to it? + // TODO: arguments should include costs of input object fields } - } else { - // use the weight of the type returned by this field - if node.isSimpleType { - node.FieldCost += config.EnumScalarWeight(node.fieldTypeName) - } else { - node.FieldCost += config.ObjectWeight(node.fieldTypeName) - } } // Compute multiplier as the maximum of data sources. @@ -335,7 +408,7 @@ func (c *CostCalculator) calculateNodeCosts(node *CostTreeNode) { multiplier := -1 for _, slicingArg := range fieldConfig.SlicingArguments { argInfo, ok := node.arguments[slicingArg] - if ok && argInfo.isScalar && argInfo.intValue > 0 && argInfo.intValue > multiplier { + if ok && argInfo.isSimple && argInfo.intValue > 0 && argInfo.intValue > multiplier { multiplier = argInfo.intValue } } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 3e7d56d734..d4b58952d4 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -421,16 +421,15 @@ func (v *Visitor) mapFieldConfig(ref int) { // enterFieldCost creates a skeleton cost node when entering a field. // Actual cost calculation is deferred to leaveFieldCost when fieldPlanners data is available. -func (v *Visitor) enterFieldCost(ref int) { +func (v *Visitor) enterFieldCost(fieldRef int) { if v.costCalculator == nil { return } typeName := v.Walker.EnclosingTypeDefinition.NameString(v.Definition) - fieldName := v.Operation.FieldNameUnsafeString(ref) - coord := FieldCoordinate{typeName, fieldName} + fieldName := v.Operation.FieldNameUnsafeString(fieldRef) - fieldDefinition, ok := v.Walker.FieldDefinition(ref) + fieldDefinition, ok := v.Walker.FieldDefinition(fieldRef) if !ok { return } @@ -439,7 +438,7 @@ func (v *Visitor) enterFieldCost(ref int) { isSimpleType := v.Definition.TypeIsEnum(fieldDefinitionTypeRef, v.Definition) || v.Definition.TypeIsScalar(fieldDefinitionTypeRef, v.Definition) unwrappedTypeName := v.Definition.ResolveTypeNameString(fieldDefinitionTypeRef) - arguments := v.extractFieldArguments(ref) + arguments := v.extractFieldArguments(fieldRef) // Check and push through if the unwrapped type of this field is interface or union. unwrappedTypeNode, exists := v.Definition.NodeByNameStr(unwrappedTypeName) @@ -463,20 +462,24 @@ func (v *Visitor) enterFieldCost(ref int) { } if len(implementingTypeNames) > 0 { - fmt.Printf("enterFieldCost: field %s.%s is interface or union, implementingTypeNames=%v\n", typeName, fieldName, implementingTypeNames) + fmt.Printf("enterFieldCost: field %s.%s is interface or union, implementing types: %v\n", typeName, fieldName, implementingTypeNames) } - // Create skeleton node - dsHashes will be filled in leaveFieldCost + isEnclosingTypeAbstract := v.Walker.EnclosingTypeDefinition.Kind == ast.NodeKindInterfaceTypeDefinition || + v.Walker.EnclosingTypeDefinition.Kind == ast.NodeKindUnionTypeDefinition + fmt.Printf("EnclosingType Kind = %v for %s.%s\n", v.Walker.EnclosingTypeDefinition.Kind, typeName, fieldName) + // Create a skeleton node. dataSourceHashes will be filled in leaveFieldCost node := CostTreeNode{ - fieldRef: ref, - fieldCoord: coord, - multiplier: 1, - fieldTypeName: unwrappedTypeName, - implementingTypeNames: implementingTypeNames, - isListType: isListType, - isSimpleType: isSimpleType, - isAbstractType: isAbstractType, - arguments: arguments, + fieldRef: fieldRef, + fieldCoord: FieldCoordinate{typeName, fieldName}, + multiplier: 1, + fieldTypeName: unwrappedTypeName, + implementingTypeNames: implementingTypeNames, + isListType: isListType, + isSimpleType: isSimpleType, + isAbstractType: isAbstractType, + isEnclosingTypeAbstract: isEnclosingTypeAbstract, + arguments: arguments, } v.costCalculator.EnterField(&node) } @@ -500,8 +503,10 @@ func (v *Visitor) getFieldDataSourceHashes(ref int) []DSHash { } // extractFieldArguments extracts arguments from a field for cost calculation -func (v *Visitor) extractFieldArguments(ref int) map[string]ArgumentInfo { - argRefs := v.Operation.FieldArguments(ref) +// This implementation does not go deep for input objects yet. +// It should return unwrapped type names for arguments and that is it for now. +func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { + argRefs := v.Operation.FieldArguments(fieldRef) if len(argRefs) == 0 { return nil } @@ -520,7 +525,7 @@ func (v *Visitor) extractFieldArguments(ref int) map[string]ArgumentInfo { fmt.Printf("value = %s\n", val) switch argValue.Kind { case ast.ValueKindBoolean, ast.ValueKindEnum, ast.ValueKindString, ast.ValueKindFloat: - argInfo.isScalar = true + argInfo.isSimple = true argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) case ast.ValueKindNull: // Ignore any nulls @@ -528,20 +533,20 @@ func (v *Visitor) extractFieldArguments(ref int) map[string]ArgumentInfo { case ast.ValueKindInteger: // Extract integer value if present (for multipliers like "first", "limit") argInfo.intValue = int(v.Operation.IntValueAsInt(argValue.Ref)) - argInfo.isScalar = true + argInfo.isSimple = true argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) case ast.ValueKindVariable: argInfo.isInputObject = true variableValue := v.Operation.VariableValueNameString(argValue.Ref) if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { - continue // omit optional argument when variable is not defined + continue // omit optional argument when the variable is not defined } variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinition, v.Operation.VariableValueNameBytes(argValue.Ref)) if !exists { break } - // variableTypeRef := v.Operation.VariableDefinitions[variableDefinition].Type - argInfo.typeName = v.Operation.ResolveTypeNameString(v.Operation.VariableDefinitions[variableDefinition].Type) + variableTypeRef := v.Operation.VariableDefinitions[variableDefinition].Type + argInfo.typeName = v.Operation.ResolveTypeNameString(variableTypeRef) // TODO: we need to analyze variables that contains input object fields. // If these fields has weight attached, use them for calculation. // Variables are not inlined at this stage, so we need to inspect them via AST. @@ -549,9 +554,8 @@ func (v *Visitor) extractFieldArguments(ref int) map[string]ArgumentInfo { case ast.ValueKindList: unwrappedTypeRef := v.Operation.ResolveUnderlyingType(argValue.Ref) argInfo.typeName = v.Operation.TypeNameString(unwrappedTypeRef) - // how to figure out if the unwrapped is scalar? default: - fmt.Printf("unhandled case: %v\n", argValue.Kind) + fmt.Printf("unhandled argument type: %v\n", argValue.Kind) continue } @@ -770,11 +774,9 @@ func (v *Visitor) LeaveField(ref int) { return } - // Calculate costs and pop from cost stack. - // This is done in LeaveField because fieldPlanners is available in LeaveField only. + // This is done in LeaveField because fieldPlanners become available before LeaveField. if v.costCalculator != nil { - dsHashes := v.getFieldDataSourceHashes(ref) - v.costCalculator.LeaveField(ref, dsHashes) + v.costCalculator.LeaveField(ref, v.getFieldDataSourceHashes(ref)) } if v.currentFields[len(v.currentFields)-1].popOnField == ref { From a11d847f7f14d0a2926851a091dcc3658e4337e8 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:27:13 +0200 Subject: [PATCH 09/43] support listSize directive --- execution/engine/execution_engine_test.go | 333 +++++++++++++++++++--- v2/pkg/engine/plan/static_cost.go | 226 +++++++++------ v2/pkg/engine/plan/visitor.go | 39 ++- 3 files changed, 452 insertions(+), 146 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 14452439bd..6386e41436 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -5578,21 +5578,57 @@ func TestExecutionEngine_Execute(t *testing.T) { string(graphql.StarwarsSchema(t).RawSchema()), ), }) - costConfig := &plan.DataSourceCostConfig{ - Fields: map[plan.FieldCoordinate]*plan.FieldCostConfig{ - {TypeName: "Query", FieldName: "droid"}: { - ArgumentWeights: map[string]int{"id": 1}, - HasWeight: false, - }, - {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, - {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 3}, - {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, - {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, - }, - Types: map[string]int{ - "Human": 13, + + t.Run("droid with weighted plain fields", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + droid(id: "R2D2") { + name + primaryFunction + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + }}, + customConfig, + ), + }, + fields: []plan.FieldConfiguration{ + { + TypeName: "Query", FieldName: "droid", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "id", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, + }, + }, + }, + }, + expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + expectedStaticCost: 18, // Query.droid (1) + droid.name (17) }, - } + computeStaticCost(), + )) t.Run("droid with weighted plain fields and an argument", runWithoutError( ExecutionEngineTestCase{ @@ -5611,21 +5647,29 @@ func TestExecutionEngine_Execute(t *testing.T) { mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", - expectedPath: "/", - expectedBody: "", + expectedHost: "example.com", expectedPath: "/", expectedBody: "", sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, sendStatusCode: 200, }), ), - &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: costConfig}, + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "droid"}: { + ArgumentWeights: map[string]int{"id": 3}, + HasWeight: false, + }, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + }}, customConfig, ), }, fields: []plan.FieldConfiguration{ { - TypeName: "Query", - FieldName: "droid", + TypeName: "Query", FieldName: "droid", Arguments: []plan.ArgumentConfiguration{ { Name: "id", @@ -5636,12 +5680,12 @@ func TestExecutionEngine_Execute(t *testing.T) { }, }, expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, - expectedStaticCost: 19, // Query.droid (1) + Query.droid.id (1) + droid.name (17) + expectedStaticCost: 21, // Query.droid (1) + Query.droid.id (3) + droid.name (17) }, computeStaticCost(), )) - t.Run("hero with interfaces", runWithoutError( + t.Run("hero field has weight (returns interface) and with concrete fragment", runWithoutError( ExecutionEngineTestCase{ schema: graphql.StarwarsSchema(t), operation: func(t *testing.T) graphql.Request { @@ -5649,9 +5693,7 @@ func TestExecutionEngine_Execute(t *testing.T) { Query: `{ hero { name - ... on Human { - height - } + ... on Human { height } } }`, } @@ -5667,7 +5709,17 @@ func TestExecutionEngine_Execute(t *testing.T) { sendStatusCode: 200, }), ), - &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: costConfig}, + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 3}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + Types: map[string]int{ + "Human": 13, + }, + }}, customConfig, ), }, @@ -5677,7 +5729,47 @@ func TestExecutionEngine_Execute(t *testing.T) { computeStaticCost(), )) - t.Run("hero with 1 expected friend", runWithoutError( + t.Run("hero field has no weight (returns interface) and with concrete fragment", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { name } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + Types: map[string]int{ + "Human": 13, + "Droid": 11, + }, + }}, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker"}}}`, + expectedStaticCost: 30, // Query.Human (13) + Droid.name (17=max(7, 17)) + }, + computeStaticCost(), + )) + + t.Run("query hero without assumedSize on friends", runWithoutError( ExecutionEngineTestCase{ schema: graphql.StarwarsSchema(t), operation: func(t *testing.T) graphql.Request { @@ -5685,14 +5777,8 @@ func TestExecutionEngine_Execute(t *testing.T) { Query: `{ hero { friends { - ...on Droid { - name - primaryFunction - } - ...on Human { - name - height - } + ...on Droid { name primaryFunction } + ...on Human { name height } } } }`, @@ -5712,12 +5798,185 @@ func TestExecutionEngine_Execute(t *testing.T) { sendStatusCode: 200, }), ), - &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: costConfig}, + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 127, // Query.hero(max(7,5))+10*(Human(max(7,5))+Human.name(2)+Human.height(1)+Droid.name(2)) + }, + computeStaticCost(), + )) + + t.Run("query hero with assumedSize on friends", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + {"__typename":"Human","name":"Luke Skywalker","height":"12"}, + {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} + ]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 5}, + {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 20}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 247, // Query.hero(max(7,5))+ 20 * (7+2+2+1) + }, + computeStaticCost(), + )) + + t.Run("query hero with assumedSize on friends and weight defined", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + {"__typename":"Human","name":"Luke Skywalker","height":"12"}, + {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} + ]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "friends"}: {HasWeight: true, Weight: 3}, + {TypeName: "Droid", FieldName: "friends"}: {HasWeight: true, Weight: 4}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 5}, + {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 20}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 187, // Query.hero(max(7,5))+ 20 * (4+2+2+1) + }, + computeStaticCost(), + )) + + t.Run("query hero with empty cost structures", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + {"__typename":"Human","name":"Luke Skywalker","height":"12"}, + {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} + ]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{}, + }, customConfig, ), }, expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, - expectedStaticCost: 42, // Query.hero(2)+Human(13=max(13,0))+Human.name(7)+Human.height(3)+Droid.name(17) + expectedStaticCost: 11, // Query.hero(max(1,1))+ 10 * 1 }, computeStaticCost(), )) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 1f14996306..42f9b557d7 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -4,7 +4,6 @@ import "fmt" // StaticCostDefaults contains default cost values when no specific costs are configured var StaticCostDefaults = WeightDefaults{ - Field: 1, EnumScalar: 0, Object: 1, List: 10, // The assumed maximum size of a list for fields that return lists. @@ -12,15 +11,13 @@ var StaticCostDefaults = WeightDefaults{ // WeightDefaults defines default cost values for different GraphQL elements type WeightDefaults struct { - Field int EnumScalar int Object int List int } -// FieldCostConfig defines cost configuration for a specific field of an object or input object. -// Includes @listSize directive fields for objects. -type FieldCostConfig struct { +// FieldWeight defines cost configuration for a specific field of an object or input object. +type FieldWeight struct { // Weight is the cost of this field definition. It could be negative or zero. // Should be used only if HasWeight is true. @@ -32,9 +29,10 @@ type FieldCostConfig struct { // ArgumentWeights maps an argument name to its weight. // Location: ARGUMENT_DEFINITION ArgumentWeights map[string]int +} - // Fields below are defined only on FIELD_DEFINITION from the @listSize directive. - +// FieldListSize contains parsed data from the @listSize directive for an object field. +type FieldListSize struct { // AssumedSize is the default assumed size when no slicing argument is provided. // If 0, the global default list cost is used. AssumedSize int @@ -52,12 +50,34 @@ type FieldCostConfig struct { RequireOneSlicingArgument bool } +// multiplier returns the multiplier for a list field based on the values of arguments. +func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo) int { + multiplier := -1 + for _, slicingArg := range ls.SlicingArguments { + argInfo, ok := arguments[slicingArg] + if ok && argInfo.isSimple && argInfo.intValue > 0 && argInfo.intValue > multiplier { + multiplier = argInfo.intValue + } + } + if multiplier == -1 && ls.AssumedSize > 0 { + multiplier = ls.AssumedSize + } + if multiplier == -1 { + multiplier = StaticCostDefaults.List + } + return multiplier +} + // DataSourceCostConfig holds all cost configurations for a data source. // This data is passed from the composition. type DataSourceCostConfig struct { - // Fields maps field coordinate to its cost config. Cannot be on fields of interfaces. + // Weights maps field coordinate to its weights. Cannot be on fields of interfaces. // Location: FIELD_DEFINITION, INPUT_FIELD_DEFINITION - Fields map[FieldCoordinate]*FieldCostConfig + Weights map[FieldCoordinate]*FieldWeight + + // ListSizes maps field coordinates to their respective list size configurations. + // Location: FIELD_DEFINITION + ListSizes map[FieldCoordinate]*FieldListSize // Types maps TypeName to the weight of the object, scalar or enum definition. // If TypeName is not present, the default value for Enums and Scalars is 0, otherwise 1. @@ -76,8 +96,8 @@ type DataSourceCostConfig struct { // NewDataSourceCostConfig creates a new cost config with defaults func NewDataSourceCostConfig() *DataSourceCostConfig { return &DataSourceCostConfig{ - Fields: make(map[FieldCoordinate]*FieldCostConfig), - Types: make(map[string]int), + Weights: make(map[FieldCoordinate]*FieldWeight), + Types: make(map[string]int), } } @@ -143,53 +163,46 @@ type CostTreeNode struct { // arguments contain the values of arguments passed to the field. arguments map[string]ArgumentInfo - isListType bool - isSimpleType bool - isAbstractType bool + returnsListType bool + returnsSimpleType bool + returnsAbstractType bool isEnclosingTypeAbstract bool } -func (node *CostTreeNode) maxCostImplementingFieldConfig(config *DataSourceCostConfig, fieldName string) *FieldCostConfig { - var maxWeightConfig *FieldCostConfig +func (node *CostTreeNode) maxWeightImplementingField(config *DataSourceCostConfig, fieldName string) *FieldWeight { + var maxWeight *FieldWeight for _, implTypeName := range node.implementingTypeNames { // Get the cost config for the field of an implementing type. - implFieldCoord := FieldCoordinate{implTypeName, fieldName} - fieldConfig := config.Fields[implFieldCoord] + coord := FieldCoordinate{implTypeName, fieldName} + fieldWeight := config.Weights[coord] - if fieldConfig != nil { - if fieldConfig.HasWeight && (maxWeightConfig == nil || fieldConfig.Weight > maxWeightConfig.Weight) { - fmt.Printf("found better maxWeightConfig for %v: %v\n", implFieldCoord, fieldConfig) - maxWeightConfig = fieldConfig + if fieldWeight != nil { + if fieldWeight.HasWeight && (maxWeight == nil || fieldWeight.Weight > maxWeight.Weight) { + fmt.Printf("found better maxWeight for %v: %v\n", coord, fieldWeight) + maxWeight = fieldWeight } } } - return maxWeightConfig - + return maxWeight } -type ArgumentInfo struct { - intValue int - - // The name of an unwrapped type. - typeName string - - // If argument is passed an input object, we want to gather counts - // for all the field coordinates with non-null values used in the argument. - // - // For example, for - // "input A { x: Int, rec: A! }" - // following value is passed: - // { x: 1, rec: { x: 2, rec: { x: 3 } } }, - // then coordCounts will be: - // { {"A", "rec"}: 2, {"A", "x"}: 3 } - // - coordCounts map[FieldCoordinate]int - - // isInputObject is true for an input object passed to the argument, - // otherwise the argument is Scalar or Enum. - isInputObject bool - - isSimple bool +func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostConfig, fieldName string, arguments map[string]ArgumentInfo) *FieldListSize { + var maxMultiplier int + var maxListSize *FieldListSize + for _, implTypeName := range node.implementingTypeNames { + coord := FieldCoordinate{implTypeName, fieldName} + listSize := config.ListSizes[coord] + + if listSize != nil { + multiplier := listSize.multiplier(arguments) + if maxListSize == nil || multiplier > maxMultiplier { + fmt.Printf("found better multiplier for %v: %v\n", coord, multiplier) + maxMultiplier = multiplier + maxListSize = listSize + } + } + } + return maxListSize } // TotalCost calculates the total cost of this node and all descendants @@ -222,6 +235,31 @@ func (node *CostTreeNode) TotalCost() int { return cost } +type ArgumentInfo struct { + intValue int + + // The name of an unwrapped type. + typeName string + + // If argument is passed an input object, we want to gather counts + // for all the field coordinates with non-null values used in the argument. + // + // For example, for + // "input A { x: Int, rec: A! }" + // following value is passed: + // { x: 1, rec: { x: 2, rec: { x: 3 } } }, + // then coordCounts will be: + // { {"A", "rec"}: 2, {"A", "x"}: 3 } + // + coordCounts map[FieldCoordinate]int + + // isInputObject is true for an input object passed to the argument, + // otherwise the argument is Scalar or Enum. + isInputObject bool + + isSimple bool +} + // CostTree represents the complete cost tree for a query type CostTree struct { Root *CostTreeNode @@ -245,9 +283,6 @@ type CostCalculator struct { // costConfigs maps data source hash to its cost configuration costConfigs map[DSHash]*DataSourceCostConfig - - // defaultConfig is used when no data source specific config exists - defaultConfig *DataSourceCostConfig } // NewCostCalculator creates a new cost calculator @@ -316,12 +351,13 @@ func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { } // calculateNodeCosts fills in the cost values for a node based on its data sources. -// It implements IBM GraphQL Cost Specification. -// See: https://ibm.github.io/graphql-specs/cost-spec.html#sec-Field-Cost +// // For this node we sum weights of the field or its returned type for all the data sources. // Each data source can have its own cost configuration. If we plan field on two data sources, // it means more work for the router: we should sum the costs. -// For the multiplier we pick the maximum. +// +// For the multiplier we pick the maximum field weight of implementing types and then +// the maximum among slicing arguments. func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { if len(node.dataSourceHashes) <= 0 { // no data source is responsible for this field @@ -337,34 +373,36 @@ func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { continue } - fieldConfig := dsCostConfig.Fields[node.fieldCoord] + fieldWeight := dsCostConfig.Weights[node.fieldCoord] + listSize := dsCostConfig.ListSizes[node.fieldCoord] // The cost directive is not allowed on fields in an interface. // The cost of a field on an interface can be calculated based on the costs of // the corresponding field on each concrete type implementing that interface, // either directly or indirectly through other interfaces. - if fieldConfig != nil && node.isEnclosingTypeAbstract && parent.isAbstractType { + if fieldWeight != nil && node.isEnclosingTypeAbstract && parent.returnsAbstractType { // Composition should not let interface fields have weights, so we assume that // the enclosing type is concrete. fmt.Printf("WARNING: cost directive on field %v of interface %v\n", node.fieldCoord, parent.fieldCoord) } - if node.isEnclosingTypeAbstract && parent.isAbstractType { - fmt.Printf("WARNING: no dsCostConfig for %v, parent node: %v\n", node.fieldCoord, parent.fieldCoord) - // This field is part of the enclosing interface/union. We should look into - // implementing types and find the max-weighted field. - // Found fieldConfig can be used for all the calculations. - // Should we do the same when there is no weight on the field enclosed into the abstract type? - fieldConfig = parent.maxCostImplementingFieldConfig(dsCostConfig, node.fieldCoord.FieldName) + if node.isEnclosingTypeAbstract && parent.returnsAbstractType { + // This field is part of the enclosing interface/union. + // We look into implementing types and find the max-weighted field. + // Found fieldWeight can be used for all the calculations. + fieldWeight = parent.maxWeightImplementingField(dsCostConfig, node.fieldCoord.FieldName) + // If this field has listSize defined, then do not look into implementing types. + if listSize == nil && node.returnsListType { + listSize = parent.maxMultiplierImplementingField(dsCostConfig, node.fieldCoord.FieldName, node.arguments) + } } - if fieldConfig != nil && fieldConfig.HasWeight { - node.fieldCost += fieldConfig.Weight + if fieldWeight != nil && fieldWeight.HasWeight { + node.fieldCost += fieldWeight.Weight } else { - fmt.Printf("WARNING: no weight for %v, parent node: %v\n", node.fieldCoord, parent) + // Use the weight of the type returned by this field switch { - case node.isSimpleType: - // use the weight of the type returned by this field + case node.returnsSimpleType: node.fieldCost += dsCostConfig.EnumScalarTypeWeight(node.fieldTypeName) - case node.isAbstractType: + case node.returnsAbstractType: // For the abstract field, find the max weight among all implementing types maxWeight := 0 for _, implTypeName := range node.implementingTypeNames { @@ -379,47 +417,45 @@ func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { } } - if fieldConfig != nil { - for argName, arg := range node.arguments { - weight, ok := fieldConfig.ArgumentWeights[argName] - if ok { + for argName, arg := range node.arguments { + if fieldWeight != nil { + if weight, ok := fieldWeight.ArgumentWeights[argName]; ok { node.argumentsCost += weight - } else { - // Take into account the type of the argument. - // If the argument definition itself does not have weight attached, - // but the type of the argument does have weight attached to it. - if arg.isSimple { - node.argumentsCost += dsCostConfig.EnumScalarTypeWeight(arg.typeName) - } else { - node.argumentsCost += dsCostConfig.ObjectTypeWeight(arg.typeName) - } + continue } - + } + // Take into account the type of the argument. + // If the argument definition itself does not have weight attached, + // but the type of the argument does have weight attached to it. + if arg.isSimple { + node.argumentsCost += dsCostConfig.EnumScalarTypeWeight(arg.typeName) + } else if arg.isInputObject { // TODO: arguments should include costs of input object fields + } else { + node.argumentsCost += dsCostConfig.ObjectTypeWeight(arg.typeName) } + } // Compute multiplier as the maximum of data sources. - if !node.isListType || fieldConfig == nil { - node.multiplier = 1 + if !node.returnsListType { continue } - multiplier := -1 - for _, slicingArg := range fieldConfig.SlicingArguments { - argInfo, ok := node.arguments[slicingArg] - if ok && argInfo.isSimple && argInfo.intValue > 0 && argInfo.intValue > multiplier { - multiplier = argInfo.intValue + if listSize != nil { + multiplier := listSize.multiplier(node.arguments) + // If this node returns a list of abstract types, then it should have listSize defined + // to set the multiplier. Spec allows defining listSize on the fields of interfaces. + if multiplier > node.multiplier { + node.multiplier = multiplier } } - if multiplier == -1 && fieldConfig.AssumedSize > 0 { - multiplier = fieldConfig.AssumedSize - } - if multiplier > node.multiplier { - node.multiplier = multiplier - } + } + if node.multiplier == 0 && node.returnsListType { + node.multiplier = StaticCostDefaults.List + } } // GetTree returns the cost tree diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index d4b58952d4..83bbc873ed 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -465,8 +465,7 @@ func (v *Visitor) enterFieldCost(fieldRef int) { fmt.Printf("enterFieldCost: field %s.%s is interface or union, implementing types: %v\n", typeName, fieldName, implementingTypeNames) } - isEnclosingTypeAbstract := v.Walker.EnclosingTypeDefinition.Kind == ast.NodeKindInterfaceTypeDefinition || - v.Walker.EnclosingTypeDefinition.Kind == ast.NodeKindUnionTypeDefinition + isEnclosingTypeAbstract := v.Walker.EnclosingTypeDefinition.Kind.IsAbstractType() fmt.Printf("EnclosingType Kind = %v for %s.%s\n", v.Walker.EnclosingTypeDefinition.Kind, typeName, fieldName) // Create a skeleton node. dataSourceHashes will be filled in leaveFieldCost node := CostTreeNode{ @@ -475,9 +474,9 @@ func (v *Visitor) enterFieldCost(fieldRef int) { multiplier: 1, fieldTypeName: unwrappedTypeName, implementingTypeNames: implementingTypeNames, - isListType: isListType, - isSimpleType: isSimpleType, - isAbstractType: isAbstractType, + returnsListType: isListType, + returnsSimpleType: isSimpleType, + returnsAbstractType: isAbstractType, isEnclosingTypeAbstract: isEnclosingTypeAbstract, arguments: arguments, } @@ -527,35 +526,47 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { case ast.ValueKindBoolean, ast.ValueKindEnum, ast.ValueKindString, ast.ValueKindFloat: argInfo.isSimple = true argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) - case ast.ValueKindNull: - // Ignore any nulls - continue case ast.ValueKindInteger: - // Extract integer value if present (for multipliers like "first", "limit") - argInfo.intValue = int(v.Operation.IntValueAsInt(argValue.Ref)) + // Extract integer value if present (for arguments in directives) argInfo.isSimple = true argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) + argInfo.intValue = int(v.Operation.IntValueAsInt(argValue.Ref)) case ast.ValueKindVariable: - argInfo.isInputObject = true variableValue := v.Operation.VariableValueNameString(argValue.Ref) if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { continue // omit optional argument when the variable is not defined } variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinition, v.Operation.VariableValueNameBytes(argValue.Ref)) if !exists { - break + continue } variableTypeRef := v.Operation.VariableDefinitions[variableDefinition].Type - argInfo.typeName = v.Operation.ResolveTypeNameString(variableTypeRef) + unwrappedVarTypeRef := v.Operation.ResolveUnderlyingType(variableTypeRef) + argInfo.typeName = v.Operation.TypeNameString(unwrappedVarTypeRef) + node, exists := v.Definition.NodeByNameStr(argInfo.typeName) + if !exists { + continue + } + fmt.Printf("variableTypeRef=%v unwrappedVarTypeRef=%v typeName=%v node=%v\n", variableTypeRef, unwrappedVarTypeRef, argInfo.typeName, node) + // Analyze the node to see what kind of variable was passed. + switch node.Kind { + case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: + argInfo.isSimple = true + case ast.NodeKindInputObjectTypeDefinition: + argInfo.isInputObject = true + + } // TODO: we need to analyze variables that contains input object fields. // If these fields has weight attached, use them for calculation. // Variables are not inlined at this stage, so we need to inspect them via AST. case ast.ValueKindList: + // not sure if these are relevant unwrappedTypeRef := v.Operation.ResolveUnderlyingType(argValue.Ref) argInfo.typeName = v.Operation.TypeNameString(unwrappedTypeRef) + fmt.Printf("WARNING: unhandled list argument type: %v typeName=%v\n", argValue.Kind, argInfo.typeName) default: - fmt.Printf("unhandled argument type: %v\n", argValue.Kind) + fmt.Printf("WARNING: unhandled argument type: %v\n", argValue.Kind) continue } From 763e84aa4a5a9ec4ae7d096a6aadeb858fd6c4ec Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 13 Jan 2026 10:10:05 +0200 Subject: [PATCH 10/43] fix comments --- v2/pkg/engine/plan/static_cost.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 42f9b557d7..8ce718bc97 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -51,6 +51,7 @@ type FieldListSize struct { } // multiplier returns the multiplier for a list field based on the values of arguments. +// Does not take into account the SizedFields; TBD later. func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo) int { multiplier := -1 for _, slicingArg := range ls.SlicingArguments { @@ -228,7 +229,7 @@ func (node *CostTreeNode) TotalCost() int { cost = 0 } // Here we do not follow IBM spec. We multiply with field cost. - // If there is weight attached to the type that is returned (resolved) by the field, + // If there is a weight attached to the type that is returned (resolved) by the field, // the more objects we request, the more expensive it should be. cost += (node.fieldCost + childrenCost) * multiplier @@ -243,6 +244,7 @@ type ArgumentInfo struct { // If argument is passed an input object, we want to gather counts // for all the field coordinates with non-null values used in the argument. + // TBD later when input objects are supported. // // For example, for // "input A { x: Int, rec: A! }" @@ -437,11 +439,13 @@ func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { } - // Compute multiplier as the maximum of data sources. + // Return early, since we do not support sizedFields yet. That parameter means + // that lisSize could be applied to fields that return non-lists. if !node.returnsListType { continue } + // Compute multiplier as the maximum of data sources. if listSize != nil { multiplier := listSize.multiplier(node.arguments) // If this node returns a list of abstract types, then it should have listSize defined From 6173a8e67ab789c2e60d693d368c6ff0637ede21 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 13 Jan 2026 13:15:28 +0200 Subject: [PATCH 11/43] add the slicingArguments test --- execution/engine/execution_engine_test.go | 134 ++++++++++++++++------ 1 file changed, 100 insertions(+), 34 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 6386e41436..fa810bdba4 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -5548,24 +5548,12 @@ func TestExecutionEngine_Execute(t *testing.T) { t.Run("static cost computation", func(t *testing.T) { rootNodes := []plan.TypeField{ - { - TypeName: "Query", - FieldNames: []string{"hero", "droid"}, - }, - { - TypeName: "Human", - FieldNames: []string{"name", "height", "friends"}, - }, - { - TypeName: "Droid", - FieldNames: []string{"name", "primaryFunction", "friends"}, - }, + {TypeName: "Query", FieldNames: []string{"hero", "droid"}}, + {TypeName: "Human", FieldNames: []string{"name", "height", "friends"}}, + {TypeName: "Droid", FieldNames: []string{"name", "primaryFunction", "friends"}}, } childNodes := []plan.TypeField{ - { - TypeName: "Character", - FieldNames: []string{"name", "friends"}, - }, + {TypeName: "Character", FieldNames: []string{"name", "friends"}}, } customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ Fetch: &graphql_datasource.FetchConfiguration{ @@ -5702,9 +5690,7 @@ func TestExecutionEngine_Execute(t *testing.T) { mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", - expectedPath: "/", - expectedBody: "", + expectedHost: "example.com", expectedPath: "/", expectedBody: "", sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker","height":"12"}}}`, sendStatusCode: 200, }), @@ -5743,9 +5729,7 @@ func TestExecutionEngine_Execute(t *testing.T) { mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", - expectedPath: "/", - expectedBody: "", + expectedHost: "example.com", expectedPath: "/", expectedBody: "", sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker"}}}`, sendStatusCode: 200, }), @@ -5788,9 +5772,7 @@ func TestExecutionEngine_Execute(t *testing.T) { mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", - expectedPath: "/", - expectedBody: "", + expectedHost: "example.com", expectedPath: "/", expectedBody: "", sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ {"__typename":"Human","name":"Luke Skywalker","height":"12"}, {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} @@ -5841,9 +5823,7 @@ func TestExecutionEngine_Execute(t *testing.T) { mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", - expectedPath: "/", - expectedBody: "", + expectedHost: "example.com", expectedPath: "/", expectedBody: "", sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ {"__typename":"Human","name":"Luke Skywalker","height":"12"}, {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} @@ -5898,9 +5878,7 @@ func TestExecutionEngine_Execute(t *testing.T) { mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", - expectedPath: "/", - expectedBody: "", + expectedHost: "example.com", expectedPath: "/", expectedBody: "", sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ {"__typename":"Human","name":"Luke Skywalker","height":"12"}, {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} @@ -5957,9 +5935,7 @@ func TestExecutionEngine_Execute(t *testing.T) { mustGraphqlDataSourceConfiguration(t, "id", mustFactory(t, testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", - expectedPath: "/", - expectedBody: "", + expectedHost: "example.com", expectedPath: "/", expectedBody: "", sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ {"__typename":"Human","name":"Luke Skywalker","height":"12"}, {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} @@ -5981,6 +5957,96 @@ func TestExecutionEngine_Execute(t *testing.T) { computeStaticCost(), )) + t.Run("custom scheme for listSize", func(t *testing.T) { + listSchema := ` + type Query { + items(first: Int, last: Int): [Item!] + } + type Item @key(fields: "id") { + id: ID + } + ` + schemaSlicing, err := graphql.NewSchemaFromString(listSchema) + require.NoError(t, err) + rootNodes := []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"items"}}, + {TypeName: "Item", FieldNames: []string{"id"}}, + } + childNodes := []plan.TypeField{} + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", + }, + SchemaConfiguration: mustSchemaConfig(t, nil, listSchema), + }) + fieldConfig := []plan.FieldConfiguration{ + { + TypeName: "Query", + FieldName: "items", + Path: []string{"items"}, + Arguments: []plan.ArgumentConfiguration{ + { + Name: "first", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, + }, + { + Name: "last", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, + }, + }, + }, + } + t.Run("multiple slicing arguments as literals", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query MultipleSlicingArguments { + items(first: 5, last: 12) { id } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[ {"id":"2"}, {"id":"3"} ]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 8, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 3, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[{"id":"2"},{"id":"3"}]}}`, + expectedStaticCost: 48, // slicingArgument(12) * (Item(3)+Item.id(1)) + }, + computeStaticCost(), + )) + }) + }) } From 6aec35d1028f90b5ce6745556711a54bb3cfdafa Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 13 Jan 2026 17:21:03 +0200 Subject: [PATCH 12/43] leave comments before PR And keep one test failing since we cannnot read variables yet. --- v2/pkg/engine/plan/visitor.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 83bbc873ed..f2984ac09e 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -556,14 +556,16 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { argInfo.isInputObject = true } - // TODO: we need to analyze variables that contains input object fields. + // TODO 1: read values of variables from the context later, not possible to do it here. + + // TODO 2: we need to analyze variables that contains input object fields. // If these fields has weight attached, use them for calculation. // Variables are not inlined at this stage, so we need to inspect them via AST. case ast.ValueKindList: // not sure if these are relevant - unwrappedTypeRef := v.Operation.ResolveUnderlyingType(argValue.Ref) - argInfo.typeName = v.Operation.TypeNameString(unwrappedTypeRef) + // unwrappedTypeRef := v.Operation.ResolveUnderlyingType(argValue.Ref) + // argInfo.typeName = v.Operation.TypeNameString(unwrappedTypeRef) fmt.Printf("WARNING: unhandled list argument type: %v typeName=%v\n", argValue.Kind, argInfo.typeName) default: fmt.Printf("WARNING: unhandled argument type: %v\n", argValue.Kind) From bad52783376d2ec505f120d77d6da3a506848bc3 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 13 Jan 2026 17:28:32 +0200 Subject: [PATCH 13/43] disable broken test --- execution/engine/execution_engine_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index fa810bdba4..5c6274ae71 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -5958,6 +5958,7 @@ func TestExecutionEngine_Execute(t *testing.T) { )) t.Run("custom scheme for listSize", func(t *testing.T) { + t.Skip("Skipping due to known issue with values and arguments") listSchema := ` type Query { items(first: Int, last: Int): [Item!] From cd10b230b70ece43b96e7371476b0560a83006da Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Wed, 14 Jan 2026 16:28:30 +0200 Subject: [PATCH 14/43] handle slicingArguments correctly --- execution/engine/execution_engine.go | 26 +- execution/engine/execution_engine_test.go | 12 +- v2/pkg/engine/plan/plan.go | 29 +-- v2/pkg/engine/plan/planner.go | 5 - v2/pkg/engine/plan/static_cost.go | 284 +++++++++++----------- v2/pkg/engine/plan/visitor.go | 31 ++- v2/pkg/engine/resolve/context.go | 13 +- 7 files changed, 202 insertions(+), 198 deletions(-) diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index f442888a25..aaafaf53b8 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -38,12 +38,6 @@ func newInternalExecutionContext() *internalExecutionContext { } } -func (e *internalExecutionContext) prepare(ctx context.Context, variables []byte, request resolve.Request) { - e.setContext(ctx) - e.setVariables(variables) - e.setRequest(request) -} - func (e *internalExecutionContext) setRequest(request resolve.Request) { e.resolveContext.Request = request } @@ -69,7 +63,7 @@ type ExecutionEngine struct { executionPlanCache *lru.Cache apolloCompatibilityFlags apollocompatibility.Flags // Holds the plan after Execute(). Used in testing. - lastPlan plan.Plan + lastPlan plan.Plan } type WebsocketBeforeStartHook interface { @@ -196,7 +190,10 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques } execContext := newInternalExecutionContext() - execContext.prepare(ctx, operation.Variables, operation.InternalRequest()) + execContext.setContext(ctx) + execContext.setVariables(operation.Variables) + execContext.setRequest(operation.InternalRequest()) + for i := range options { options[i](execContext) } @@ -212,11 +209,12 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques } var report operationreport.Report - cachedPlan := e.getCachedPlan(execContext, operation.Document(), e.config.schema.Document(), operation.OperationName, &report) + cachedPlan, costCalculator := e.getCachedPlan(execContext, operation.Document(), e.config.schema.Document(), operation.OperationName, &report) if report.HasErrors() { return report } e.lastPlan = cachedPlan + costCalculator.SetVariables(execContext.resolveContext.Variables) if execContext.resolveContext.TracingOptions.Enable && !execContext.resolveContext.TracingOptions.ExcludePlannerStats { planningTime := resolve.GetDurationNanoSinceTraceStart(execContext.resolveContext.Context()) - tracePlanStart @@ -239,33 +237,33 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques } } -func (e *ExecutionEngine) getCachedPlan(ctx *internalExecutionContext, operation, definition *ast.Document, operationName string, report *operationreport.Report) plan.Plan { +func (e *ExecutionEngine) getCachedPlan(ctx *internalExecutionContext, operation, definition *ast.Document, operationName string, report *operationreport.Report) (plan.Plan, *plan.CostCalculator) { hash := pool.Hash64.Get() hash.Reset() defer pool.Hash64.Put(hash) err := astprinter.Print(operation, hash) if err != nil { report.AddInternalError(err) - return nil + return nil, nil } cacheKey := hash.Sum64() if cached, ok := e.executionPlanCache.Get(cacheKey); ok { if p, ok := cached.(plan.Plan); ok { - return p + return p, p.GetStaticCostCalculator() } } planner, _ := plan.NewPlanner(e.config.plannerConfig) planResult := planner.Plan(operation, definition, operationName, report) if report.HasErrors() { - return nil + return nil, nil } ctx.postProcessor.Process(planResult) e.executionPlanCache.Add(cacheKey, planResult) - return planResult + return planResult, planResult.GetStaticCostCalculator() } func (e *ExecutionEngine) GetWebsocketBeforeStartHook() WebsocketBeforeStartHook { diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 5c6274ae71..3cac390196 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -337,7 +337,8 @@ func TestExecutionEngine_Execute(t *testing.T) { if testCase.expectedStaticCost != 0 { lastPlan := engine.lastPlan assert.NotNil(t, lastPlan) - assert.Equal(t, testCase.expectedStaticCost, lastPlan.GetStaticCost()) + costCalc := lastPlan.GetStaticCostCalculator() + assert.Equal(t, testCase.expectedStaticCost, costCalc.GetTotalCost()) } } @@ -5958,7 +5959,6 @@ func TestExecutionEngine_Execute(t *testing.T) { )) t.Run("custom scheme for listSize", func(t *testing.T) { - t.Skip("Skipping due to known issue with values and arguments") listSchema := ` type Query { items(first: Int, last: Int): [Item!] @@ -6154,7 +6154,7 @@ func TestExecutionEngine_GetCachedPlan(t *testing.T) { } report := operationreport.Report{} - cachedPlan := engine.getCachedPlan(firstInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) + cachedPlan, _ := engine.getCachedPlan(firstInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) _, oldestCachedPlan, _ := engine.executionPlanCache.GetOldest() assert.False(t, report.HasErrors()) assert.Equal(t, 1, engine.executionPlanCache.Len()) @@ -6165,7 +6165,7 @@ func TestExecutionEngine_GetCachedPlan(t *testing.T) { http.CanonicalHeaderKey("Authorization"): []string{"123abc"}, } - cachedPlan = engine.getCachedPlan(secondInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) + cachedPlan, _ = engine.getCachedPlan(secondInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) _, oldestCachedPlan, _ = engine.executionPlanCache.GetOldest() assert.False(t, report.HasErrors()) assert.Equal(t, 1, engine.executionPlanCache.Len()) @@ -6182,7 +6182,7 @@ func TestExecutionEngine_GetCachedPlan(t *testing.T) { } report := operationreport.Report{} - cachedPlan := engine.getCachedPlan(firstInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) + cachedPlan, _ := engine.getCachedPlan(firstInternalExecCtx, gqlRequest.Document(), schema.Document(), gqlRequest.OperationName, &report) _, oldestCachedPlan, _ := engine.executionPlanCache.GetOldest() assert.False(t, report.HasErrors()) assert.Equal(t, 1, engine.executionPlanCache.Len()) @@ -6193,7 +6193,7 @@ func TestExecutionEngine_GetCachedPlan(t *testing.T) { http.CanonicalHeaderKey("Authorization"): []string{"xyz098"}, } - cachedPlan = engine.getCachedPlan(secondInternalExecCtx, differentGqlRequest.Document(), schema.Document(), differentGqlRequest.OperationName, &report) + cachedPlan, _ = engine.getCachedPlan(secondInternalExecCtx, differentGqlRequest.Document(), schema.Document(), differentGqlRequest.OperationName, &report) _, oldestCachedPlan, _ = engine.executionPlanCache.GetOldest() assert.False(t, report.HasErrors()) assert.Equal(t, 2, engine.executionPlanCache.Len()) diff --git a/v2/pkg/engine/plan/plan.go b/v2/pkg/engine/plan/plan.go index 1252c2e01c..4ed9a97325 100644 --- a/v2/pkg/engine/plan/plan.go +++ b/v2/pkg/engine/plan/plan.go @@ -14,22 +14,21 @@ const ( type Plan interface { PlanKind() Kind SetFlushInterval(interval int64) - GetStaticCost() int - SetStaticCost(cost int) + GetStaticCostCalculator() *CostCalculator } type SynchronousResponsePlan struct { - Response *resolve.GraphQLResponse - FlushInterval int64 - StaticCost int + Response *resolve.GraphQLResponse + FlushInterval int64 + StaticCostCalculator *CostCalculator } func (s *SynchronousResponsePlan) GetStaticCost() int { - return s.StaticCost + return s.StaticCostCalculator.GetTotalCost() } -func (s *SynchronousResponsePlan) SetStaticCost(cost int) { - s.StaticCost = cost +func (s *SynchronousResponsePlan) GetStaticCostCalculator() *CostCalculator { + return s.StaticCostCalculator } func (s *SynchronousResponsePlan) SetFlushInterval(interval int64) { @@ -41,17 +40,13 @@ func (*SynchronousResponsePlan) PlanKind() Kind { } type SubscriptionResponsePlan struct { - Response *resolve.GraphQLSubscription - FlushInterval int64 - StaticCost int + Response *resolve.GraphQLSubscription + FlushInterval int64 + StaticCostCalculator *CostCalculator } -func (s *SubscriptionResponsePlan) GetStaticCost() int { - return s.StaticCost -} - -func (s *SubscriptionResponsePlan) SetStaticCost(cost int) { - s.StaticCost = cost +func (s *SubscriptionResponsePlan) GetStaticCostCalculator() *CostCalculator { + return s.StaticCostCalculator } func (s *SubscriptionResponsePlan) SetFlushInterval(interval int64) { diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index 713dceab25..b39af11896 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -214,11 +214,6 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string return } - if p.config.ComputeStaticCost { - cost := p.planningVisitor.costCalculator.GetTotalCost() - p.planningVisitor.plan.SetStaticCost(cost) - } - return p.planningVisitor.plan } diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 8ce718bc97..760378d6ee 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -1,6 +1,10 @@ package plan -import "fmt" +import ( + "fmt" + + "github.com/wundergraph/astjson" +) // StaticCostDefaults contains default cost values when no specific costs are configured var StaticCostDefaults = WeightDefaults{ @@ -41,23 +45,39 @@ type FieldListSize struct { // The value of these arguments will be used as the multiplier. SlicingArguments []string - // SizedFields are field names that return the actual size of the list. - // These can be used for more accurate, actual cost estimation. + // SizedFields are contains field names in the returned object that returns lists. + // For these lists we estimate the size based on the value of the slicing arguments or AssumedSize. SizedFields []string // RequireOneSlicingArgument if true, at least one slicing argument must be provided. // If false and no slicing argument is provided, AssumedSize is used. + // It is not used right now since it is required only for validation. RequireOneSlicingArgument bool } -// multiplier returns the multiplier for a list field based on the values of arguments. +// multiplier returns the multiplier based on arguments and variables. +// It picks the maximum value among slicing arguments, otherwise it tries to use AssumedSize. +// // Does not take into account the SizedFields; TBD later. -func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo) int { +func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo, vars *astjson.Value) int { multiplier := -1 for _, slicingArg := range ls.SlicingArguments { - argInfo, ok := arguments[slicingArg] - if ok && argInfo.isSimple && argInfo.intValue > 0 && argInfo.intValue > multiplier { - multiplier = argInfo.intValue + arg, ok := arguments[slicingArg] + if ok && arg.isSimple { + var value int + // Argument could have a variable or literal value. + if arg.hasVariable { + v := vars.Get(arg.varName) + if v == nil || v.Type() != astjson.TypeNumber { + continue + } + value = vars.GetInt(arg.varName) + } else if arg.intValue > 0 { + value = arg.intValue + } + if value > 0 && value > multiplier { + multiplier = value + } } } if multiplier == -1 && ls.AssumedSize > 0 { @@ -127,6 +147,8 @@ func (c *DataSourceCostConfig) ObjectTypeWeight(name string) int { // CostTreeNode represents a node in the cost calculation tree // Based on IBM GraphQL Cost Specification: https://ibm.github.io/graphql-specs/cost-spec.html type CostTreeNode struct { + parent *CostTreeNode + // dataSourceHashes identifies which data sources resolve this field. dataSourceHashes []DSHash @@ -187,7 +209,7 @@ func (node *CostTreeNode) maxWeightImplementingField(config *DataSourceCostConfi return maxWeight } -func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostConfig, fieldName string, arguments map[string]ArgumentInfo) *FieldListSize { +func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostConfig, fieldName string, arguments map[string]ArgumentInfo, vars *astjson.Value) *FieldListSize { var maxMultiplier int var maxListSize *FieldListSize for _, implTypeName := range node.implementingTypeNames { @@ -195,7 +217,7 @@ func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostC listSize := config.ListSizes[coord] if listSize != nil { - multiplier := listSize.multiplier(arguments) + multiplier := listSize.multiplier(arguments, vars) if maxListSize == nil || multiplier > maxMultiplier { fmt.Printf("found better multiplier for %v: %v\n", coord, multiplier) maxMultiplier = multiplier @@ -206,16 +228,18 @@ func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostC return maxListSize } -// TotalCost calculates the total cost of this node and all descendants -func (node *CostTreeNode) TotalCost() int { +// totalCost calculates the total cost of this node and all descendants +func (node *CostTreeNode) totalCost(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value) int { if node == nil { return 0 } + node.setCostsAndMultiplier(configs, variables) + // Sum children (fields) costs var childrenCost int for _, child := range node.children { - childrenCost += child.TotalCost() + childrenCost += child.totalCost(configs, variables) } // Apply multiplier to children cost (for list fields) @@ -236,123 +260,7 @@ func (node *CostTreeNode) TotalCost() int { return cost } -type ArgumentInfo struct { - intValue int - - // The name of an unwrapped type. - typeName string - - // If argument is passed an input object, we want to gather counts - // for all the field coordinates with non-null values used in the argument. - // TBD later when input objects are supported. - // - // For example, for - // "input A { x: Int, rec: A! }" - // following value is passed: - // { x: 1, rec: { x: 2, rec: { x: 3 } } }, - // then coordCounts will be: - // { {"A", "rec"}: 2, {"A", "x"}: 3 } - // - coordCounts map[FieldCoordinate]int - - // isInputObject is true for an input object passed to the argument, - // otherwise the argument is Scalar or Enum. - isInputObject bool - - isSimple bool -} - -// CostTree represents the complete cost tree for a query -type CostTree struct { - Root *CostTreeNode - Total int -} - -// Calculate computes the total cost and checks against max -func (t *CostTree) Calculate() { - if t.Root != nil { - t.Total = t.Root.TotalCost() - } -} - -// CostCalculator manages cost calculation during AST traversal -type CostCalculator struct { - // stack maintains the current path in the cost tree - stack []*CostTreeNode - - // tree is the complete cost tree being built - tree *CostTree - - // costConfigs maps data source hash to its cost configuration - costConfigs map[DSHash]*DataSourceCostConfig -} - -// NewCostCalculator creates a new cost calculator -func NewCostCalculator() *CostCalculator { - tree := &CostTree{ - Root: &CostTreeNode{ - fieldCoord: FieldCoordinate{"_none", "_root"}, - multiplier: 1, - }, - } - c := CostCalculator{ - stack: make([]*CostTreeNode, 0, 16), - costConfigs: make(map[DSHash]*DataSourceCostConfig), - tree: tree, - } - c.stack = append(c.stack, c.tree.Root) - - return &c -} - -// SetDataSourceCostConfig sets the cost config for a specific data source -func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSourceCostConfig) { - c.costConfigs[dsHash] = config -} - -// CurrentNode returns the current node on the stack -func (c *CostCalculator) CurrentNode() *CostTreeNode { - if len(c.stack) == 0 { - return nil - } - return c.stack[len(c.stack)-1] -} - -// EnterField is called when entering a field during AST traversal. -// It creates a skeleton node and pushes it onto the stack. -// The actual cost calculation happens in LeaveField when fieldPlanners data is available. -func (c *CostCalculator) EnterField(node *CostTreeNode) { - // Attach to parent - parent := c.CurrentNode() - if parent != nil { - parent.children = append(parent.children, node) - } - - c.stack = append(c.stack, node) -} - -// LeaveField calculates the cose of the current node and pop from the cost stack. -// It is called when leaving a field during planning. -func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { - if len(c.stack) <= 1 { // Keep root on stack - return - } - - // Find the current node (should match fieldRef) - lastIndex := len(c.stack) - 1 - current := c.stack[lastIndex] - if current.fieldRef != fieldRef { - return - } - - current.dataSourceHashes = dsHashes - parent := c.stack[lastIndex-1] - c.calculateNodeCosts(current, parent) - - c.stack = c.stack[:lastIndex] -} - -// calculateNodeCosts fills in the cost values for a node based on its data sources. +// setCostsAndMultiplier fills in the cost values for a node based on its data sources. // // For this node we sum weights of the field or its returned type for all the data sources. // Each data source can have its own cost configuration. If we plan field on two data sources, @@ -360,16 +268,17 @@ func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { // // For the multiplier we pick the maximum field weight of implementing types and then // the maximum among slicing arguments. -func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { +func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value) { if len(node.dataSourceHashes) <= 0 { // no data source is responsible for this field return } + parent := node.parent node.multiplier = 0 for _, dsHash := range node.dataSourceHashes { - dsCostConfig, ok := c.costConfigs[dsHash] + dsCostConfig, ok := configs[dsHash] if !ok { fmt.Printf("WARNING: no cost dsCostConfig for data source %v\n", dsHash) continue @@ -393,7 +302,7 @@ func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { fieldWeight = parent.maxWeightImplementingField(dsCostConfig, node.fieldCoord.FieldName) // If this field has listSize defined, then do not look into implementing types. if listSize == nil && node.returnsListType { - listSize = parent.maxMultiplierImplementingField(dsCostConfig, node.fieldCoord.FieldName, node.arguments) + listSize = parent.maxMultiplierImplementingField(dsCostConfig, node.fieldCoord.FieldName, node.arguments, variables) } } @@ -447,7 +356,7 @@ func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { // Compute multiplier as the maximum of data sources. if listSize != nil { - multiplier := listSize.multiplier(node.arguments) + multiplier := listSize.multiplier(node.arguments, variables) // If this node returns a list of abstract types, then it should have listSize defined // to set the multiplier. Spec allows defining listSize on the fields of interfaces. if multiplier > node.multiplier { @@ -462,14 +371,109 @@ func (c *CostCalculator) calculateNodeCosts(node, parent *CostTreeNode) { } } -// GetTree returns the cost tree -func (c *CostCalculator) GetTree() *CostTree { - c.tree.Calculate() - return c.tree +type ArgumentInfo struct { + intValue int + + // The name of an unwrapped type. + typeName string + + // If argument is passed an input object, we want to gather counts + // for all the field coordinates with non-null values used in the argument. + // TBD later when input objects are supported. + // + // For example, for + // "input A { x: Int, rec: A! }" + // following value is passed: + // { x: 1, rec: { x: 2, rec: { x: 3 } } }, + // then coordCounts will be: + // { {"A", "rec"}: 2, {"A", "x"}: 3 } + // + coordCounts map[FieldCoordinate]int + + // isInputObject is true for an input object passed to the argument, + // otherwise the argument is Scalar or Enum. + isInputObject bool + + isSimple bool + + // When the argument points to a variable, it contains the name of the variable. + hasVariable bool + + // The name of the variable that has value for this argument. + varName string +} + +// CostCalculator manages cost calculation during AST traversal +type CostCalculator struct { + // stack maintains the current path in the cost tree + stack []*CostTreeNode + + // tree is the complete cost tree being built + tree *CostTreeNode + + // costConfigs maps data source hash to its cost configuration + costConfigs map[DSHash]*DataSourceCostConfig + variables *astjson.Value +} + +// NewCostCalculator creates a new cost calculator +func NewCostCalculator() *CostCalculator { + c := CostCalculator{ + stack: make([]*CostTreeNode, 0, 16), + costConfigs: make(map[DSHash]*DataSourceCostConfig), + tree: &CostTreeNode{ + fieldCoord: FieldCoordinate{"_none", "_root"}, + multiplier: 1, + }, + } + c.stack = append(c.stack, c.tree) + + return &c +} + +// SetDataSourceCostConfig sets the cost config for a specific data source +func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSourceCostConfig) { + c.costConfigs[dsHash] = config +} + +// EnterField is called when entering a field during AST traversal. +// It creates a skeleton node and pushes it onto the stack. +// The actual cost calculation happens in LeaveField when fieldPlanners data is available. +func (c *CostCalculator) EnterField(node *CostTreeNode) { + // Attach to parent + if len(c.stack) > 0 { + parent := c.stack[len(c.stack)-1] + parent.children = append(parent.children, node) + } + + c.stack = append(c.stack, node) +} + +// LeaveField calculates the cose of the current node and pop from the cost stack. +// It is called when leaving a field during planning. +func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { + if len(c.stack) <= 1 { // Keep root on stack + return + } + + // Find the current node (should match fieldRef) + lastIndex := len(c.stack) - 1 + current := c.stack[lastIndex] + if current.fieldRef != fieldRef { + return + } + + current.dataSourceHashes = dsHashes + current.parent = c.stack[lastIndex-1] + + c.stack = c.stack[:lastIndex] } // GetTotalCost returns the calculated total cost func (c *CostCalculator) GetTotalCost() int { - c.tree.Calculate() - return c.tree.Total + return c.tree.totalCost(c.costConfigs, c.variables) +} + +func (c *CostCalculator) SetVariables(variables *astjson.Value) { + c.variables = variables } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index f2984ac09e..4c0b1779aa 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -67,7 +67,9 @@ type Visitor struct { // fieldEnclosingTypeNames maps fieldRef to the enclosing type name. fieldEnclosingTypeNames map[int]string - // costCalculator calculates IBM static costs during AST traversal + // costCalculator performs static cost analysis during AST traversal. Visitor calls + // enter/leave field hooks to let the calculator build the cost tree. Cost is calculated + // after actual planning. costCalculator *CostCalculator } @@ -516,12 +518,7 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { argValue := v.Operation.ArgumentValue(argRef) argInfo := ArgumentInfo{} - fmt.Printf("extractFieldArguments: argName=%s, argValue=%v\n", argName, argValue) - val, err := v.Operation.PrintValueBytes(argValue, nil) - if err != nil { - panic(err) - } - fmt.Printf("value = %s\n", val) + fmt.Printf("extractFieldArguments: argName = %s, argValueKind = %v\n", argName, argValue.Kind) switch argValue.Kind { case ast.ValueKindBoolean, ast.ValueKindEnum, ast.ValueKindString, ast.ValueKindFloat: argInfo.isSimple = true @@ -536,6 +533,11 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { continue // omit optional argument when the variable is not defined } + + // We cannot read values of variables from the context here. Save it for later. + argInfo.hasVariable = true + argInfo.varName = variableValue + variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinition, v.Operation.VariableValueNameBytes(argValue.Ref)) if !exists { continue @@ -547,7 +549,9 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { if !exists { continue } - fmt.Printf("variableTypeRef=%v unwrappedVarTypeRef=%v typeName=%v node=%v\n", variableTypeRef, unwrappedVarTypeRef, argInfo.typeName, node) + + fmt.Printf("variableTypeRef = %v unwrappedVarTypeRef = %v typeName = %v nodeKind = %v varVal = %v\n", variableTypeRef, unwrappedVarTypeRef, argInfo.typeName, node.Kind, variableValue) + // Analyze the node to see what kind of variable was passed. switch node.Kind { case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: @@ -556,9 +560,8 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { argInfo.isInputObject = true } - // TODO 1: read values of variables from the context later, not possible to do it here. - // TODO 2: we need to analyze variables that contains input object fields. + // TODO: we need to analyze variables that contains input object fields. // If these fields has weight attached, use them for calculation. // Variables are not inlined at this stage, so we need to inspect them via AST. @@ -1181,14 +1184,16 @@ func (v *Visitor) EnterOperationDefinition(ref int) { Response: v.response, } v.plan = &SubscriptionResponsePlan{ - FlushInterval: v.Config.DefaultFlushIntervalMillis, - Response: v.subscription, + FlushInterval: v.Config.DefaultFlushIntervalMillis, + Response: v.subscription, + StaticCostCalculator: v.costCalculator, } return } v.plan = &SynchronousResponsePlan{ - Response: v.response, + Response: v.response, + StaticCostCalculator: v.costCalculator, } } diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 474dd66a60..25bfeac79a 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -16,12 +16,19 @@ import ( // Context should not ever be initialized directly, and should be initialized via the NewContext function type Context struct { - ctx context.Context - Variables *astjson.Value + ctx context.Context + + // Variables contains the variables to be used to render values of variables for the subgraph. + // Resolver takes into account RemapVariables for variable names. + Variables *astjson.Value + + // RemapVariables contains a map from new names to old names. When variables are renamed, + // the resolver will use the new name to look up the old name to render the variable in the query. + RemapVariables map[string]string + Files []*httpclient.FileUpload Request Request RenameTypeNames []RenameTypeName - RemapVariables map[string]string TracingOptions TraceOptions RateLimitOptions RateLimitOptions ExecutionOptions ExecutionOptions From 54596573a0358f8ff8269ae3ccb20e7890cc3314 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Wed, 14 Jan 2026 17:16:14 +0200 Subject: [PATCH 15/43] fix nil deref --- execution/engine/execution_engine.go | 4 +++- v2/pkg/engine/plan/static_cost.go | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index aaafaf53b8..c4a2695128 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -214,7 +214,9 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques return report } e.lastPlan = cachedPlan - costCalculator.SetVariables(execContext.resolveContext.Variables) + if costCalculator != nil { + costCalculator.SetVariables(execContext.resolveContext.Variables) + } if execContext.resolveContext.TracingOptions.Enable && !execContext.resolveContext.TracingOptions.ExcludePlannerStats { planningTime := resolve.GetDurationNanoSinceTraceStart(execContext.resolveContext.Context()) - tracePlanStart diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 760378d6ee..d45bf4bd5f 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -357,8 +357,8 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo // Compute multiplier as the maximum of data sources. if listSize != nil { multiplier := listSize.multiplier(node.arguments, variables) - // If this node returns a list of abstract types, then it should have listSize defined - // to set the multiplier. Spec allows defining listSize on the fields of interfaces. + // If this node returns a list of abstract types, then it could have listSize defined. + // Spec allows defining listSize on the fields of interfaces. if multiplier > node.multiplier { node.multiplier = multiplier } From 3b76d84170c5172b8c920679847cdec6854c11c4 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Thu, 15 Jan 2026 11:22:23 +0200 Subject: [PATCH 16/43] cleanup the code --- execution/engine/execution_engine_test.go | 4 +-- v2/pkg/engine/plan/plan.go | 4 --- v2/pkg/engine/plan/planner.go | 31 ++++++++--------------- v2/pkg/engine/plan/static_cost.go | 6 ++--- v2/pkg/engine/plan/visitor.go | 29 +++++++-------------- 5 files changed, 25 insertions(+), 49 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 3cac390196..e8cff9fd4c 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -336,9 +336,9 @@ func TestExecutionEngine_Execute(t *testing.T) { if testCase.expectedStaticCost != 0 { lastPlan := engine.lastPlan - assert.NotNil(t, lastPlan) + require.NotNil(t, lastPlan) costCalc := lastPlan.GetStaticCostCalculator() - assert.Equal(t, testCase.expectedStaticCost, costCalc.GetTotalCost()) + require.Equal(t, testCase.expectedStaticCost, costCalc.GetTotalCost()) } } diff --git a/v2/pkg/engine/plan/plan.go b/v2/pkg/engine/plan/plan.go index 4ed9a97325..06377ffd24 100644 --- a/v2/pkg/engine/plan/plan.go +++ b/v2/pkg/engine/plan/plan.go @@ -23,10 +23,6 @@ type SynchronousResponsePlan struct { StaticCostCalculator *CostCalculator } -func (s *SynchronousResponsePlan) GetStaticCost() int { - return s.StaticCostCalculator.GetTotalCost() -} - func (s *SynchronousResponsePlan) GetStaticCostCalculator() *CostCalculator { return s.StaticCostCalculator } diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index b39af11896..53955b8864 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -60,22 +60,10 @@ func NewPlanner(config Configuration) (*Planner, error) { planningWalker := astvisitor.NewWalkerWithID(48, "PlanningWalker") - // Initialize cost calculator and configure from data sources - var costCalc *CostCalculator - if config.ComputeStaticCost { - costCalc = NewCostCalculator() - for _, ds := range config.DataSources { - if costConfig := ds.GetCostConfig(); costConfig != nil { - costCalc.SetDataSourceCostConfig(ds.Hash(), costConfig) - } - } - } - planningVisitor := &Visitor{ Walker: &planningWalker, fieldConfigs: map[int]*FieldConfiguration{}, disableResolveFieldPositions: config.DisableResolveFieldPositions, - costCalculator: costCalc, } p := &Planner{ @@ -88,14 +76,6 @@ func NewPlanner(config Configuration) (*Planner, error) { return p, nil } -func (p *Planner) SetConfig(config Configuration) { - p.config = config -} - -func (p *Planner) SetDebugConfig(config DebugConfiguration) { - p.config.Debug = config -} - type _opts struct { includeQueryPlanInResponse bool } @@ -169,6 +149,17 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string p.planningVisitor.fieldDependencyKind = selectionsConfig.fieldDependencyKind p.planningVisitor.fieldRefDependants = inverseMap(selectionsConfig.fieldRefDependsOn) + // Initialize cost calculator and configure from data sources + if p.config.ComputeStaticCost { + calc := NewCostCalculator() + for _, ds := range p.config.DataSources { + if costConfig := ds.GetCostConfig(); costConfig != nil { + calc.SetDataSourceCostConfig(ds.Hash(), costConfig) + } + } + p.planningVisitor.costCalculator = calc + } + p.planningWalker.ResetVisitors() p.planningWalker.SetVisitorFilter(p.planningVisitor) p.planningWalker.RegisterDocumentVisitor(p.planningVisitor) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index d45bf4bd5f..d05f7fa54f 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -117,8 +117,9 @@ type DataSourceCostConfig struct { // NewDataSourceCostConfig creates a new cost config with defaults func NewDataSourceCostConfig() *DataSourceCostConfig { return &DataSourceCostConfig{ - Weights: make(map[FieldCoordinate]*FieldWeight), - Types: make(map[string]int), + Weights: make(map[FieldCoordinate]*FieldWeight), + ListSizes: make(map[FieldCoordinate]*FieldListSize), + Types: make(map[string]int), } } @@ -201,7 +202,6 @@ func (node *CostTreeNode) maxWeightImplementingField(config *DataSourceCostConfi if fieldWeight != nil { if fieldWeight.HasWeight && (maxWeight == nil || fieldWeight.Weight > maxWeight.Weight) { - fmt.Printf("found better maxWeight for %v: %v\n", coord, fieldWeight) maxWeight = fieldWeight } } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 4c0b1779aa..d1e37d57b4 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -463,12 +463,7 @@ func (v *Visitor) enterFieldCost(fieldRef int) { } } - if len(implementingTypeNames) > 0 { - fmt.Printf("enterFieldCost: field %s.%s is interface or union, implementing types: %v\n", typeName, fieldName, implementingTypeNames) - } - isEnclosingTypeAbstract := v.Walker.EnclosingTypeDefinition.Kind.IsAbstractType() - fmt.Printf("EnclosingType Kind = %v for %s.%s\n", v.Walker.EnclosingTypeDefinition.Kind, typeName, fieldName) // Create a skeleton node. dataSourceHashes will be filled in leaveFieldCost node := CostTreeNode{ fieldRef: fieldRef, @@ -518,16 +513,15 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { argValue := v.Operation.ArgumentValue(argRef) argInfo := ArgumentInfo{} - fmt.Printf("extractFieldArguments: argName = %s, argValueKind = %v\n", argName, argValue.Kind) switch argValue.Kind { - case ast.ValueKindBoolean, ast.ValueKindEnum, ast.ValueKindString, ast.ValueKindFloat: - argInfo.isSimple = true - argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) - case ast.ValueKindInteger: - // Extract integer value if present (for arguments in directives) - argInfo.isSimple = true - argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) - argInfo.intValue = int(v.Operation.IntValueAsInt(argValue.Ref)) + // case ast.ValueKindBoolean, ast.ValueKindEnum, ast.ValueKindString, ast.ValueKindFloat: + // argInfo.isSimple = true + // argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) + // case ast.ValueKindInteger: + // // Extract integer value if present (for arguments in directives) + // argInfo.isSimple = true + // argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) + // argInfo.intValue = int(v.Operation.IntValueAsInt(argValue.Ref)) case ast.ValueKindVariable: variableValue := v.Operation.VariableValueNameString(argValue.Ref) if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { @@ -550,7 +544,7 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { continue } - fmt.Printf("variableTypeRef = %v unwrappedVarTypeRef = %v typeName = %v nodeKind = %v varVal = %v\n", variableTypeRef, unwrappedVarTypeRef, argInfo.typeName, node.Kind, variableValue) + // fmt.Printf("variableTypeRef = %v unwrappedVarTypeRef = %v typeName = %v nodeKind = %v varVal = %v\n", variableTypeRef, unwrappedVarTypeRef, argInfo.typeName, node.Kind, variableValue) // Analyze the node to see what kind of variable was passed. switch node.Kind { @@ -565,11 +559,6 @@ func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { // If these fields has weight attached, use them for calculation. // Variables are not inlined at this stage, so we need to inspect them via AST. - case ast.ValueKindList: - // not sure if these are relevant - // unwrappedTypeRef := v.Operation.ResolveUnderlyingType(argValue.Ref) - // argInfo.typeName = v.Operation.TypeNameString(unwrappedTypeRef) - fmt.Printf("WARNING: unhandled list argument type: %v typeName=%v\n", argValue.Kind, argInfo.typeName) default: fmt.Printf("WARNING: unhandled argument type: %v\n", argValue.Kind) continue From eca56d863cb93b7975198f9c41c5df3024201bcf Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Thu, 15 Jan 2026 11:28:02 +0200 Subject: [PATCH 17/43] rename vars for readability --- v2/pkg/engine/plan/visitor.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index d1e37d57b4..1b97fd342c 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -482,8 +482,8 @@ func (v *Visitor) enterFieldCost(fieldRef int) { // getFieldDataSourceHashes returns all data source hashes for the field. // A field can be planned on multiple data sources in federation scenarios. -func (v *Visitor) getFieldDataSourceHashes(ref int) []DSHash { - plannerIDs, ok := v.fieldPlanners[ref] +func (v *Visitor) getFieldDataSourceHashes(fieldRef int) []DSHash { + plannerIDs, ok := v.fieldPlanners[fieldRef] if !ok || len(plannerIDs) == 0 { return nil } @@ -769,10 +769,10 @@ func (v *Visitor) addInterfaceObjectNameToTypeNames(fieldRef int, typeName []byt return onTypeNames } -func (v *Visitor) LeaveField(ref int) { - v.debugOnLeaveNode(ast.NodeKindField, ref) +func (v *Visitor) LeaveField(fieldRef int) { + v.debugOnLeaveNode(ast.NodeKindField, fieldRef) - if v.skipField(ref) { + if v.skipField(fieldRef) { // we should also check skips on field leave // cause on nested keys we could mistakenly remove wrong object // from the stack of the current objects @@ -781,13 +781,13 @@ func (v *Visitor) LeaveField(ref int) { // This is done in LeaveField because fieldPlanners become available before LeaveField. if v.costCalculator != nil { - v.costCalculator.LeaveField(ref, v.getFieldDataSourceHashes(ref)) + v.costCalculator.LeaveField(fieldRef, v.getFieldDataSourceHashes(fieldRef)) } - if v.currentFields[len(v.currentFields)-1].popOnField == ref { + if v.currentFields[len(v.currentFields)-1].popOnField == fieldRef { v.currentFields = v.currentFields[:len(v.currentFields)-1] } - fieldDefinition, ok := v.Walker.FieldDefinition(ref) + fieldDefinition, ok := v.Walker.FieldDefinition(fieldRef) if !ok { return } From 35ec9837efb3c52865ab0c12487979269037a0f0 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Thu, 15 Jan 2026 12:59:38 +0200 Subject: [PATCH 18/43] add test and comments --- execution/engine/execution_engine_test.go | 48 +++++++++++++++++++++++ v2/pkg/engine/plan/static_cost.go | 22 +++++++++++ 2 files changed, 70 insertions(+) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index e8cff9fd4c..4ddb07edb8 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -6046,6 +6046,54 @@ func TestExecutionEngine_Execute(t *testing.T) { }, computeStaticCost(), )) + t.Run("slicing argument as a variable", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query SlicingWithVariable($limit: Int!) { + items(first: $limit) { id } + }`, + Variables: []byte(`{"limit": 25}`), + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[ {"id":"2"}, {"id":"3"} ]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 8, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 3, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[{"id":"2"},{"id":"3"}]}}`, + expectedStaticCost: 100, // slicingArgument($limit=25) * (Item(3)+Item.id(1)) + }, + computeStaticCost(), + )) + }) }) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index d05f7fa54f..ad1560bfcb 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -1,5 +1,27 @@ package plan +/* + +Static Cost Analysis. + +Planning visitor collects information for the costCalculator via EnterField and LeaveField hooks. +Calculator builds a tree of nodes, each node corresponding to the requested field. +After the planning is done, a callee could get a ref to the calculator and request cost calculation. +Cost calculation walks the previously built tree and using variables provided with operation, +estimates the static cost. + +It builds on top of IBM spec for @cost and @listSize directive with a few changes: +* It uses Int! for weights instead of Float!. +* When weight is specified for the type and a field returns the list of that type, +this weight (along with children's costs) is multiplied too. + +A few things on the TBD list: +* Support of SizedFields of @listSize +* Weights on fields of InputObjects with recursion +* Weights on arguments of directives + +*/ + import ( "fmt" From 7cd18149358ece3cded69c6f337705c1f3feb3c9 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Thu, 15 Jan 2026 13:18:08 +0200 Subject: [PATCH 19/43] fix edge cases --- v2/pkg/engine/plan/static_cost.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index ad1560bfcb..e5e27576a7 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -89,8 +89,10 @@ func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo, vars *ast var value int // Argument could have a variable or literal value. if arg.hasVariable { - v := vars.Get(arg.varName) - if v == nil || v.Type() != astjson.TypeNumber { + if vars == nil { + continue + } + if v := vars.Get(arg.varName); v == nil || v.Type() != astjson.TypeNumber { continue } value = vars.GetInt(arg.varName) @@ -159,7 +161,7 @@ func (c *DataSourceCostConfig) EnumScalarTypeWeight(enumName string) int { // ObjectTypeWeight returns the default object cost func (c *DataSourceCostConfig) ObjectTypeWeight(name string) int { if c == nil { - return 0 + return StaticCostDefaults.Object } if cost, ok := c.Types[name]; ok { return cost @@ -297,6 +299,8 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo } parent := node.parent + node.fieldCost = 0 + node.argumentsCost = 0 node.multiplier = 0 for _, dsHash := range node.dataSourceHashes { @@ -491,7 +495,7 @@ func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { c.stack = c.stack[:lastIndex] } -// GetTotalCost returns the calculated total cost +// GetTotalCost returns the calculated total cost. func (c *CostCalculator) GetTotalCost() int { return c.tree.totalCost(c.costConfigs, c.variables) } From 7ebcea73bd0b290f02650fdd8f3a92159328d060 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Fri, 16 Jan 2026 14:18:36 +0200 Subject: [PATCH 20/43] fix top comment --- v2/pkg/engine/plan/static_cost.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index e5e27576a7..74866bfa15 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -10,12 +10,16 @@ After the planning is done, a callee could get a ref to the calculator and reque Cost calculation walks the previously built tree and using variables provided with operation, estimates the static cost. -It builds on top of IBM spec for @cost and @listSize directive with a few changes: -* It uses Int! for weights instead of Float!. +https://ibm.github.io/graphql-specs/cost-spec.html + +It builds on top of IBM spec for @cost and @listSize directive with a few changes. + +* We use Int! for weights instead of floats packed in String!. * When weight is specified for the type and a field returns the list of that type, this weight (along with children's costs) is multiplied too. A few things on the TBD list: + * Support of SizedFields of @listSize * Weights on fields of InputObjects with recursion * Weights on arguments of directives From dccdd7c61a9275f69a87aca22e66417981c95f1b Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:15:44 +0200 Subject: [PATCH 21/43] build cost tree in the dedicated visitor --- v2/pkg/engine/plan/plan.go | 25 ++- v2/pkg/engine/plan/planner.go | 37 ++-- v2/pkg/engine/plan/static_cost.go | 52 +---- v2/pkg/engine/plan/static_cost_visitor.go | 220 ++++++++++++++++++++++ v2/pkg/engine/plan/visitor.go | 194 ++----------------- 5 files changed, 286 insertions(+), 242 deletions(-) create mode 100644 v2/pkg/engine/plan/static_cost_visitor.go diff --git a/v2/pkg/engine/plan/plan.go b/v2/pkg/engine/plan/plan.go index 06377ffd24..1cca76d896 100644 --- a/v2/pkg/engine/plan/plan.go +++ b/v2/pkg/engine/plan/plan.go @@ -15,6 +15,7 @@ type Plan interface { PlanKind() Kind SetFlushInterval(interval int64) GetStaticCostCalculator() *CostCalculator + SetStaticCostCalculator(calc *CostCalculator) } type SynchronousResponsePlan struct { @@ -23,10 +24,6 @@ type SynchronousResponsePlan struct { StaticCostCalculator *CostCalculator } -func (s *SynchronousResponsePlan) GetStaticCostCalculator() *CostCalculator { - return s.StaticCostCalculator -} - func (s *SynchronousResponsePlan) SetFlushInterval(interval int64) { s.FlushInterval = interval } @@ -35,16 +32,20 @@ func (*SynchronousResponsePlan) PlanKind() Kind { return SynchronousResponseKind } +func (s *SynchronousResponsePlan) GetStaticCostCalculator() *CostCalculator { + return s.StaticCostCalculator +} + +func (s *SynchronousResponsePlan) SetStaticCostCalculator(c *CostCalculator) { + s.StaticCostCalculator = c +} + type SubscriptionResponsePlan struct { Response *resolve.GraphQLSubscription FlushInterval int64 StaticCostCalculator *CostCalculator } -func (s *SubscriptionResponsePlan) GetStaticCostCalculator() *CostCalculator { - return s.StaticCostCalculator -} - func (s *SubscriptionResponsePlan) SetFlushInterval(interval int64) { s.FlushInterval = interval } @@ -52,3 +53,11 @@ func (s *SubscriptionResponsePlan) SetFlushInterval(interval int64) { func (*SubscriptionResponsePlan) PlanKind() Kind { return SubscriptionResponseKind } + +func (s *SubscriptionResponsePlan) GetStaticCostCalculator() *CostCalculator { + return s.StaticCostCalculator +} + +func (s *SubscriptionResponsePlan) SetStaticCostCalculator(c *CostCalculator) { + s.StaticCostCalculator = c +} diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index 53955b8864..426b0e5374 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -18,6 +18,7 @@ type Planner struct { planningWalker *astvisitor.Walker planningVisitor *Visitor + costVisitor *StaticCostVisitor nodeSelectionBuilder *NodeSelectionBuilder planningPathBuilder *PathBuilder @@ -149,17 +150,6 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string p.planningVisitor.fieldDependencyKind = selectionsConfig.fieldDependencyKind p.planningVisitor.fieldRefDependants = inverseMap(selectionsConfig.fieldRefDependsOn) - // Initialize cost calculator and configure from data sources - if p.config.ComputeStaticCost { - calc := NewCostCalculator() - for _, ds := range p.config.DataSources { - if costConfig := ds.GetCostConfig(); costConfig != nil { - calc.SetDataSourceCostConfig(ds.Hash(), costConfig) - } - } - p.planningVisitor.costCalculator = calc - } - p.planningWalker.ResetVisitors() p.planningWalker.SetVisitorFilter(p.planningVisitor) p.planningWalker.RegisterDocumentVisitor(p.planningVisitor) @@ -169,6 +159,17 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string p.planningWalker.RegisterEnterDirectiveVisitor(p.planningVisitor) p.planningWalker.RegisterInlineFragmentVisitor(p.planningVisitor) + // Register cost visitor on the same walker (will be invoked after planningVisitor hooks) + if p.config.ComputeStaticCost { + p.costVisitor = NewStaticCostVisitor(p.planningWalker, operation, definition) + p.costVisitor.planners = plannersConfigurations + p.costVisitor.fieldPlanners = &p.planningVisitor.fieldPlanners + p.costVisitor.operationDefinition = &p.planningVisitor.operationDefinitionRef + + p.planningWalker.RegisterEnterFieldVisitor(p.costVisitor) + p.planningWalker.RegisterLeaveFieldVisitor(p.costVisitor) + } + for key := range p.planningVisitor.planners { if p.config.MinifySubgraphOperations { if dataSourceWithMinify, ok := p.planningVisitor.planners[key].Planner().(SubgraphRequestMinifier); ok { @@ -205,6 +206,20 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string return } + if p.config.ComputeStaticCost { + // Initialize cost calculator and configure from data sources + costCalc := NewCostCalculator() + for _, ds := range p.config.DataSources { + if costConfig := ds.GetCostConfig(); costConfig != nil { + costCalc.SetDataSourceCostConfig(ds.Hash(), costConfig) + } + } + // The root tree pointing to the costTreeNode is the ultimate result of costVisitor. + // Store is as part of this plan for later, should be part of the cached plan too. + costCalc.tree = p.costVisitor.finalCostTree() + p.planningVisitor.plan.SetStaticCostCalculator(costCalc) + } + return p.planningVisitor.plan } diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 74866bfa15..4e3cdd8e17 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -435,29 +435,21 @@ type ArgumentInfo struct { // CostCalculator manages cost calculation during AST traversal type CostCalculator struct { - // stack maintains the current path in the cost tree - stack []*CostTreeNode - - // tree is the complete cost tree being built + // tree points to the root of the complete cost tree. tree *CostTreeNode // costConfigs maps data source hash to its cost configuration costConfigs map[DSHash]*DataSourceCostConfig - variables *astjson.Value + + // variables are passed by the resolver's context. + variables *astjson.Value } // NewCostCalculator creates a new cost calculator func NewCostCalculator() *CostCalculator { c := CostCalculator{ - stack: make([]*CostTreeNode, 0, 16), costConfigs: make(map[DSHash]*DataSourceCostConfig), - tree: &CostTreeNode{ - fieldCoord: FieldCoordinate{"_none", "_root"}, - multiplier: 1, - }, } - c.stack = append(c.stack, c.tree) - return &c } @@ -466,37 +458,8 @@ func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSour c.costConfigs[dsHash] = config } -// EnterField is called when entering a field during AST traversal. -// It creates a skeleton node and pushes it onto the stack. -// The actual cost calculation happens in LeaveField when fieldPlanners data is available. -func (c *CostCalculator) EnterField(node *CostTreeNode) { - // Attach to parent - if len(c.stack) > 0 { - parent := c.stack[len(c.stack)-1] - parent.children = append(parent.children, node) - } - - c.stack = append(c.stack, node) -} - -// LeaveField calculates the cose of the current node and pop from the cost stack. -// It is called when leaving a field during planning. -func (c *CostCalculator) LeaveField(fieldRef int, dsHashes []DSHash) { - if len(c.stack) <= 1 { // Keep root on stack - return - } - - // Find the current node (should match fieldRef) - lastIndex := len(c.stack) - 1 - current := c.stack[lastIndex] - if current.fieldRef != fieldRef { - return - } - - current.dataSourceHashes = dsHashes - current.parent = c.stack[lastIndex-1] - - c.stack = c.stack[:lastIndex] +func (c *CostCalculator) SetVariables(variables *astjson.Value) { + c.variables = variables } // GetTotalCost returns the calculated total cost. @@ -504,6 +467,3 @@ func (c *CostCalculator) GetTotalCost() int { return c.tree.totalCost(c.costConfigs, c.variables) } -func (c *CostCalculator) SetVariables(variables *astjson.Value) { - c.variables = variables -} diff --git a/v2/pkg/engine/plan/static_cost_visitor.go b/v2/pkg/engine/plan/static_cost_visitor.go new file mode 100644 index 0000000000..19c90ad92f --- /dev/null +++ b/v2/pkg/engine/plan/static_cost_visitor.go @@ -0,0 +1,220 @@ +package plan + +import ( + "fmt" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" +) + +// StaticCostVisitor builds the cost tree during AST traversal. +// It is registered on the same walker as the planning Visitor and uses +// data from the planning visitor (fieldPlanners, planners) to determine +// which data sources are responsible for each field. +type StaticCostVisitor struct { + Walker *astvisitor.Walker + + // AST documents - set before walking + Operation *ast.Document + Definition *ast.Document + + // References to planning visitor data - set before walking + planners []PlannerConfiguration + fieldPlanners *map[int][]int // Pointer to Visitor.fieldPlanners + + // Pointer to the main visitor's operationDefinition (set during EnterDocument) + operationDefinition *int + + // stack to keep track of the current node + stack []*CostTreeNode + + // The final cost tree that is built during plan traversal. + tree *CostTreeNode +} + +// NewStaticCostVisitor creates a new cost tree visitor +func NewStaticCostVisitor(walker *astvisitor.Walker, operation, definition *ast.Document) *StaticCostVisitor { + stack := make([]*CostTreeNode, 0, 16) + rootNode := CostTreeNode{ + fieldCoord: FieldCoordinate{"_none", "_root"}, + multiplier: 1, + } + stack = append(stack, &rootNode) + return &StaticCostVisitor{ + Walker: walker, + Operation: operation, + Definition: definition, + stack: stack, + tree: &rootNode, + } +} + +// EnterField creates a partial cost node when entering a field. +// The node is filled in full in the LeaveField when fieldPlanners data is available. +func (v *StaticCostVisitor) EnterField(fieldRef int) { + typeName := v.Walker.EnclosingTypeDefinition.NameString(v.Definition) + fieldName := v.Operation.FieldNameUnsafeString(fieldRef) + + fieldDefinitionRef, ok := v.Walker.FieldDefinition(fieldRef) + if !ok { + return + } + fieldDefinitionTypeRef := v.Definition.FieldDefinitionType(fieldDefinitionRef) + isListType := v.Definition.TypeIsList(fieldDefinitionTypeRef) + isSimpleType := v.Definition.TypeIsEnum(fieldDefinitionTypeRef, v.Definition) || v.Definition.TypeIsScalar(fieldDefinitionTypeRef, v.Definition) + unwrappedTypeName := v.Definition.ResolveTypeNameString(fieldDefinitionTypeRef) + + arguments := v.extractFieldArguments(fieldRef) + + // Check and push through if the unwrapped type of this field is interface or union. + unwrappedTypeNode, exists := v.Definition.NodeByNameStr(unwrappedTypeName) + var implementingTypeNames []string + var isAbstractType bool + if exists { + if unwrappedTypeNode.Kind == ast.NodeKindInterfaceTypeDefinition { + impl, ok := v.Definition.InterfaceTypeDefinitionImplementedByObjectWithNames(unwrappedTypeNode.Ref) + if ok { + implementingTypeNames = append(implementingTypeNames, impl...) + isAbstractType = true + } + } + if unwrappedTypeNode.Kind == ast.NodeKindUnionTypeDefinition { + impl, ok := v.Definition.UnionTypeDefinitionMemberTypeNames(unwrappedTypeNode.Ref) + if ok { + implementingTypeNames = append(implementingTypeNames, impl...) + isAbstractType = true + } + } + } + + isEnclosingTypeAbstract := v.Walker.EnclosingTypeDefinition.Kind.IsAbstractType() + // Create a skeleton node. dataSourceHashes will be filled in leaveFieldCost + node := CostTreeNode{ + fieldRef: fieldRef, + fieldCoord: FieldCoordinate{typeName, fieldName}, + multiplier: 1, + fieldTypeName: unwrappedTypeName, + implementingTypeNames: implementingTypeNames, + returnsListType: isListType, + returnsSimpleType: isSimpleType, + returnsAbstractType: isAbstractType, + isEnclosingTypeAbstract: isEnclosingTypeAbstract, + arguments: arguments, + } + + // Attach to parent + if len(v.stack) > 0 { + parent := v.stack[len(v.stack)-1] + parent.children = append(parent.children, &node) + } + + v.stack = append(v.stack, &node) + +} + +// LeaveField fills DataSource hashes for the current node and pop it from the cost stack. +func (v *StaticCostVisitor) LeaveField(fieldRef int) { + dsHashes := v.getFieldDataSourceHashes(fieldRef) + + if len(v.stack) <= 1 { // Keep root on stack + return + } + + // Find the current node (should match fieldRef) + lastIndex := len(v.stack) - 1 + current := v.stack[lastIndex] + if current.fieldRef != fieldRef { + return + } + + current.dataSourceHashes = dsHashes + current.parent = v.stack[lastIndex-1] + + v.stack = v.stack[:lastIndex] +} + +// getFieldDataSourceHashes returns all data source hashes for the field. +// A field can be planned on multiple data sources in federation scenarios. +func (v *StaticCostVisitor) getFieldDataSourceHashes(fieldRef int) []DSHash { + plannerIDs, ok := (*v.fieldPlanners)[fieldRef] + if !ok || len(plannerIDs) == 0 { + return nil + } + + dsHashes := make([]DSHash, 0, len(plannerIDs)) + for _, plannerID := range plannerIDs { + if plannerID >= 0 && plannerID < len(v.planners) { + dsHash := v.planners[plannerID].DataSourceConfiguration().Hash() + dsHashes = append(dsHashes, dsHash) + } + } + return dsHashes +} + +// extractFieldArguments extracts arguments from a field for cost calculation +// This implementation does not go deep for input objects yet. +// It should return unwrapped type names for arguments and that is it for now. +func (v *StaticCostVisitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { + argRefs := v.Operation.FieldArguments(fieldRef) + if len(argRefs) == 0 { + return nil + } + + arguments := make(map[string]ArgumentInfo, len(argRefs)) + for _, argRef := range argRefs { + argName := v.Operation.ArgumentNameString(argRef) + argValue := v.Operation.ArgumentValue(argRef) + argInfo := ArgumentInfo{} + + switch argValue.Kind { + case ast.ValueKindVariable: + variableValue := v.Operation.VariableValueNameString(argValue.Ref) + if !v.Operation.OperationDefinitionHasVariableDefinition(*v.operationDefinition, variableValue) { + continue // omit optional argument when the variable is not defined + } + + // We cannot read values of variables from the context here. Save it for later. + argInfo.hasVariable = true + argInfo.varName = variableValue + + variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(*v.operationDefinition, v.Operation.VariableValueNameBytes(argValue.Ref)) + if !exists { + continue + } + variableTypeRef := v.Operation.VariableDefinitions[variableDefinition].Type + unwrappedVarTypeRef := v.Operation.ResolveUnderlyingType(variableTypeRef) + argInfo.typeName = v.Operation.TypeNameString(unwrappedVarTypeRef) + node, exists := v.Definition.NodeByNameStr(argInfo.typeName) + if !exists { + continue + } + + // fmt.Printf("variableTypeRef = %v unwrappedVarTypeRef = %v typeName = %v nodeKind = %v varVal = %v\n", variableTypeRef, unwrappedVarTypeRef, argInfo.typeName, node.Kind, variableValue) + + // Analyze the node to see what kind of variable was passed. + switch node.Kind { + case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: + argInfo.isSimple = true + case ast.NodeKindInputObjectTypeDefinition: + argInfo.isInputObject = true + + } + + // TODO: we need to analyze variables that contains input object fields. + // If these fields has weight attached, use them for calculation. + // Variables are not inlined at this stage, so we need to inspect them via AST. + + default: + fmt.Printf("WARNING: unhandled argument type: %v\n", argValue.Kind) + continue + } + + arguments[argName] = argInfo + } + + return arguments +} + +func (v *StaticCostVisitor) finalCostTree() *CostTreeNode { + return v.tree +} diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 1b97fd342c..34d7014fe3 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -39,7 +39,7 @@ type Visitor struct { response *resolve.GraphQLResponse subscription *resolve.GraphQLSubscription OperationName string - operationDefinition int + operationDefinitionRef int objects []*resolve.Object currentFields []objectFields currentField *resolve.Field @@ -57,20 +57,15 @@ type Visitor struct { pathCache map[astvisitor.VisitorKind]map[int]string // plannerFields maps plannerID to fieldRefs planned on this planner. - // It is available just before the LeaveField. + // Values added in AllowVisitor callback which is fired before calling LeaveField plannerFields map[int][]int // fieldPlanners maps fieldRef to the plannerIDs where it was planned on. - // It is available just before the LeaveField. + // Values added in AllowVisitor callback which is fired before calling LeaveField fieldPlanners map[int][]int // fieldEnclosingTypeNames maps fieldRef to the enclosing type name. fieldEnclosingTypeNames map[int]string - - // costCalculator performs static cost analysis during AST traversal. Visitor calls - // enter/leave field hooks to let the calculator build the cost tree. Cost is calculated - // after actual planning. - costCalculator *CostCalculator } type indirectInterfaceField struct { @@ -141,6 +136,10 @@ func (v *Visitor) AllowVisitor(kind astvisitor.VisitorKind, ref int, visitor any // main planner visitor should always be allowed return true } + if _, isCostVisitor := visitor.(*StaticCostVisitor); isCostVisitor { + // cost tree visitor should always be allowed + return true + } var ( path string isFragmentPath bool @@ -406,9 +405,6 @@ func (v *Visitor) EnterField(ref int) { *v.currentFields[len(v.currentFields)-1].fields = append(*v.currentFields[len(v.currentFields)-1].fields, v.currentField) v.mapFieldConfig(ref) - - // Enter cost calculation for this field (skeleton node, actual costs calculated in LeaveField) - v.enterFieldCost(ref) } func (v *Visitor) mapFieldConfig(ref int) { @@ -421,155 +417,6 @@ func (v *Visitor) mapFieldConfig(ref int) { v.fieldConfigs[ref] = fieldConfig } -// enterFieldCost creates a skeleton cost node when entering a field. -// Actual cost calculation is deferred to leaveFieldCost when fieldPlanners data is available. -func (v *Visitor) enterFieldCost(fieldRef int) { - if v.costCalculator == nil { - return - } - - typeName := v.Walker.EnclosingTypeDefinition.NameString(v.Definition) - fieldName := v.Operation.FieldNameUnsafeString(fieldRef) - - fieldDefinition, ok := v.Walker.FieldDefinition(fieldRef) - if !ok { - return - } - fieldDefinitionTypeRef := v.Definition.FieldDefinitionType(fieldDefinition) - isListType := v.Definition.TypeIsList(fieldDefinitionTypeRef) - isSimpleType := v.Definition.TypeIsEnum(fieldDefinitionTypeRef, v.Definition) || v.Definition.TypeIsScalar(fieldDefinitionTypeRef, v.Definition) - unwrappedTypeName := v.Definition.ResolveTypeNameString(fieldDefinitionTypeRef) - - arguments := v.extractFieldArguments(fieldRef) - - // Check and push through if the unwrapped type of this field is interface or union. - unwrappedTypeNode, exists := v.Definition.NodeByNameStr(unwrappedTypeName) - var implementingTypeNames []string - var isAbstractType bool - if exists { - if unwrappedTypeNode.Kind == ast.NodeKindInterfaceTypeDefinition { - impl, ok := v.Definition.InterfaceTypeDefinitionImplementedByObjectWithNames(unwrappedTypeNode.Ref) - if ok { - implementingTypeNames = append(implementingTypeNames, impl...) - isAbstractType = true - } - } - if unwrappedTypeNode.Kind == ast.NodeKindUnionTypeDefinition { - impl, ok := v.Definition.UnionTypeDefinitionMemberTypeNames(unwrappedTypeNode.Ref) - if ok { - implementingTypeNames = append(implementingTypeNames, impl...) - isAbstractType = true - } - } - } - - isEnclosingTypeAbstract := v.Walker.EnclosingTypeDefinition.Kind.IsAbstractType() - // Create a skeleton node. dataSourceHashes will be filled in leaveFieldCost - node := CostTreeNode{ - fieldRef: fieldRef, - fieldCoord: FieldCoordinate{typeName, fieldName}, - multiplier: 1, - fieldTypeName: unwrappedTypeName, - implementingTypeNames: implementingTypeNames, - returnsListType: isListType, - returnsSimpleType: isSimpleType, - returnsAbstractType: isAbstractType, - isEnclosingTypeAbstract: isEnclosingTypeAbstract, - arguments: arguments, - } - v.costCalculator.EnterField(&node) -} - -// getFieldDataSourceHashes returns all data source hashes for the field. -// A field can be planned on multiple data sources in federation scenarios. -func (v *Visitor) getFieldDataSourceHashes(fieldRef int) []DSHash { - plannerIDs, ok := v.fieldPlanners[fieldRef] - if !ok || len(plannerIDs) == 0 { - return nil - } - - dsHashes := make([]DSHash, 0, len(plannerIDs)) - for _, plannerID := range plannerIDs { - if plannerID >= 0 && plannerID < len(v.planners) { - dsHash := v.planners[plannerID].DataSourceConfiguration().Hash() - dsHashes = append(dsHashes, dsHash) - } - } - return dsHashes -} - -// extractFieldArguments extracts arguments from a field for cost calculation -// This implementation does not go deep for input objects yet. -// It should return unwrapped type names for arguments and that is it for now. -func (v *Visitor) extractFieldArguments(fieldRef int) map[string]ArgumentInfo { - argRefs := v.Operation.FieldArguments(fieldRef) - if len(argRefs) == 0 { - return nil - } - - arguments := make(map[string]ArgumentInfo, len(argRefs)) - for _, argRef := range argRefs { - argName := v.Operation.ArgumentNameString(argRef) - argValue := v.Operation.ArgumentValue(argRef) - argInfo := ArgumentInfo{} - - switch argValue.Kind { - // case ast.ValueKindBoolean, ast.ValueKindEnum, ast.ValueKindString, ast.ValueKindFloat: - // argInfo.isSimple = true - // argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) - // case ast.ValueKindInteger: - // // Extract integer value if present (for arguments in directives) - // argInfo.isSimple = true - // argInfo.typeName = v.Operation.TypeNameString(argValue.Ref) - // argInfo.intValue = int(v.Operation.IntValueAsInt(argValue.Ref)) - case ast.ValueKindVariable: - variableValue := v.Operation.VariableValueNameString(argValue.Ref) - if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { - continue // omit optional argument when the variable is not defined - } - - // We cannot read values of variables from the context here. Save it for later. - argInfo.hasVariable = true - argInfo.varName = variableValue - - variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinition, v.Operation.VariableValueNameBytes(argValue.Ref)) - if !exists { - continue - } - variableTypeRef := v.Operation.VariableDefinitions[variableDefinition].Type - unwrappedVarTypeRef := v.Operation.ResolveUnderlyingType(variableTypeRef) - argInfo.typeName = v.Operation.TypeNameString(unwrappedVarTypeRef) - node, exists := v.Definition.NodeByNameStr(argInfo.typeName) - if !exists { - continue - } - - // fmt.Printf("variableTypeRef = %v unwrappedVarTypeRef = %v typeName = %v nodeKind = %v varVal = %v\n", variableTypeRef, unwrappedVarTypeRef, argInfo.typeName, node.Kind, variableValue) - - // Analyze the node to see what kind of variable was passed. - switch node.Kind { - case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: - argInfo.isSimple = true - case ast.NodeKindInputObjectTypeDefinition: - argInfo.isInputObject = true - - } - - // TODO: we need to analyze variables that contains input object fields. - // If these fields has weight attached, use them for calculation. - // Variables are not inlined at this stage, so we need to inspect them via AST. - - default: - fmt.Printf("WARNING: unhandled argument type: %v\n", argValue.Kind) - continue - } - - arguments[argName] = argInfo - } - - return arguments -} - func (v *Visitor) resolveFieldInfo(ref, typeRef int, onTypeNames [][]byte) *resolve.FieldInfo { if v.Config.DisableIncludeInfo { return nil @@ -779,19 +626,14 @@ func (v *Visitor) LeaveField(fieldRef int) { return } - // This is done in LeaveField because fieldPlanners become available before LeaveField. - if v.costCalculator != nil { - v.costCalculator.LeaveField(fieldRef, v.getFieldDataSourceHashes(fieldRef)) - } - if v.currentFields[len(v.currentFields)-1].popOnField == fieldRef { v.currentFields = v.currentFields[:len(v.currentFields)-1] } - fieldDefinition, ok := v.Walker.FieldDefinition(fieldRef) + fieldDefinitionRef, ok := v.Walker.FieldDefinition(fieldRef) if !ok { return } - fieldDefinitionTypeNode := v.Definition.FieldDefinitionTypeNode(fieldDefinition) + fieldDefinitionTypeNode := v.Definition.FieldDefinitionTypeNode(fieldDefinitionRef) switch fieldDefinitionTypeNode.Kind { case ast.NodeKindObjectTypeDefinition, ast.NodeKindInterfaceTypeDefinition, ast.NodeKindUnionTypeDefinition: v.objects = v.objects[:len(v.objects)-1] @@ -1133,14 +975,14 @@ func (v *Visitor) valueRequiresExportedVariable(value ast.Value) bool { } } -func (v *Visitor) EnterOperationDefinition(ref int) { - operationName := v.Operation.OperationDefinitionNameString(ref) +func (v *Visitor) EnterOperationDefinition(opRef int) { + operationName := v.Operation.OperationDefinitionNameString(opRef) if v.OperationName != operationName { v.Walker.SkipNode() return } - v.operationDefinition = ref + v.operationDefinitionRef = opRef rootObject := &resolve.Object{ Fields: []*resolve.Field{}, @@ -1173,16 +1015,14 @@ func (v *Visitor) EnterOperationDefinition(ref int) { Response: v.response, } v.plan = &SubscriptionResponsePlan{ - FlushInterval: v.Config.DefaultFlushIntervalMillis, - Response: v.subscription, - StaticCostCalculator: v.costCalculator, + FlushInterval: v.Config.DefaultFlushIntervalMillis, + Response: v.subscription, } return } v.plan = &SynchronousResponsePlan{ - Response: v.response, - StaticCostCalculator: v.costCalculator, + Response: v.response, } } @@ -1334,10 +1174,10 @@ func (v *Visitor) resolveInputTemplates(config *objectFetchConfiguration, input return v.renderJSONValueTemplate(value, variables, inputValueDefinition) } variableValue := v.Operation.VariableValueNameString(value.Ref) - if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinition, variableValue) { + if !v.Operation.OperationDefinitionHasVariableDefinition(v.operationDefinitionRef, variableValue) { break // omit optional argument when variable is not defined } - variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinition, v.Operation.VariableValueNameBytes(value.Ref)) + variableDefinition, exists := v.Operation.VariableDefinitionByNameAndOperation(v.operationDefinitionRef, v.Operation.VariableValueNameBytes(value.Ref)) if !exists { break } From 122d5a5a829a4ec1f062291f9c7f1b135d5015a8 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:19:35 +0200 Subject: [PATCH 22/43] remove dead code --- v2/pkg/engine/plan/static_cost_visitor.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost_visitor.go b/v2/pkg/engine/plan/static_cost_visitor.go index 19c90ad92f..a2949638a6 100644 --- a/v2/pkg/engine/plan/static_cost_visitor.go +++ b/v2/pkg/engine/plan/static_cost_visitor.go @@ -1,8 +1,6 @@ package plan import ( - "fmt" - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" ) @@ -203,10 +201,6 @@ func (v *StaticCostVisitor) extractFieldArguments(fieldRef int) map[string]Argum // TODO: we need to analyze variables that contains input object fields. // If these fields has weight attached, use them for calculation. // Variables are not inlined at this stage, so we need to inspect them via AST. - - default: - fmt.Printf("WARNING: unhandled argument type: %v\n", argValue.Kind) - continue } arguments[argName] = argInfo From daf54f0f1cb848012a91048073070ecb8d90bfa1 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:40:59 +0200 Subject: [PATCH 23/43] unwrap ifs and add a comment --- v2/pkg/engine/plan/static_cost.go | 47 +++++++++++++++++++------------ 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 4e3cdd8e17..2197602c16 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -89,25 +89,29 @@ func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo, vars *ast multiplier := -1 for _, slicingArg := range ls.SlicingArguments { arg, ok := arguments[slicingArg] - if ok && arg.isSimple { - var value int - // Argument could have a variable or literal value. - if arg.hasVariable { - if vars == nil { - continue - } - if v := vars.Get(arg.varName); v == nil || v.Type() != astjson.TypeNumber { - continue - } - value = vars.GetInt(arg.varName) - } else if arg.intValue > 0 { - value = arg.intValue + if !ok || !arg.isSimple { + continue + } + + var value int + // Argument could be a variable or literal value. + if arg.hasVariable { + if vars == nil { + continue } - if value > 0 && value > multiplier { - multiplier = value + if v := vars.Get(arg.varName); v == nil || v.Type() != astjson.TypeNumber { + continue } + value = vars.GetInt(arg.varName) + } else if arg.intValue > 0 { + value = arg.intValue + } + + if value > 0 && value > multiplier { + multiplier = value } } + if multiplier == -1 && ls.AssumedSize > 0 { multiplier = ls.AssumedSize } @@ -280,9 +284,17 @@ func (node *CostTreeNode) totalCost(configs map[DSHash]*DataSourceCostConfig, va // If arguments and directive weights decrease the field cost, floor it to zero. cost = 0 } - // Here we do not follow IBM spec. We multiply with field cost. + // Here we do not follow IBM spec. IBM spec does not use the cost of the object itself + // in multiplication. It assumes that the weight of the type should be just summed up + // without regard to the size of the list. + // + // We, instead, multiply with field cost. // If there is a weight attached to the type that is returned (resolved) by the field, - // the more objects we request, the more expensive it should be. + // the more objects are requested, the more expensive it should be. + // This, in turn, has some ambiguity for definitions of the weights for the list types. + // "A: [Obj] @cost(weight: 5)" means that the cost of the field is 5 for each object in the list. + // "type Object @cost(weight: 5) { ... }" does exactly the same thing. + // Weight defined on a field has priority over the weight defined on a type. cost += (node.fieldCost + childrenCost) * multiplier return cost @@ -466,4 +478,3 @@ func (c *CostCalculator) SetVariables(variables *astjson.Value) { func (c *CostCalculator) GetTotalCost() int { return c.tree.totalCost(c.costConfigs, c.variables) } - From edecf5647b9e974b63bc719d4fe24ee9d3dd6918 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 19 Jan 2026 18:13:21 +0200 Subject: [PATCH 24/43] fix a comment --- v2/pkg/engine/plan/static_cost_visitor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/plan/static_cost_visitor.go b/v2/pkg/engine/plan/static_cost_visitor.go index a2949638a6..ba466baab9 100644 --- a/v2/pkg/engine/plan/static_cost_visitor.go +++ b/v2/pkg/engine/plan/static_cost_visitor.go @@ -200,7 +200,7 @@ func (v *StaticCostVisitor) extractFieldArguments(fieldRef int) map[string]Argum // TODO: we need to analyze variables that contains input object fields. // If these fields has weight attached, use them for calculation. - // Variables are not inlined at this stage, so we need to inspect them via AST. + // Inline values extracted into variables here, so we need to inspect them via AST. } arguments[argName] = argInfo From 81f9caa4cae887d5b8c256e363966eab93370f83 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 19 Jan 2026 18:22:13 +0200 Subject: [PATCH 25/43] push to the stack anyway --- v2/pkg/engine/plan/static_cost_visitor.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/plan/static_cost_visitor.go b/v2/pkg/engine/plan/static_cost_visitor.go index ba466baab9..3d8fbbb571 100644 --- a/v2/pkg/engine/plan/static_cost_visitor.go +++ b/v2/pkg/engine/plan/static_cost_visitor.go @@ -55,6 +55,8 @@ func (v *StaticCostVisitor) EnterField(fieldRef int) { fieldDefinitionRef, ok := v.Walker.FieldDefinition(fieldRef) if !ok { + // Push the sentinel node, so the LeaveField would pop the stack correctly. + v.stack = append(v.stack, &CostTreeNode{fieldRef: fieldRef}) return } fieldDefinitionTypeRef := v.Definition.FieldDefinitionType(fieldDefinitionRef) @@ -107,7 +109,6 @@ func (v *StaticCostVisitor) EnterField(fieldRef int) { } v.stack = append(v.stack, &node) - } // LeaveField fills DataSource hashes for the current node and pop it from the cost stack. From f795906ded6348c7f6587c661057626f2abefce1 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 20 Jan 2026 13:27:06 +0200 Subject: [PATCH 26/43] test nested lists --- execution/engine/execution_engine_test.go | 906 +++++++++++++--------- 1 file changed, 554 insertions(+), 352 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 4ddb07edb8..7f447962ec 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -5548,324 +5548,325 @@ func TestExecutionEngine_Execute(t *testing.T) { }) t.Run("static cost computation", func(t *testing.T) { - rootNodes := []plan.TypeField{ - {TypeName: "Query", FieldNames: []string{"hero", "droid"}}, - {TypeName: "Human", FieldNames: []string{"name", "height", "friends"}}, - {TypeName: "Droid", FieldNames: []string{"name", "primaryFunction", "friends"}}, - } - childNodes := []plan.TypeField{ - {TypeName: "Character", FieldNames: []string{"name", "friends"}}, - } - customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ - Fetch: &graphql_datasource.FetchConfiguration{ - URL: "https://example.com/", - Method: "GET", - }, - SchemaConfiguration: mustSchemaConfig( - t, - nil, - string(graphql.StarwarsSchema(t).RawSchema()), - ), - }) - - t.Run("droid with weighted plain fields", runWithoutError( - ExecutionEngineTestCase{ - schema: graphql.StarwarsSchema(t), - operation: func(t *testing.T) graphql.Request { - return graphql.Request{ - Query: `{ - droid(id: "R2D2") { - name - primaryFunction - } - }`, - } + t.Run("star wars", func(t *testing.T) { + rootNodes := []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"hero", "droid"}}, + {TypeName: "Human", FieldNames: []string{"name", "height", "friends"}}, + {TypeName: "Droid", FieldNames: []string{"name", "primaryFunction", "friends"}}, + } + childNodes := []plan.TypeField{ + {TypeName: "Character", FieldNames: []string{"name", "friends"}}, + } + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", }, - dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, "id", - mustFactory(t, - testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", expectedPath: "/", expectedBody: "", - sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, - sendStatusCode: 200, - }), + SchemaConfiguration: mustSchemaConfig( + t, + nil, + string(graphql.StarwarsSchema(t).RawSchema()), + ), + }) + + t.Run("droid with weighted plain fields", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + droid(id: "R2D2") { + name + primaryFunction + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + }}, + customConfig, ), - &plan.DataSourceMetadata{ - RootNodes: rootNodes, - ChildNodes: childNodes, - CostConfig: &plan.DataSourceCostConfig{ - Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ - {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + fields: []plan.FieldConfiguration{ + { + TypeName: "Query", FieldName: "droid", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "id", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, }, - }}, - customConfig, - ), - }, - fields: []plan.FieldConfiguration{ - { - TypeName: "Query", FieldName: "droid", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "id", - SourceType: plan.FieldArgumentSource, - RenderConfig: plan.RenderArgumentAsGraphQLValue, }, }, }, + expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + expectedStaticCost: 18, // Query.droid (1) + droid.name (17) }, - expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, - expectedStaticCost: 18, // Query.droid (1) + droid.name (17) - }, - computeStaticCost(), - )) + computeStaticCost(), + )) - t.Run("droid with weighted plain fields and an argument", runWithoutError( - ExecutionEngineTestCase{ - schema: graphql.StarwarsSchema(t), - operation: func(t *testing.T) graphql.Request { - return graphql.Request{ - Query: `{ - droid(id: "R2D2") { - name - primaryFunction - } - }`, - } - }, - dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, "id", - mustFactory(t, - testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", expectedPath: "/", expectedBody: "", - sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, - sendStatusCode: 200, - }), - ), - &plan.DataSourceMetadata{ - RootNodes: rootNodes, - ChildNodes: childNodes, - CostConfig: &plan.DataSourceCostConfig{ - Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ - {TypeName: "Query", FieldName: "droid"}: { - ArgumentWeights: map[string]int{"id": 3}, - HasWeight: false, + t.Run("droid with weighted plain fields and an argument", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + droid(id: "R2D2") { + name + primaryFunction + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "droid"}: { + ArgumentWeights: map[string]int{"id": 3}, + HasWeight: false, + }, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, }, - {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }}, + customConfig, + ), + }, + fields: []plan.FieldConfiguration{ + { + TypeName: "Query", FieldName: "droid", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "id", + SourceType: plan.FieldArgumentSource, + RenderConfig: plan.RenderArgumentAsGraphQLValue, }, - }}, - customConfig, - ), - }, - fields: []plan.FieldConfiguration{ - { - TypeName: "Query", FieldName: "droid", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "id", - SourceType: plan.FieldArgumentSource, - RenderConfig: plan.RenderArgumentAsGraphQLValue, }, }, }, + expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, + expectedStaticCost: 21, // Query.droid (1) + Query.droid.id (3) + droid.name (17) }, - expectedResponse: `{"data":{"droid":{"name":"R2D2","primaryFunction":"no"}}}`, - expectedStaticCost: 21, // Query.droid (1) + Query.droid.id (3) + droid.name (17) - }, - computeStaticCost(), - )) + computeStaticCost(), + )) - t.Run("hero field has weight (returns interface) and with concrete fragment", runWithoutError( - ExecutionEngineTestCase{ - schema: graphql.StarwarsSchema(t), - operation: func(t *testing.T) graphql.Request { - return graphql.Request{ - Query: `{ - hero { - name - ... on Human { height } - } - }`, - } - }, - dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, "id", - mustFactory(t, - testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", expectedPath: "/", expectedBody: "", - sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker","height":"12"}}}`, - sendStatusCode: 200, - }), + t.Run("hero field has weight (returns interface) and with concrete fragment", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + name + ... on Human { height } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker","height":"12"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 3}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + Types: map[string]int{ + "Human": 13, + }, + }}, + customConfig, ), - &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ - Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ - {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, - {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 3}, - {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, - {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, - }, - Types: map[string]int{ - "Human": 13, - }, - }}, - customConfig, - ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker","height":"12"}}}`, + expectedStaticCost: 22, // Query.hero (2) + Human.height (3) + Droid.name (17=max(7, 17)) }, - expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker","height":"12"}}}`, - expectedStaticCost: 22, // Query.hero (2) + Human.height (3) + Droid.name (17=max(7, 17)) - }, - computeStaticCost(), - )) + computeStaticCost(), + )) - t.Run("hero field has no weight (returns interface) and with concrete fragment", runWithoutError( - ExecutionEngineTestCase{ - schema: graphql.StarwarsSchema(t), - operation: func(t *testing.T) graphql.Request { - return graphql.Request{ - Query: `{ - hero { name } - }`, - } - }, - dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, "id", - mustFactory(t, - testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", expectedPath: "/", expectedBody: "", - sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker"}}}`, - sendStatusCode: 200, - }), + t.Run("hero field has no weight (returns interface) and with concrete fragment", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { name } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke Skywalker"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, + }, + Types: map[string]int{ + "Human": 13, + "Droid": 11, + }, + }}, + customConfig, ), - &plan.DataSourceMetadata{RootNodes: rootNodes, ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ - Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ - {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 7}, - {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 17}, - }, - Types: map[string]int{ - "Human": 13, - "Droid": 11, - }, - }}, - customConfig, - ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker"}}}`, + expectedStaticCost: 30, // Query.Human (13) + Droid.name (17=max(7, 17)) }, - expectedResponse: `{"data":{"hero":{"name":"Luke Skywalker"}}}`, - expectedStaticCost: 30, // Query.Human (13) + Droid.name (17=max(7, 17)) - }, - computeStaticCost(), - )) + computeStaticCost(), + )) - t.Run("query hero without assumedSize on friends", runWithoutError( - ExecutionEngineTestCase{ - schema: graphql.StarwarsSchema(t), - operation: func(t *testing.T) graphql.Request { - return graphql.Request{ - Query: `{ - hero { - friends { - ...on Droid { name primaryFunction } - ...on Human { name height } + t.Run("query hero without assumedSize on friends", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } } - } - }`, - } - }, - dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, "id", - mustFactory(t, - testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", expectedPath: "/", expectedBody: "", - sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ {"__typename":"Human","name":"Luke Skywalker","height":"12"}, {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} ]}}}`, - sendStatusCode: 200, - }), - ), - &plan.DataSourceMetadata{ - RootNodes: rootNodes, - ChildNodes: childNodes, - CostConfig: &plan.DataSourceCostConfig{ - Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ - {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, - {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, - {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, - }, - Types: map[string]int{ - "Human": 7, - "Droid": 5, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, }, }, - }, - customConfig, - ), + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 127, // Query.hero(max(7,5))+10*(Human(max(7,5))+Human.name(2)+Human.height(1)+Droid.name(2)) }, - expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, - expectedStaticCost: 127, // Query.hero(max(7,5))+10*(Human(max(7,5))+Human.name(2)+Human.height(1)+Droid.name(2)) - }, - computeStaticCost(), - )) + computeStaticCost(), + )) - t.Run("query hero with assumedSize on friends", runWithoutError( - ExecutionEngineTestCase{ - schema: graphql.StarwarsSchema(t), - operation: func(t *testing.T) graphql.Request { - return graphql.Request{ - Query: `{ - hero { - friends { - ...on Droid { name primaryFunction } - ...on Human { name height } + t.Run("query hero with assumedSize on friends", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } } - } - }`, - } - }, - dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, "id", - mustFactory(t, - testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", expectedPath: "/", expectedBody: "", - sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ {"__typename":"Human","name":"Luke Skywalker","height":"12"}, {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} ]}}}`, - sendStatusCode: 200, - }), - ), - &plan.DataSourceMetadata{ - RootNodes: rootNodes, - ChildNodes: childNodes, - CostConfig: &plan.DataSourceCostConfig{ - Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ - {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, - {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, - {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, - }, - ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ - {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 5}, - {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 20}, - }, - Types: map[string]int{ - "Human": 7, - "Droid": 5, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 5}, + {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 20}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, }, }, - }, - customConfig, - ), + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 247, // Query.hero(max(7,5))+ 20 * (7+2+2+1) }, - expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, - expectedStaticCost: 247, // Query.hero(max(7,5))+ 20 * (7+2+2+1) - }, - computeStaticCost(), - )) + computeStaticCost(), + )) - t.Run("query hero with assumedSize on friends and weight defined", runWithoutError( - ExecutionEngineTestCase{ - schema: graphql.StarwarsSchema(t), - operation: func(t *testing.T) graphql.Request { - return graphql.Request{ - Query: `{ + t.Run("query hero with assumedSize on friends and weight defined", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ hero { friends { ...on Droid { name primaryFunction } @@ -5873,90 +5874,91 @@ func TestExecutionEngine_Execute(t *testing.T) { } } }`, - } - }, - dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, "id", - mustFactory(t, - testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", expectedPath: "/", expectedBody: "", - sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ {"__typename":"Human","name":"Luke Skywalker","height":"12"}, {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} ]}}}`, - sendStatusCode: 200, - }), - ), - &plan.DataSourceMetadata{ - RootNodes: rootNodes, - ChildNodes: childNodes, - CostConfig: &plan.DataSourceCostConfig{ - Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ - {TypeName: "Human", FieldName: "friends"}: {HasWeight: true, Weight: 3}, - {TypeName: "Droid", FieldName: "friends"}: {HasWeight: true, Weight: 4}, - {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, - {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, - {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, - }, - ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ - {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 5}, - {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 20}, - }, - Types: map[string]int{ - "Human": 7, - "Droid": 5, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Human", FieldName: "friends"}: {HasWeight: true, Weight: 3}, + {TypeName: "Droid", FieldName: "friends"}: {HasWeight: true, Weight: 4}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 1}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 2}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 5}, + {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 20}, + }, + Types: map[string]int{ + "Human": 7, + "Droid": 5, + }, }, }, - }, - customConfig, - ), + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 187, // Query.hero(max(7,5))+ 20 * (4+2+2+1) }, - expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, - expectedStaticCost: 187, // Query.hero(max(7,5))+ 20 * (4+2+2+1) - }, - computeStaticCost(), - )) + computeStaticCost(), + )) - t.Run("query hero with empty cost structures", runWithoutError( - ExecutionEngineTestCase{ - schema: graphql.StarwarsSchema(t), - operation: func(t *testing.T) graphql.Request { - return graphql.Request{ - Query: `{ - hero { - friends { - ...on Droid { name primaryFunction } - ...on Human { name height } + t.Run("query hero with empty cost structures", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + hero { + friends { + ...on Droid { name primaryFunction } + ...on Human { name height } + } } - } - }`, - } - }, - dataSources: []plan.DataSource{ - mustGraphqlDataSourceConfiguration(t, "id", - mustFactory(t, - testNetHttpClient(t, roundTripperTestCase{ - expectedHost: "example.com", expectedPath: "/", expectedBody: "", - sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","friends":[ {"__typename":"Human","name":"Luke Skywalker","height":"12"}, {"__typename":"Droid","name":"R2DO","primaryFunction":"joke"} ]}}}`, - sendStatusCode: 200, - }), + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{}, + }, + customConfig, ), - &plan.DataSourceMetadata{ - RootNodes: rootNodes, - ChildNodes: childNodes, - CostConfig: &plan.DataSourceCostConfig{}, - }, - customConfig, - ), + }, + expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, + expectedStaticCost: 11, // Query.hero(max(1,1))+ 10 * 1 }, - expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, - expectedStaticCost: 11, // Query.hero(max(1,1))+ 10 * 1 - }, - computeStaticCost(), - )) + computeStaticCost(), + )) + }) t.Run("custom scheme for listSize", func(t *testing.T) { listSchema := ` @@ -6096,6 +6098,206 @@ func TestExecutionEngine_Execute(t *testing.T) { }) + t.Run("nested lists with compounding multipliers", func(t *testing.T) { + nestedSchema := ` + type Query { + users(first: Int): [User!] + } + type User @key(fields: "id") { + id: ID! + posts(first: Int): [Post!] + } + type Post @key(fields: "id") { + id: ID! + comments(first: Int): [Comment!] + } + type Comment @key(fields: "id") { + id: ID! + text: String! + } + ` + schemaNested, err := graphql.NewSchemaFromString(nestedSchema) + require.NoError(t, err) + + rootNodes := []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"users"}}, + {TypeName: "User", FieldNames: []string{"id", "posts"}}, + {TypeName: "Post", FieldNames: []string{"id", "comments"}}, + {TypeName: "Comment", FieldNames: []string{"id", "text"}}, + } + childNodes := []plan.TypeField{} + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", + }, + SchemaConfiguration: mustSchemaConfig(t, nil, nestedSchema), + }) + fieldConfig := []plan.FieldConfiguration{ + { + TypeName: "Query", FieldName: "users", Path: []string{"users"}, + Arguments: []plan.ArgumentConfiguration{ + {Name: "first", SourceType: plan.FieldArgumentSource, RenderConfig: plan.RenderArgumentAsGraphQLValue}, + }, + }, + { + TypeName: "User", FieldName: "posts", Path: []string{"posts"}, + Arguments: []plan.ArgumentConfiguration{ + {Name: "first", SourceType: plan.FieldArgumentSource, RenderConfig: plan.RenderArgumentAsGraphQLValue}, + }, + }, + { + TypeName: "Post", FieldName: "comments", Path: []string{"comments"}, + Arguments: []plan.ArgumentConfiguration{ + {Name: "first", SourceType: plan.FieldArgumentSource, RenderConfig: plan.RenderArgumentAsGraphQLValue}, + }, + }, + } + + t.Run("nested lists with slicing arguments", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaNested, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + users(first: 10) { + posts(first: 5) { + comments(first: 3) { text } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"users":[{"posts":[{"comments":[{"text":"hello"}]}]}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Comment", FieldName: "text"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "users"}: { + AssumedSize: 100, + SlicingArguments: []string{"first"}, + }, + {TypeName: "User", FieldName: "posts"}: { + AssumedSize: 50, + SlicingArguments: []string{"first"}, + }, + {TypeName: "Post", FieldName: "comments"}: { + AssumedSize: 20, + SlicingArguments: []string{"first"}, + }, + }, + Types: map[string]int{ + "User": 4, + "Post": 3, + "Comment": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"users":[{"posts":[{"comments":[{"text":"hello"}]}]}]}}`, + // Cost calculation: + // users(first:10) -> multiplier 10 + // User type weight: 4 + // posts(first:5) -> multiplier 5 + // Post type weight: 3 + // comments(first:3) -> multiplier 3 + // Comment type weight: 2 + // text weight: 1 + // Total: 10 * (4 + 5 * (3 + 3 * (2 + 1))) + expectedStaticCost: 640, + }, + computeStaticCost(), + )) + + t.Run("nested lists fallback to assumedSize when slicing arg not provided", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaNested, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + users(first: 2) { + posts { + comments(first: 4) { text } + } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"users":[{"posts":[{"comments":[{"text":"hi"}]}]}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Comment", FieldName: "text"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "users"}: { + AssumedSize: 100, + SlicingArguments: []string{"first"}, + }, + {TypeName: "User", FieldName: "posts"}: { + AssumedSize: 50, // no slicing arg, should use this + }, + {TypeName: "Post", FieldName: "comments"}: { + AssumedSize: 20, + SlicingArguments: []string{"first"}, + }, + }, + Types: map[string]int{ + "User": 4, + "Post": 3, + "Comment": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"users":[{"posts":[{"comments":[{"text":"hi"}]}]}]}}`, + // Cost calculation: + // users(first:2) -> multiplier 2 + // User type weight: 4 + // posts (no arg) -> assumedSize 50 + // Post type weight: 3 + // comments(first:4) -> multiplier 4 + // Comment type weight: 2 + // text weight: 1 + // Total: 2 * (4 + 50 * (3 + 4 * (2 + 1))) + expectedStaticCost: 1508, + }, + computeStaticCost(), + )) + }) + }) } From 5245e736d01e5be1f16078339316e59558fece31 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 20 Jan 2026 13:51:00 +0200 Subject: [PATCH 27/43] add more fragment tests --- execution/engine/execution_engine_test.go | 128 +++++++++++++++++++++- 1 file changed, 122 insertions(+), 6 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 7f447962ec..d2811fcb78 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -5958,6 +5958,122 @@ func TestExecutionEngine_Execute(t *testing.T) { }, computeStaticCost(), )) + + t.Run("named fragment on interface", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: ` + fragment CharacterFields on Character { + name + friends { name } + } + { hero { ...CharacterFields } } + `, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke","friends":[{"name":"Leia"}]}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 3}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 5}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Human", FieldName: "friends"}: {AssumedSize: 4}, + {TypeName: "Droid", FieldName: "friends"}: {AssumedSize: 6}, + }, + Types: map[string]int{ + "Human": 2, + "Droid": 3, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke","friends":[{"name":"Leia"}]}}}`, + // Cost calculation: + // Query.hero: 2 + // Character.name: max(Human.name=3, Droid.name=5) = 5 + // friends listSize: max(4, 6) = 6 + // Character type: max(Human=2, Droid=3) = 3 + // name: max(Human.name=3, Droid.name=5) = 5 + // Total: 2 + 5 + 6 * (3 + 5) + expectedStaticCost: 55, + }, + computeStaticCost(), + )) + + t.Run("named fragment with concrete type", runWithoutError( + ExecutionEngineTestCase{ + schema: graphql.StarwarsSchema(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: ` + fragment HumanFields on Human { + name + height + } + { hero { ...HumanFields } } + `, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"hero":{"__typename":"Human","name":"Luke","height":"1.72"}}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "hero"}: {HasWeight: true, Weight: 2}, + {TypeName: "Human", FieldName: "name"}: {HasWeight: true, Weight: 3}, + {TypeName: "Human", FieldName: "height"}: {HasWeight: true, Weight: 7}, + {TypeName: "Droid", FieldName: "name"}: {HasWeight: true, Weight: 5}, + }, + Types: map[string]int{ + "Human": 1, + "Droid": 1, + }, + }, + }, + customConfig, + ), + }, + expectedResponse: `{"data":{"hero":{"name":"Luke","height":"1.72"}}}`, + // Cost calculation: + // Query.hero: 2 + // Human.name: 3 + // Human.height: 7 + // Total: 2 + 3 + 7 + expectedStaticCost: 12, + }, + computeStaticCost(), + )) + }) t.Run("custom scheme for listSize", func(t *testing.T) { @@ -6213,11 +6329,11 @@ func TestExecutionEngine_Execute(t *testing.T) { fields: fieldConfig, expectedResponse: `{"data":{"users":[{"posts":[{"comments":[{"text":"hello"}]}]}]}}`, // Cost calculation: - // users(first:10) -> multiplier 10 + // users(first:10): multiplier 10 // User type weight: 4 - // posts(first:5) -> multiplier 5 + // posts(first:5): multiplier 5 // Post type weight: 3 - // comments(first:3) -> multiplier 3 + // comments(first:3): multiplier 3 // Comment type weight: 2 // text weight: 1 // Total: 10 * (4 + 5 * (3 + 3 * (2 + 1))) @@ -6284,11 +6400,11 @@ func TestExecutionEngine_Execute(t *testing.T) { fields: fieldConfig, expectedResponse: `{"data":{"users":[{"posts":[{"comments":[{"text":"hi"}]}]}]}}`, // Cost calculation: - // users(first:2) -> multiplier 2 + // users(first:2): multiplier 2 // User type weight: 4 - // posts (no arg) -> assumedSize 50 + // posts (no arg): assumedSize 50 // Post type weight: 3 - // comments(first:4) -> multiplier 4 + // comments(first:4): multiplier 4 // Comment type weight: 2 // text weight: 1 // Total: 2 * (4 + 50 * (3 + 4 * (2 + 1))) From 13ab9877af72d397e099245e1921e25db5219392 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 20 Jan 2026 15:31:41 +0200 Subject: [PATCH 28/43] add tests for unions --- execution/engine/execution_engine_test.go | 177 +++++++++++++++++++++- v2/pkg/engine/plan/static_cost.go | 94 ++++++++++++ 2 files changed, 267 insertions(+), 4 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index d2811fcb78..43d4093daf 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -338,6 +338,7 @@ func TestExecutionEngine_Execute(t *testing.T) { lastPlan := engine.lastPlan require.NotNil(t, lastPlan) costCalc := lastPlan.GetStaticCostCalculator() + fmt.Println(costCalc.DebugPrint()) require.Equal(t, testCase.expectedStaticCost, costCalc.GetTotalCost()) } @@ -6064,10 +6065,6 @@ func TestExecutionEngine_Execute(t *testing.T) { ), }, expectedResponse: `{"data":{"hero":{"name":"Luke","height":"1.72"}}}`, - // Cost calculation: - // Query.hero: 2 - // Human.name: 3 - // Human.height: 7 // Total: 2 + 3 + 7 expectedStaticCost: 12, }, @@ -6076,6 +6073,178 @@ func TestExecutionEngine_Execute(t *testing.T) { }) + t.Run("union types", func(t *testing.T) { + unionSchema := ` + type Query { + search(term: String!): [SearchResult!] + } + union SearchResult = User | Post | Comment + type User @key(fields: "id") { + id: ID! + name: String! + email: String! + } + type Post @key(fields: "id") { + id: ID! + title: String! + body: String! + } + type Comment @key(fields: "id") { + id: ID! + text: String! + } + ` + schemaUnion, err := graphql.NewSchemaFromString(unionSchema) + require.NoError(t, err) + + unionRootNodes := []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"search"}}, + {TypeName: "User", FieldNames: []string{"id", "name", "email"}}, + {TypeName: "Post", FieldNames: []string{"id", "title", "body"}}, + {TypeName: "Comment", FieldNames: []string{"id", "text"}}, + } + unionChildNodes := []plan.TypeField{} + unionCustomConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + Fetch: &graphql_datasource.FetchConfiguration{ + URL: "https://example.com/", + Method: "GET", + }, + SchemaConfiguration: mustSchemaConfig(t, nil, unionSchema), + }) + unionFieldConfig := []plan.FieldConfiguration{ + { + TypeName: "Query", + FieldName: "search", + Path: []string{"search"}, + Arguments: []plan.ArgumentConfiguration{ + {Name: "term", SourceType: plan.FieldArgumentSource, RenderConfig: plan.RenderArgumentAsGraphQLValue}, + }, + }, + } + + t.Run("union with all member types", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaUnion, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + search(term: "test") { + ... on User { name email } + ... on Post { title body } + ... on Comment { text } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"search":[{"__typename":"User","name":"John","email":"john@test.com"}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: unionRootNodes, + ChildNodes: unionChildNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "User", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "User", FieldName: "email"}: {HasWeight: true, Weight: 3}, + {TypeName: "Post", FieldName: "title"}: {HasWeight: true, Weight: 4}, + {TypeName: "Post", FieldName: "body"}: {HasWeight: true, Weight: 5}, + {TypeName: "Comment", FieldName: "text"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "search"}: {AssumedSize: 5}, + }, + Types: map[string]int{ + "User": 2, + "Post": 3, + "Comment": 1, + }, + }, + }, + unionCustomConfig, + ), + }, + fields: unionFieldConfig, + expectedResponse: `{"data":{"search":[{"name":"John","email":"john@test.com"}]}}`, + // search listSize: 10 + // For each SearchResult, use max across all union members: + // Type weight: max(User=2, Post=3, Comment=1) = 3 + // Fields: all fields from all fragments are counted + // (2 + 3) + (4 + 5) + (1) = 15 + // TODO: this is not correct, we should pick a maximum sum among types implementing union. + // 9 should be used instead of 15 + // Total: 5 * (3 + 15) + expectedStaticCost: 90, + }, + computeStaticCost(), + )) + + t.Run("union with weighted search field", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaUnion, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `{ + search(term: "test") { + ... on User { name } + ... on Post { title } + } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", + expectedPath: "/", + expectedBody: "", + sendResponseBody: `{"data":{"search":[{"__typename":"User","name":"John"}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: unionRootNodes, + ChildNodes: unionChildNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "User", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "Post", FieldName: "title"}: {HasWeight: true, Weight: 5}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "search"}: {AssumedSize: 3}, + }, + Types: map[string]int{ + "User": 6, + "Post": 10, + }, + }, + }, + unionCustomConfig, + ), + }, + fields: unionFieldConfig, + expectedResponse: `{"data":{"search":[{"name":"John"}]}}`, + // Query.search: max(User=10, Post=6) + // search listSize: 3 + // Union members: + // All fields from all fragments: User.name(2) + Post.title(5) + // Total: 3 * (10+2+5) + // TODO: we might correct this by counting only members of one implementing types + // of a union when fragments are used. + expectedStaticCost: 51, + }, + computeStaticCost(), + )) + }) + t.Run("custom scheme for listSize", func(t *testing.T) { listSchema := ` type Query { diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 2197602c16..a446e7c77f 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -28,6 +28,7 @@ A few things on the TBD list: import ( "fmt" + "strings" "github.com/wundergraph/astjson" ) @@ -478,3 +479,96 @@ func (c *CostCalculator) SetVariables(variables *astjson.Value) { func (c *CostCalculator) GetTotalCost() int { return c.tree.totalCost(c.costConfigs, c.variables) } + +// DebugPrint prints the cost tree structure for debugging purposes. +// It shows each node's field coordinate, costs, multipliers, and computed totals. +func (c *CostCalculator) DebugPrint() string { + if c.tree == nil || len(c.tree.children) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString("Cost Tree Debug:\n") + sb.WriteString("================\n") + c.tree.children[0].debugPrint(&sb, c.costConfigs, c.variables, 0) + fmt.Fprintf(&sb, "\nTotal Cost: %d\n", c.GetTotalCost()) + return sb.String() +} + +// debugPrint recursively prints a node and its children with indentation. +func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, depth int) { + if node == nil { + return + } + + indent := strings.Repeat(" ", depth) + + // Calculate costs for this node + node.setCostsAndMultiplier(configs, variables) + + // Field coordinate info + fieldInfo := fmt.Sprintf("%s.%s", node.fieldCoord.TypeName, node.fieldCoord.FieldName) + + // Build node info line + fmt.Fprintf(sb, "%s├ %s", indent, fieldInfo) + + // Add type info + if node.fieldTypeName != "" { + fmt.Fprintf(sb, " -> %s", node.fieldTypeName) + } + + // Add flags + var flags []string + if node.returnsListType { + flags = append(flags, "list") + } + if node.returnsAbstractType { + flags = append(flags, "abstract") + } + if node.returnsSimpleType { + flags = append(flags, "simple") + } + if len(flags) > 0 { + fmt.Fprintf(sb, " [%s]", strings.Join(flags, ",")) + } + sb.WriteString("\n") + + // Cost details + if node.fieldCost != 0 || node.argumentsCost != 0 || node.multiplier != 0 { + fmt.Fprintf(sb, "%s│ fieldCost=%d, argsCost=%d, multiplier=%d", + indent, node.fieldCost, node.argumentsCost, node.multiplier) + + // Show data sources + if len(node.dataSourceHashes) > 0 { + fmt.Fprintf(sb, ", dataSources=%d", len(node.dataSourceHashes)) + } + sb.WriteString("\n") + } + + // Arguments info + if len(node.arguments) > 0 { + var argStrs []string + for name, arg := range node.arguments { + if arg.hasVariable { + argStrs = append(argStrs, fmt.Sprintf("%s=$%s", name, arg.varName)) + } else if arg.isSimple { + argStrs = append(argStrs, fmt.Sprintf("%s=%d", name, arg.intValue)) + } else { + argStrs = append(argStrs, fmt.Sprintf("%s=", name)) + } + } + fmt.Fprintf(sb, "%s│ args: {%s}\n", indent, strings.Join(argStrs, ", ")) + } + + // Implementing types (for abstract types) + if len(node.implementingTypeNames) > 0 { + fmt.Fprintf(sb, "%s│ implements: [%s]\n", indent, strings.Join(node.implementingTypeNames, ", ")) + } + + subtreeCost := node.totalCost(configs, variables) + fmt.Fprintf(sb, "%s│ subCost=%d\n", indent, subtreeCost) + + // Print children + for _, child := range node.children { + child.debugPrint(sb, configs, variables, depth+1) + } +} From 64eb0bb73db542410d96f367fa03e64498496766 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 20 Jan 2026 16:28:14 +0200 Subject: [PATCH 29/43] add more slicing tests --- execution/engine/execution_engine_test.go | 186 +++++++++++++++++++--- v2/pkg/engine/plan/static_cost.go | 23 +-- 2 files changed, 170 insertions(+), 39 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 43d4093daf..a7e6468651 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -270,12 +270,12 @@ func TestExecutionEngine_Execute(t *testing.T) { engineConf.SetCustomResolveMap(testCase.customResolveMap) engineConf.plannerConfig.Debug = plan.DebugConfiguration{ - PrintOperationTransformations: true, - PrintPlanningPaths: true, + // PrintOperationTransformations: true, + // PrintPlanningPaths: true, // PrintNodeSuggestions: true, - PrintQueryPlans: true, - ConfigurationVisitor: true, - PlanningVisitor: true, + // PrintQueryPlans: true, + // ConfigurationVisitor: true, + // PlanningVisitor: true, // DatasourceVisitor: true, } @@ -338,8 +338,9 @@ func TestExecutionEngine_Execute(t *testing.T) { lastPlan := engine.lastPlan require.NotNil(t, lastPlan) costCalc := lastPlan.GetStaticCostCalculator() + gotCost := costCalc.GetTotalCost() fmt.Println(costCalc.DebugPrint()) - require.Equal(t, testCase.expectedStaticCost, costCalc.GetTotalCost()) + require.Equal(t, testCase.expectedStaticCost, gotCost) } } @@ -5549,7 +5550,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }) t.Run("static cost computation", func(t *testing.T) { - t.Run("star wars", func(t *testing.T) { + t.Run("common on star wars scheme", func(t *testing.T) { rootNodes := []plan.TypeField{ {TypeName: "Query", FieldNames: []string{"hero", "droid"}}, {TypeName: "Human", FieldNames: []string{"name", "height", "friends"}}, @@ -6094,24 +6095,24 @@ func TestExecutionEngine_Execute(t *testing.T) { text: String! } ` - schemaUnion, err := graphql.NewSchemaFromString(unionSchema) + schema, err := graphql.NewSchemaFromString(unionSchema) require.NoError(t, err) - unionRootNodes := []plan.TypeField{ + rootNodes := []plan.TypeField{ {TypeName: "Query", FieldNames: []string{"search"}}, {TypeName: "User", FieldNames: []string{"id", "name", "email"}}, {TypeName: "Post", FieldNames: []string{"id", "title", "body"}}, {TypeName: "Comment", FieldNames: []string{"id", "text"}}, } - unionChildNodes := []plan.TypeField{} - unionCustomConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ + childNodes := []plan.TypeField{} + customConfig := mustConfiguration(t, graphql_datasource.ConfigurationInput{ Fetch: &graphql_datasource.FetchConfiguration{ URL: "https://example.com/", Method: "GET", }, SchemaConfiguration: mustSchemaConfig(t, nil, unionSchema), }) - unionFieldConfig := []plan.FieldConfiguration{ + fieldConfig := []plan.FieldConfiguration{ { TypeName: "Query", FieldName: "search", @@ -6124,7 +6125,7 @@ func TestExecutionEngine_Execute(t *testing.T) { t.Run("union with all member types", runWithoutError( ExecutionEngineTestCase{ - schema: schemaUnion, + schema: schema, operation: func(t *testing.T) graphql.Request { return graphql.Request{ Query: `{ @@ -6148,8 +6149,8 @@ func TestExecutionEngine_Execute(t *testing.T) { }), ), &plan.DataSourceMetadata{ - RootNodes: unionRootNodes, - ChildNodes: unionChildNodes, + RootNodes: rootNodes, + ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ {TypeName: "User", FieldName: "name"}: {HasWeight: true, Weight: 2}, @@ -6168,10 +6169,10 @@ func TestExecutionEngine_Execute(t *testing.T) { }, }, }, - unionCustomConfig, + customConfig, ), }, - fields: unionFieldConfig, + fields: fieldConfig, expectedResponse: `{"data":{"search":[{"name":"John","email":"john@test.com"}]}}`, // search listSize: 10 // For each SearchResult, use max across all union members: @@ -6188,7 +6189,7 @@ func TestExecutionEngine_Execute(t *testing.T) { t.Run("union with weighted search field", runWithoutError( ExecutionEngineTestCase{ - schema: schemaUnion, + schema: schema, operation: func(t *testing.T) graphql.Request { return graphql.Request{ Query: `{ @@ -6211,8 +6212,8 @@ func TestExecutionEngine_Execute(t *testing.T) { }), ), &plan.DataSourceMetadata{ - RootNodes: unionRootNodes, - ChildNodes: unionChildNodes, + RootNodes: rootNodes, + ChildNodes: childNodes, CostConfig: &plan.DataSourceCostConfig{ Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ {TypeName: "User", FieldName: "name"}: {HasWeight: true, Weight: 2}, @@ -6227,10 +6228,10 @@ func TestExecutionEngine_Execute(t *testing.T) { }, }, }, - unionCustomConfig, + customConfig, ), }, - fields: unionFieldConfig, + fields: fieldConfig, expectedResponse: `{"data":{"search":[{"name":"John"}]}}`, // Query.search: max(User=10, Post=6) // search listSize: 3 @@ -6245,7 +6246,7 @@ func TestExecutionEngine_Execute(t *testing.T) { )) }) - t.Run("custom scheme for listSize", func(t *testing.T) { + t.Run("listSize", func(t *testing.T) { listSchema := ` type Query { items(first: Int, last: Int): [Item!] @@ -6380,6 +6381,145 @@ func TestExecutionEngine_Execute(t *testing.T) { }, computeStaticCost(), )) + t.Run("slicing argument not provided falls back to assumedSize", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query NoSlicingArg { + items { id } + }`, + // No slicing arguments provided - should fall back to assumedSize + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[{"id":"1"},{"id":"2"}]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 15, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[{"id":"1"},{"id":"2"}]}}`, + expectedStaticCost: 45, // Total: 15 * (2 + 1) + }, + computeStaticCost(), + )) + t.Run("zero slicing argument falls back to assumedSize", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query ZeroSlicing { + items(first: 0) { id } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 20, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[]}}`, + expectedStaticCost: 60, // 20 * (2 + 1) + }, + computeStaticCost(), + )) + t.Run("negative slicing argument falls back to assumedSize", runWithoutError( + ExecutionEngineTestCase{ + schema: schemaSlicing, + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + Query: `query NegativeSlicing { + items(first: -5) { id } + }`, + } + }, + dataSources: []plan.DataSource{ + mustGraphqlDataSourceConfiguration(t, "id", + mustFactory(t, + testNetHttpClient(t, roundTripperTestCase{ + expectedHost: "example.com", expectedPath: "/", expectedBody: "", + sendResponseBody: `{"data":{"items":[]}}`, + sendStatusCode: 200, + }), + ), + &plan.DataSourceMetadata{ + RootNodes: rootNodes, + ChildNodes: childNodes, + CostConfig: &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Item", FieldName: "id"}: {HasWeight: true, Weight: 1}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "items"}: { + AssumedSize: 25, + SlicingArguments: []string{"first", "last"}, + }, + }, + Types: map[string]int{ + "Item": 2, + }, + }, + }, + customConfig, + ), + }, + fields: fieldConfig, + expectedResponse: `{"data":{"items":[]}}`, + expectedStaticCost: 75, // 25 * (2 + 1) + }, + computeStaticCost(), + )) }) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index a446e7c77f..60f914659c 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -252,7 +252,6 @@ func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostC if listSize != nil { multiplier := listSize.multiplier(arguments, vars) if maxListSize == nil || multiplier > maxMultiplier { - fmt.Printf("found better multiplier for %v: %v\n", coord, multiplier) maxMultiplier = multiplier maxListSize = listSize } @@ -490,33 +489,27 @@ func (c *CostCalculator) DebugPrint() string { sb.WriteString("Cost Tree Debug:\n") sb.WriteString("================\n") c.tree.children[0].debugPrint(&sb, c.costConfigs, c.variables, 0) - fmt.Fprintf(&sb, "\nTotal Cost: %d\n", c.GetTotalCost()) return sb.String() } // debugPrint recursively prints a node and its children with indentation. func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, depth int) { + // implementation is a bit crude and redundant, we could skip calculating nodes all over again. + // but it should suffice for debugging tests. if node == nil { return } indent := strings.Repeat(" ", depth) - // Calculate costs for this node - node.setCostsAndMultiplier(configs, variables) - - // Field coordinate info fieldInfo := fmt.Sprintf("%s.%s", node.fieldCoord.TypeName, node.fieldCoord.FieldName) - // Build node info line - fmt.Fprintf(sb, "%s├ %s", indent, fieldInfo) + fmt.Fprintf(sb, "%s* %s", indent, fieldInfo) - // Add type info if node.fieldTypeName != "" { fmt.Fprintf(sb, " -> %s", node.fieldTypeName) } - // Add flags var flags []string if node.returnsListType { flags = append(flags, "list") @@ -532,9 +525,8 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da } sb.WriteString("\n") - // Cost details if node.fieldCost != 0 || node.argumentsCost != 0 || node.multiplier != 0 { - fmt.Fprintf(sb, "%s│ fieldCost=%d, argsCost=%d, multiplier=%d", + fmt.Fprintf(sb, "%s fieldCost=%d, argsCost=%d, multiplier=%d", indent, node.fieldCost, node.argumentsCost, node.multiplier) // Show data sources @@ -544,7 +536,6 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da sb.WriteString("\n") } - // Arguments info if len(node.arguments) > 0 { var argStrs []string for name, arg := range node.arguments { @@ -556,16 +547,16 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da argStrs = append(argStrs, fmt.Sprintf("%s=", name)) } } - fmt.Fprintf(sb, "%s│ args: {%s}\n", indent, strings.Join(argStrs, ", ")) + fmt.Fprintf(sb, "%s args: {%s}\n", indent, strings.Join(argStrs, ", ")) } // Implementing types (for abstract types) if len(node.implementingTypeNames) > 0 { - fmt.Fprintf(sb, "%s│ implements: [%s]\n", indent, strings.Join(node.implementingTypeNames, ", ")) + fmt.Fprintf(sb, "%s implements: [%s]\n", indent, strings.Join(node.implementingTypeNames, ", ")) } subtreeCost := node.totalCost(configs, variables) - fmt.Fprintf(sb, "%s│ subCost=%d\n", indent, subtreeCost) + fmt.Fprintf(sb, "%s subCost=%d\n", indent, subtreeCost) // Print children for _, child := range node.children { From 91473638a2855adc183893a29cdb0ae5f115c8cb Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 20 Jan 2026 17:29:15 +0200 Subject: [PATCH 30/43] prettify debug print --- v2/pkg/engine/plan/static_cost.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 60f914659c..42ba2e757f 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -486,8 +486,8 @@ func (c *CostCalculator) DebugPrint() string { return "" } var sb strings.Builder - sb.WriteString("Cost Tree Debug:\n") - sb.WriteString("================\n") + sb.WriteString("Cost Tree Debug\n") + sb.WriteString("===============\n") c.tree.children[0].debugPrint(&sb, c.costConfigs, c.variables, 0) return sb.String() } @@ -500,14 +500,14 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da return } - indent := strings.Repeat(" ", depth) + indent := strings.Repeat(" ", depth) fieldInfo := fmt.Sprintf("%s.%s", node.fieldCoord.TypeName, node.fieldCoord.FieldName) fmt.Fprintf(sb, "%s* %s", indent, fieldInfo) if node.fieldTypeName != "" { - fmt.Fprintf(sb, " -> %s", node.fieldTypeName) + fmt.Fprintf(sb, " : %s", node.fieldTypeName) } var flags []string @@ -556,7 +556,7 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da } subtreeCost := node.totalCost(configs, variables) - fmt.Fprintf(sb, "%s subCost=%d\n", indent, subtreeCost) + fmt.Fprintf(sb, "%s cost=%d\n", indent, subtreeCost) // Print children for _, child := range node.children { From 8b47cbb1d86420b6666f45b6178dd86aff9a48b8 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 20 Jan 2026 18:05:01 +0200 Subject: [PATCH 31/43] test costs in federation --- execution/engine/execution_engine_test.go | 148 +++++++++++++++++----- 1 file changed, 117 insertions(+), 31 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index a7e6468651..466ee06960 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -339,7 +339,7 @@ func TestExecutionEngine_Execute(t *testing.T) { require.NotNil(t, lastPlan) costCalc := lastPlan.GetStaticCostCalculator() gotCost := costCalc.GetTotalCost() - fmt.Println(costCalc.DebugPrint()) + // fmt.Println(costCalc.DebugPrint()) require.Equal(t, testCase.expectedStaticCost, gotCost) } @@ -4656,16 +4656,49 @@ func TestExecutionEngine_Execute(t *testing.T) { } ` - makeDataSource := func(t *testing.T, expectFetchReasons bool) []plan.DataSource { + type makeDataSourceOpts struct { + expectFetchReasons bool + includeCostConfig bool + } + + makeDataSource := func(t *testing.T, opts makeDataSourceOpts) []plan.DataSource { var expectedBody1 string var expectedBody2 string - if !expectFetchReasons { + if !opts.expectFetchReasons { expectedBody1 = `{"query":"{accounts {__typename ... on User {some {__typename id}} ... on Admin {some {__typename id}}}}"}` } else { expectedBody1 = `{"query":"{accounts {__typename ... on User {some {__typename id}} ... on Admin {some {__typename id}}}}","extensions":{"fetch_reasons":[{"typename":"Admin","field":"some","by_user":true},{"typename":"User","field":"id","by_subgraphs":["id-2"],"by_user":true,"is_key":true}]}}` } expectedBody2 = `{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {__typename title}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"3"}]}}` + // Cost config for DS1 (first subgraph): accounts service + var ds1CostConfig *plan.DataSourceCostConfig + if opts.includeCostConfig { + ds1CostConfig = &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "Query", FieldName: "accounts"}: {HasWeight: true, Weight: 5}, + {TypeName: "User", FieldName: "some"}: {HasWeight: true, Weight: 2}, + {TypeName: "Admin", FieldName: "some"}: {HasWeight: true, Weight: 3}, + }, + ListSizes: map[plan.FieldCoordinate]*plan.FieldListSize{ + {TypeName: "Query", FieldName: "accounts"}: {AssumedSize: 3}, + }, + } + } + + // Cost config for DS2 (second subgraph): extends User/Admin with title + var ds2CostConfig *plan.DataSourceCostConfig + if opts.includeCostConfig { + ds2CostConfig = &plan.DataSourceCostConfig{ + Weights: map[plan.FieldCoordinate]*plan.FieldWeight{ + {TypeName: "User", FieldName: "name"}: {HasWeight: true, Weight: 2}, + {TypeName: "User", FieldName: "title"}: {HasWeight: true, Weight: 4}, + {TypeName: "Admin", FieldName: "adminName"}: {HasWeight: true, Weight: 3}, + {TypeName: "Admin", FieldName: "title"}: {HasWeight: true, Weight: 5}, + }, + } + } + return []plan.DataSource{ mustGraphqlDataSourceConfiguration(t, "id-1", @@ -4701,6 +4734,7 @@ func TestExecutionEngine_Execute(t *testing.T) { FieldNames: []string{"id", "title", "some"}, }, }, + CostConfig: ds1CostConfig, FederationMetaData: plan.FederationMetaData{ Keys: plan.FederationFieldConfigurations{ { @@ -4751,6 +4785,7 @@ func TestExecutionEngine_Execute(t *testing.T) { FieldNames: []string{"id", "adminName", "title"}, }, }, + CostConfig: ds2CostConfig, FederationMetaData: plan.FederationMetaData{ Keys: plan.FederationFieldConfigurations{ { @@ -4793,24 +4828,24 @@ func TestExecutionEngine_Execute(t *testing.T) { return graphql.Request{ OperationName: "Accounts", Query: ` - query Accounts { - accounts { - ... on User { - some { - title + query Accounts { + accounts { + ... on User { + some { + title + } } - } - ... on Admin { - some { - __typename - id + ... on Admin { + some { + __typename + id + } } } - } - }`, + }`, } }, - dataSources: makeDataSource(t, false), + dataSources: makeDataSource(t, makeDataSourceOpts{expectFetchReasons: false}), expectedResponse: `{"data":{"accounts":[{"some":{"title":"User1"}},{"some":{"__typename":"User","id":"2"}},{"some":{"title":"User3"}}]}}`, })) @@ -4826,28 +4861,79 @@ func TestExecutionEngine_Execute(t *testing.T) { return graphql.Request{ OperationName: "Accounts", Query: ` - query Accounts { - accounts { - ... on User { - some { - title - } - } - ... on Admin { - some { - __typename - id + query Accounts { + accounts { + ... on User { + some { + title + } + } + ... on Admin { + some { + __typename + id + } + } } - } - } - }`, + }`, } }, - dataSources: makeDataSource(t, true), + dataSources: makeDataSource(t, makeDataSourceOpts{expectFetchReasons: true}), expectedResponse: `{"data":{"accounts":[{"some":{"title":"User1"}},{"some":{"__typename":"User","id":"2"}},{"some":{"title":"User3"}}]}}`, }, withFetchReasons(), )) + + t.Run("run with static cost computation", runWithoutError( + ExecutionEngineTestCase{ + schema: func(t *testing.T) *graphql.Schema { + t.Helper() + parseSchema, err := graphql.NewSchemaFromString(definition) + require.NoError(t, err) + return parseSchema + }(t), + operation: func(t *testing.T) graphql.Request { + return graphql.Request{ + OperationName: "Accounts", + Query: ` + query Accounts { + accounts { + ... on User { + some { + title + } + } + ... on Admin { + some { + __typename + id + } + } + } + }`, + } + }, + dataSources: makeDataSource(t, makeDataSourceOpts{includeCostConfig: true}), + expectedResponse: `{"data":{"accounts":[{"some":{"title":"User1"}},{"some":{"__typename":"User","id":"2"}},{"some":{"title":"User3"}}]}}`, + // Cost breakdown with federation: + // Query.accounts: fieldCost=5, multiplier=3 (listSize) + // accounts returns interface [Node!]! with implementing types [User, Admin] + // + // Children (per interface member type): + // User.some: User: fieldCost=3 (DS1:2 + DS2:1 summed) + // User.title: 4 (DS2, resolved via _entities federation) + // cost = 3 + 4 = 7 + // + // Admin.some: User: fieldCost=3 (DS1 only) + // cost = 3 + // + // Children total = 7 + 3 = 10 + // (is it possible to improve accuracy here by using the largest fragment instead of the sum?) + // Total = (5 + 10) * 3 = 45 + expectedStaticCost: 45, + }, + computeStaticCost(), + )) }) t.Run("validation of optional @requires dependencies", func(t *testing.T) { From 91da8275e01d2cda2674b78d2593cf4ff8950a89 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:40:50 +0200 Subject: [PATCH 32/43] add comment for a test --- execution/engine/execution_engine_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 466ee06960..bf1ff540e9 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -5945,6 +5945,11 @@ func TestExecutionEngine_Execute(t *testing.T) { }, expectedResponse: `{"data":{"hero":{"friends":[{"name":"Luke Skywalker","height":"12"},{"name":"R2DO","primaryFunction":"joke"}]}}}`, expectedStaticCost: 247, // Query.hero(max(7,5))+ 20 * (7+2+2+1) + // We pick maximum on every path independently. This is to reveal the upper boundary. + // Query.hero: picked maximum weight (Human=7) out of two types (Human, Droid) + // Query.hero.friends: the max possible weight (7) is for implementing class Human + // of the returned type of Character; the multiplier picked for the Droid since + // it is the maximum possible value - we considered the enclosing type that contains it. }, computeStaticCost(), )) From a1716120b50509628df88ad0331f78b1ceb2b556 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Thu, 22 Jan 2026 12:54:38 +0200 Subject: [PATCH 33/43] use static instead of total cost --- execution/engine/execution_engine_test.go | 2 +- v2/pkg/engine/plan/static_cost.go | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index bf1ff540e9..c4db96d922 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -338,7 +338,7 @@ func TestExecutionEngine_Execute(t *testing.T) { lastPlan := engine.lastPlan require.NotNil(t, lastPlan) costCalc := lastPlan.GetStaticCostCalculator() - gotCost := costCalc.GetTotalCost() + gotCost := costCalc.GetStaticCost() // fmt.Println(costCalc.DebugPrint()) require.Equal(t, testCase.expectedStaticCost, gotCost) } diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 42ba2e757f..7b3a8d6a02 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -260,8 +260,8 @@ func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostC return maxListSize } -// totalCost calculates the total cost of this node and all descendants -func (node *CostTreeNode) totalCost(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value) int { +// staticCost calculates the static cost of this node and all descendants +func (node *CostTreeNode) staticCost(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value) int { if node == nil { return 0 } @@ -271,7 +271,7 @@ func (node *CostTreeNode) totalCost(configs map[DSHash]*DataSourceCostConfig, va // Sum children (fields) costs var childrenCost int for _, child := range node.children { - childrenCost += child.totalCost(configs, variables) + childrenCost += child.staticCost(configs, variables) } // Apply multiplier to children cost (for list fields) @@ -474,9 +474,9 @@ func (c *CostCalculator) SetVariables(variables *astjson.Value) { c.variables = variables } -// GetTotalCost returns the calculated total cost. -func (c *CostCalculator) GetTotalCost() int { - return c.tree.totalCost(c.costConfigs, c.variables) +// GetStaticCost returns the calculated total static cost. +func (c *CostCalculator) GetStaticCost() int { + return c.tree.staticCost(c.costConfigs, c.variables) } // DebugPrint prints the cost tree structure for debugging purposes. @@ -555,7 +555,7 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da fmt.Fprintf(sb, "%s implements: [%s]\n", indent, strings.Join(node.implementingTypeNames, ", ")) } - subtreeCost := node.totalCost(configs, variables) + subtreeCost := node.staticCost(configs, variables) fmt.Fprintf(sb, "%s cost=%d\n", indent, subtreeCost) // Print children From da65037b6c4422a13e1860c92d3c8e5f855e968b Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Fri, 23 Jan 2026 12:32:16 +0200 Subject: [PATCH 34/43] use default cost config when nothing supplied --- v2/pkg/engine/plan/static_cost.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 7b3a8d6a02..5a43658c58 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -322,8 +322,9 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo for _, dsHash := range node.dataSourceHashes { dsCostConfig, ok := configs[dsHash] if !ok { - fmt.Printf("WARNING: no cost dsCostConfig for data source %v\n", dsHash) - continue + dsCostConfig = &DataSourceCostConfig{} + // Save it for later use by other fields: + configs[dsHash] = dsCostConfig } fieldWeight := dsCostConfig.Weights[node.fieldCoord] From f9b95044f173e5a291f6c99571b84a7993072915 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Fri, 23 Jan 2026 15:36:43 +0200 Subject: [PATCH 35/43] do not use globals to configure default weights --- execution/engine/execution_engine_test.go | 1 + v2/pkg/engine/plan/configuration.go | 1 + v2/pkg/engine/plan/planner.go | 2 +- v2/pkg/engine/plan/static_cost.go | 64 +++++++++++------------ 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index c4db96d922..ef7db37259 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -288,6 +288,7 @@ func TestExecutionEngine_Execute(t *testing.T) { engineConf.plannerConfig.BuildFetchReasons = opts.propagateFetchReasons engineConf.plannerConfig.ValidateRequiredExternalFields = opts.validateRequiredExternalFields engineConf.plannerConfig.ComputeStaticCost = opts.computeStaticCost + engineConf.plannerConfig.StaticCostDefaultListSize = 10 resolveOpts := resolve.ResolverOptions{ MaxConcurrency: 1024, ResolvableOptions: opts.resolvableOptions, diff --git a/v2/pkg/engine/plan/configuration.go b/v2/pkg/engine/plan/configuration.go index ead65c0596..5021cd215d 100644 --- a/v2/pkg/engine/plan/configuration.go +++ b/v2/pkg/engine/plan/configuration.go @@ -48,6 +48,7 @@ type Configuration struct { ValidateRequiredExternalFields bool ComputeStaticCost bool + StaticCostDefaultListSize int } type DebugConfiguration struct { diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index 426b0e5374..fc737a9751 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -208,7 +208,7 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string if p.config.ComputeStaticCost { // Initialize cost calculator and configure from data sources - costCalc := NewCostCalculator() + costCalc := NewCostCalculator(p.config.StaticCostDefaultListSize) for _, ds := range p.config.DataSources { if costConfig := ds.GetCostConfig(); costConfig != nil { costCalc.SetDataSourceCostConfig(ds.Hash(), costConfig) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 5a43658c58..5f426198bf 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -33,19 +33,11 @@ import ( "github.com/wundergraph/astjson" ) -// StaticCostDefaults contains default cost values when no specific costs are configured -var StaticCostDefaults = WeightDefaults{ - EnumScalar: 0, - Object: 1, - List: 10, // The assumed maximum size of a list for fields that return lists. -} +// We don't allow configuring default weights for enums, scalars and objects. +// But they could be in the future. -// WeightDefaults defines default cost values for different GraphQL elements -type WeightDefaults struct { - EnumScalar int - Object int - List int -} +const DefaultEnumScalarWeight = 0 +const DefaultObjectWeight = 1 // FieldWeight defines cost configuration for a specific field of an object or input object. type FieldWeight struct { @@ -84,9 +76,10 @@ type FieldListSize struct { // multiplier returns the multiplier based on arguments and variables. // It picks the maximum value among slicing arguments, otherwise it tries to use AssumedSize. +// If neither is available, it falls back to defaultListSize. // // Does not take into account the SizedFields; TBD later. -func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo, vars *astjson.Value) int { +func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo, vars *astjson.Value, defaultListSize int) int { multiplier := -1 for _, slicingArg := range ls.SlicingArguments { arg, ok := arguments[slicingArg] @@ -117,7 +110,7 @@ func (ls *FieldListSize) multiplier(arguments map[string]ArgumentInfo, vars *ast multiplier = ls.AssumedSize } if multiplier == -1 { - multiplier = StaticCostDefaults.List + multiplier = defaultListSize } return multiplier } @@ -164,18 +157,18 @@ func (c *DataSourceCostConfig) EnumScalarTypeWeight(enumName string) int { if cost, ok := c.Types[enumName]; ok { return cost } - return StaticCostDefaults.EnumScalar + return DefaultEnumScalarWeight } // ObjectTypeWeight returns the default object cost func (c *DataSourceCostConfig) ObjectTypeWeight(name string) int { if c == nil { - return StaticCostDefaults.Object + return DefaultObjectWeight } if cost, ok := c.Types[name]; ok { return cost } - return StaticCostDefaults.Object + return DefaultObjectWeight } // CostTreeNode represents a node in the cost calculation tree @@ -205,7 +198,7 @@ type CostTreeNode struct { // The data below is stored for deferred cost calculation. // We populate these fields in EnterField and use them as a source of truth in LeaveField. - // fieldRef is the AST field reference + // fieldRef is the AST field reference. Used by the visitor to build the tree. fieldRef int // Enclosing type name and field name @@ -242,7 +235,7 @@ func (node *CostTreeNode) maxWeightImplementingField(config *DataSourceCostConfi return maxWeight } -func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostConfig, fieldName string, arguments map[string]ArgumentInfo, vars *astjson.Value) *FieldListSize { +func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostConfig, fieldName string, arguments map[string]ArgumentInfo, vars *astjson.Value, defaultListSize int) *FieldListSize { var maxMultiplier int var maxListSize *FieldListSize for _, implTypeName := range node.implementingTypeNames { @@ -250,7 +243,7 @@ func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostC listSize := config.ListSizes[coord] if listSize != nil { - multiplier := listSize.multiplier(arguments, vars) + multiplier := listSize.multiplier(arguments, vars, defaultListSize) if maxListSize == nil || multiplier > maxMultiplier { maxMultiplier = multiplier maxListSize = listSize @@ -261,17 +254,17 @@ func (node *CostTreeNode) maxMultiplierImplementingField(config *DataSourceCostC } // staticCost calculates the static cost of this node and all descendants -func (node *CostTreeNode) staticCost(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value) int { +func (node *CostTreeNode) staticCost(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, defaultListSize int) int { if node == nil { return 0 } - node.setCostsAndMultiplier(configs, variables) + node.setCostsAndMultiplier(configs, variables, defaultListSize) // Sum children (fields) costs var childrenCost int for _, child := range node.children { - childrenCost += child.staticCost(configs, variables) + childrenCost += child.staticCost(configs, variables, defaultListSize) } // Apply multiplier to children cost (for list fields) @@ -308,7 +301,7 @@ func (node *CostTreeNode) staticCost(configs map[DSHash]*DataSourceCostConfig, v // // For the multiplier we pick the maximum field weight of implementing types and then // the maximum among slicing arguments. -func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value) { +func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, defaultListSize int) { if len(node.dataSourceHashes) <= 0 { // no data source is responsible for this field return @@ -345,7 +338,7 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo fieldWeight = parent.maxWeightImplementingField(dsCostConfig, node.fieldCoord.FieldName) // If this field has listSize defined, then do not look into implementing types. if listSize == nil && node.returnsListType { - listSize = parent.maxMultiplierImplementingField(dsCostConfig, node.fieldCoord.FieldName, node.arguments, variables) + listSize = parent.maxMultiplierImplementingField(dsCostConfig, node.fieldCoord.FieldName, node.arguments, variables, defaultListSize) } } @@ -399,7 +392,7 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo // Compute multiplier as the maximum of data sources. if listSize != nil { - multiplier := listSize.multiplier(node.arguments, variables) + multiplier := listSize.multiplier(node.arguments, variables, defaultListSize) // If this node returns a list of abstract types, then it could have listSize defined. // Spec allows defining listSize on the fields of interfaces. if multiplier > node.multiplier { @@ -410,7 +403,7 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo } if node.multiplier == 0 && node.returnsListType { - node.multiplier = StaticCostDefaults.List + node.multiplier = defaultListSize } } @@ -456,12 +449,15 @@ type CostCalculator struct { // variables are passed by the resolver's context. variables *astjson.Value + + defaultListSize int } // NewCostCalculator creates a new cost calculator -func NewCostCalculator() *CostCalculator { +func NewCostCalculator(defaultListSize int) *CostCalculator { c := CostCalculator{ - costConfigs: make(map[DSHash]*DataSourceCostConfig), + costConfigs: make(map[DSHash]*DataSourceCostConfig), + defaultListSize: defaultListSize, } return &c } @@ -477,7 +473,7 @@ func (c *CostCalculator) SetVariables(variables *astjson.Value) { // GetStaticCost returns the calculated total static cost. func (c *CostCalculator) GetStaticCost() int { - return c.tree.staticCost(c.costConfigs, c.variables) + return c.tree.staticCost(c.costConfigs, c.variables, c.defaultListSize) } // DebugPrint prints the cost tree structure for debugging purposes. @@ -489,12 +485,12 @@ func (c *CostCalculator) DebugPrint() string { var sb strings.Builder sb.WriteString("Cost Tree Debug\n") sb.WriteString("===============\n") - c.tree.children[0].debugPrint(&sb, c.costConfigs, c.variables, 0) + c.tree.children[0].debugPrint(&sb, c.costConfigs, c.variables, c.defaultListSize, 0) return sb.String() } // debugPrint recursively prints a node and its children with indentation. -func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, depth int) { +func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, defaultListSize int, depth int) { // implementation is a bit crude and redundant, we could skip calculating nodes all over again. // but it should suffice for debugging tests. if node == nil { @@ -556,11 +552,11 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da fmt.Fprintf(sb, "%s implements: [%s]\n", indent, strings.Join(node.implementingTypeNames, ", ")) } - subtreeCost := node.staticCost(configs, variables) + subtreeCost := node.staticCost(configs, variables, defaultListSize) fmt.Fprintf(sb, "%s cost=%d\n", indent, subtreeCost) // Print children for _, child := range node.children { - child.debugPrint(sb, configs, variables, depth+1) + child.debugPrint(sb, configs, variables, defaultListSize, depth+1) } } From 912d0da7fa627949c6836e786d7ffd8408be3ce6 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Fri, 23 Jan 2026 15:42:53 +0200 Subject: [PATCH 36/43] make lint --- v2/pkg/engine/plan/configuration.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/v2/pkg/engine/plan/configuration.go b/v2/pkg/engine/plan/configuration.go index 5021cd215d..cff8109235 100644 --- a/v2/pkg/engine/plan/configuration.go +++ b/v2/pkg/engine/plan/configuration.go @@ -47,7 +47,10 @@ type Configuration struct { // This option requires BuildFetchReasons set to true. ValidateRequiredExternalFields bool + // ComputeStaticCost enables static cost computation for operations. ComputeStaticCost bool + + // When the list size is unknown from directives, this value is used as a default for static cost. StaticCostDefaultListSize int } From 9d91d7042423d533469448d2e55de143ffecbb19 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:14:25 +0200 Subject: [PATCH 37/43] floor the defaultListSize --- v2/pkg/engine/plan/static_cost.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 5f426198bf..6f203a7877 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -453,8 +453,11 @@ type CostCalculator struct { defaultListSize int } -// NewCostCalculator creates a new cost calculator +// NewCostCalculator creates a new cost calculator. The defaultListSize is floored to 1. func NewCostCalculator(defaultListSize int) *CostCalculator { + if defaultListSize < 1 { + defaultListSize = 1 + } c := CostCalculator{ costConfigs: make(map[DSHash]*DataSourceCostConfig), defaultListSize: defaultListSize, From f6504e4cbf30d3232776f27e343fc85b735a1f2d Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:48:07 +0200 Subject: [PATCH 38/43] protect from nil --- v2/pkg/engine/plan/static_cost.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 6f203a7877..0371623053 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -314,7 +314,7 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo for _, dsHash := range node.dataSourceHashes { dsCostConfig, ok := configs[dsHash] - if !ok { + if !ok || dsCostConfig == nil { dsCostConfig = &DataSourceCostConfig{} // Save it for later use by other fields: configs[dsHash] = dsCostConfig From 5591872d9431362345ee94a1d5413cfcca675c32 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:20:38 +0200 Subject: [PATCH 39/43] fix the data-race for calculator for multiple requests --- execution/engine/execution_engine.go | 4 +- execution/engine/execution_engine_test.go | 4 +- execution/graphql/request.go | 12 ++ v2/pkg/engine/plan/planner.go | 10 +- v2/pkg/engine/plan/static_cost.go | 135 +++++++++++----------- v2/pkg/engine/plan/static_cost_visitor.go | 2 - 6 files changed, 83 insertions(+), 84 deletions(-) diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index c4a2695128..19d487d81a 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -215,7 +215,9 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques } e.lastPlan = cachedPlan if costCalculator != nil { - costCalculator.SetVariables(execContext.resolveContext.Variables) + operation.ComputerStaticCost(costCalculator, e.config.plannerConfig, execContext.resolveContext.Variables) + // Debugging of cost trees. Do not remove. + // fmt.Println(costCalculator.DebugPrint(e.config.plannerConfig, execContext.resolveContext.Variables)) } if execContext.resolveContext.TracingOptions.Enable && !execContext.resolveContext.TracingOptions.ExcludePlannerStats { diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index ef7db37259..3356580360 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -338,9 +338,7 @@ func TestExecutionEngine_Execute(t *testing.T) { if testCase.expectedStaticCost != 0 { lastPlan := engine.lastPlan require.NotNil(t, lastPlan) - costCalc := lastPlan.GetStaticCostCalculator() - gotCost := costCalc.GetStaticCost() - // fmt.Println(costCalc.DebugPrint()) + gotCost := operation.StaticCost() require.Equal(t, testCase.expectedStaticCost, gotCost) } diff --git a/execution/graphql/request.go b/execution/graphql/request.go index a3ab0888d0..585350e285 100644 --- a/execution/graphql/request.go +++ b/execution/graphql/request.go @@ -6,8 +6,10 @@ import ( "io" "net/http" + "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) @@ -42,6 +44,8 @@ type Request struct { request resolve.Request validForSchema map[uint64]ValidationResult + + staticCost int } func UnmarshalRequest(reader io.Reader, request *Request) error { @@ -189,3 +193,11 @@ func (r *Request) OperationType() (OperationType, error) { return OperationTypeUnknown, nil } + +func (r *Request) ComputerStaticCost(calc *plan.CostCalculator, config plan.Configuration, variables *astjson.Value) { + r.staticCost = calc.GetStaticCost(config, variables) +} + +func (r *Request) StaticCost() int { + return r.staticCost +} diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index fc737a9751..f25c3e02c3 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -207,15 +207,7 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string } if p.config.ComputeStaticCost { - // Initialize cost calculator and configure from data sources - costCalc := NewCostCalculator(p.config.StaticCostDefaultListSize) - for _, ds := range p.config.DataSources { - if costConfig := ds.GetCostConfig(); costConfig != nil { - costCalc.SetDataSourceCostConfig(ds.Hash(), costConfig) - } - } - // The root tree pointing to the costTreeNode is the ultimate result of costVisitor. - // Store is as part of this plan for later, should be part of the cached plan too. + costCalc := NewCostCalculator() costCalc.tree = p.costVisitor.finalCostTree() p.planningVisitor.plan.SetStaticCostCalculator(costCalc) } diff --git a/v2/pkg/engine/plan/static_cost.go b/v2/pkg/engine/plan/static_cost.go index 0371623053..805deaa8f7 100644 --- a/v2/pkg/engine/plan/static_cost.go +++ b/v2/pkg/engine/plan/static_cost.go @@ -179,19 +179,6 @@ type CostTreeNode struct { // dataSourceHashes identifies which data sources resolve this field. dataSourceHashes []DSHash - // fieldCost is the weight of this field or its returned type - fieldCost int - - // argumentsCost is the sum of argument weights and input fields used on this field. - argumentsCost int - - // Weights on directives ignored for now. - directivesCost int - - // multiplier is the list size multiplier from @listSize directive - // Applied to children costs for list fields - multiplier int - // children contain child field costs children []*CostTreeNode @@ -259,7 +246,7 @@ func (node *CostTreeNode) staticCost(configs map[DSHash]*DataSourceCostConfig, v return 0 } - node.setCostsAndMultiplier(configs, variables, defaultListSize) + fieldCost, argsCost, directivesCost, multiplier := node.costsAndMultiplier(configs, variables, defaultListSize) // Sum children (fields) costs var childrenCost int @@ -268,11 +255,10 @@ func (node *CostTreeNode) staticCost(configs map[DSHash]*DataSourceCostConfig, v } // Apply multiplier to children cost (for list fields) - multiplier := node.multiplier if multiplier == 0 { multiplier = 1 } - cost := node.argumentsCost + node.directivesCost + cost := argsCost + directivesCost if cost < 0 { // If arguments and directive weights decrease the field cost, floor it to zero. cost = 0 @@ -288,29 +274,33 @@ func (node *CostTreeNode) staticCost(configs map[DSHash]*DataSourceCostConfig, v // "A: [Obj] @cost(weight: 5)" means that the cost of the field is 5 for each object in the list. // "type Object @cost(weight: 5) { ... }" does exactly the same thing. // Weight defined on a field has priority over the weight defined on a type. - cost += (node.fieldCost + childrenCost) * multiplier + cost += (fieldCost + childrenCost) * multiplier return cost } -// setCostsAndMultiplier fills in the cost values for a node based on its data sources. +// costsAndMultiplier fills in the cost values for a node based on its data sources. // // For this node we sum weights of the field or its returned type for all the data sources. // Each data source can have its own cost configuration. If we plan field on two data sources, // it means more work for the router: we should sum the costs. // +// fieldCost is the weight of this field or its returned type +// argsCost is the sum of argument weights and input fields used on this field. +// Weights on directives ignored for now. // For the multiplier we pick the maximum field weight of implementing types and then // the maximum among slicing arguments. -func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, defaultListSize int) { +func (node *CostTreeNode) costsAndMultiplier(configs map[DSHash]*DataSourceCostConfig, variables *astjson.Value, defaultListSize int) (fieldCost, argsCost, directiveCost, multiplier int) { if len(node.dataSourceHashes) <= 0 { // no data source is responsible for this field return } parent := node.parent - node.fieldCost = 0 - node.argumentsCost = 0 - node.multiplier = 0 + fieldCost = 0 + argsCost = 0 + directiveCost = 0 + multiplier = 0 for _, dsHash := range node.dataSourceHashes { dsCostConfig, ok := configs[dsHash] @@ -343,12 +333,12 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo } if fieldWeight != nil && fieldWeight.HasWeight { - node.fieldCost += fieldWeight.Weight + fieldCost += fieldWeight.Weight } else { // Use the weight of the type returned by this field switch { case node.returnsSimpleType: - node.fieldCost += dsCostConfig.EnumScalarTypeWeight(node.fieldTypeName) + fieldCost += dsCostConfig.EnumScalarTypeWeight(node.fieldTypeName) case node.returnsAbstractType: // For the abstract field, find the max weight among all implementing types maxWeight := 0 @@ -358,16 +348,16 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo maxWeight = weight } } - node.fieldCost += maxWeight + fieldCost += maxWeight default: - node.fieldCost += dsCostConfig.ObjectTypeWeight(node.fieldTypeName) + fieldCost += dsCostConfig.ObjectTypeWeight(node.fieldTypeName) } } for argName, arg := range node.arguments { if fieldWeight != nil { if weight, ok := fieldWeight.ArgumentWeights[argName]; ok { - node.argumentsCost += weight + argsCost += weight continue } } @@ -375,11 +365,11 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo // If the argument definition itself does not have weight attached, // but the type of the argument does have weight attached to it. if arg.isSimple { - node.argumentsCost += dsCostConfig.EnumScalarTypeWeight(arg.typeName) + argsCost += dsCostConfig.EnumScalarTypeWeight(arg.typeName) } else if arg.isInputObject { // TODO: arguments should include costs of input object fields } else { - node.argumentsCost += dsCostConfig.ObjectTypeWeight(arg.typeName) + argsCost += dsCostConfig.ObjectTypeWeight(arg.typeName) } } @@ -392,19 +382,20 @@ func (node *CostTreeNode) setCostsAndMultiplier(configs map[DSHash]*DataSourceCo // Compute multiplier as the maximum of data sources. if listSize != nil { - multiplier := listSize.multiplier(node.arguments, variables, defaultListSize) + localMultiplier := listSize.multiplier(node.arguments, variables, defaultListSize) // If this node returns a list of abstract types, then it could have listSize defined. // Spec allows defining listSize on the fields of interfaces. - if multiplier > node.multiplier { - node.multiplier = multiplier + if localMultiplier > multiplier { + multiplier = localMultiplier } } } - if node.multiplier == 0 && node.returnsListType { - node.multiplier = defaultListSize + if multiplier == 0 && node.returnsListType { + multiplier = defaultListSize } + return } type ArgumentInfo struct { @@ -441,54 +432,57 @@ type ArgumentInfo struct { // CostCalculator manages cost calculation during AST traversal type CostCalculator struct { - // tree points to the root of the complete cost tree. + // tree points to the root of the complete cost tree. Calculator tree is built once per query, + // then it is cached as part of the plan cache and + // not supposed to change again even during lifetime of a process. tree *CostTreeNode - - // costConfigs maps data source hash to its cost configuration - costConfigs map[DSHash]*DataSourceCostConfig - - // variables are passed by the resolver's context. - variables *astjson.Value - - defaultListSize int } // NewCostCalculator creates a new cost calculator. The defaultListSize is floored to 1. -func NewCostCalculator(defaultListSize int) *CostCalculator { - if defaultListSize < 1 { - defaultListSize = 1 - } - c := CostCalculator{ - costConfigs: make(map[DSHash]*DataSourceCostConfig), - defaultListSize: defaultListSize, - } +func NewCostCalculator() *CostCalculator { + c := CostCalculator{} return &c } -// SetDataSourceCostConfig sets the cost config for a specific data source -func (c *CostCalculator) SetDataSourceCostConfig(dsHash DSHash, config *DataSourceCostConfig) { - c.costConfigs[dsHash] = config -} - -func (c *CostCalculator) SetVariables(variables *astjson.Value) { - c.variables = variables -} - // GetStaticCost returns the calculated total static cost. -func (c *CostCalculator) GetStaticCost() int { - return c.tree.staticCost(c.costConfigs, c.variables, c.defaultListSize) +// config should be static per process or instance. variables could change between requests. +func (c *CostCalculator) GetStaticCost(config Configuration, variables *astjson.Value) int { + // costConfigs maps data source hash to its cost configuration. At the runtime we do not change + // this at all. It could be set once per router process. + costConfigs := make(map[DSHash]*DataSourceCostConfig) + for _, ds := range config.DataSources { + if costConfig := ds.GetCostConfig(); costConfig != nil { + costConfigs[ds.Hash()] = costConfig + } + } + defaultListSize := config.StaticCostDefaultListSize + if defaultListSize < 1 { + defaultListSize = 1 + } + return c.tree.staticCost(costConfigs, variables, defaultListSize) + } // DebugPrint prints the cost tree structure for debugging purposes. // It shows each node's field coordinate, costs, multipliers, and computed totals. -func (c *CostCalculator) DebugPrint() string { +func (c *CostCalculator) DebugPrint(config Configuration, variables *astjson.Value) string { if c.tree == nil || len(c.tree.children) == 0 { return "" } var sb strings.Builder sb.WriteString("Cost Tree Debug\n") sb.WriteString("===============\n") - c.tree.children[0].debugPrint(&sb, c.costConfigs, c.variables, c.defaultListSize, 0) + costConfigs := make(map[DSHash]*DataSourceCostConfig) + for _, ds := range config.DataSources { + if costConfig := ds.GetCostConfig(); costConfig != nil { + costConfigs[ds.Hash()] = costConfig + } + } + defaultListSize := config.StaticCostDefaultListSize + if defaultListSize < 1 { + defaultListSize = 1 + } + c.tree.children[0].debugPrint(&sb, costConfigs, variables, defaultListSize, 0) return sb.String() } @@ -525,9 +519,11 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da } sb.WriteString("\n") - if node.fieldCost != 0 || node.argumentsCost != 0 || node.multiplier != 0 { - fmt.Fprintf(sb, "%s fieldCost=%d, argsCost=%d, multiplier=%d", - indent, node.fieldCost, node.argumentsCost, node.multiplier) + // Compute costs for this node to display in debug output + fieldCost, argsCost, dirsCost, multiplier := node.costsAndMultiplier(configs, variables, defaultListSize) + if fieldCost != 0 || argsCost != 0 || dirsCost != 0 || multiplier != 0 { + fmt.Fprintf(sb, "%s fieldCost=%d, argsCost=%d, directivesCost=%d, multiplier=%d", + indent, fieldCost, argsCost, dirsCost, multiplier) // Show data sources if len(node.dataSourceHashes) > 0 { @@ -550,15 +546,16 @@ func (node *CostTreeNode) debugPrint(sb *strings.Builder, configs map[DSHash]*Da fmt.Fprintf(sb, "%s args: {%s}\n", indent, strings.Join(argStrs, ", ")) } - // Implementing types (for abstract types) if len(node.implementingTypeNames) > 0 { fmt.Fprintf(sb, "%s implements: [%s]\n", indent, strings.Join(node.implementingTypeNames, ", ")) } + // This is somewhat redundant, but it should not be used in production. + // If there is a need to present cost tree to the user, + // printing should be embedded into the tree calculation process. subtreeCost := node.staticCost(configs, variables, defaultListSize) fmt.Fprintf(sb, "%s cost=%d\n", indent, subtreeCost) - // Print children for _, child := range node.children { child.debugPrint(sb, configs, variables, defaultListSize, depth+1) } diff --git a/v2/pkg/engine/plan/static_cost_visitor.go b/v2/pkg/engine/plan/static_cost_visitor.go index 3d8fbbb571..a566e6ae93 100644 --- a/v2/pkg/engine/plan/static_cost_visitor.go +++ b/v2/pkg/engine/plan/static_cost_visitor.go @@ -35,7 +35,6 @@ func NewStaticCostVisitor(walker *astvisitor.Walker, operation, definition *ast. stack := make([]*CostTreeNode, 0, 16) rootNode := CostTreeNode{ fieldCoord: FieldCoordinate{"_none", "_root"}, - multiplier: 1, } stack = append(stack, &rootNode) return &StaticCostVisitor{ @@ -92,7 +91,6 @@ func (v *StaticCostVisitor) EnterField(fieldRef int) { node := CostTreeNode{ fieldRef: fieldRef, fieldCoord: FieldCoordinate{typeName, fieldName}, - multiplier: 1, fieldTypeName: unwrappedTypeName, implementingTypeNames: implementingTypeNames, returnsListType: isListType, From 41087f073129b3baafa43f80a81142a2db0e7059 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:34:04 +0200 Subject: [PATCH 40/43] make lint --- execution/graphql/request.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/execution/graphql/request.go b/execution/graphql/request.go index 585350e285..d1cf90f857 100644 --- a/execution/graphql/request.go +++ b/execution/graphql/request.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" @@ -45,7 +46,7 @@ type Request struct { validForSchema map[uint64]ValidationResult - staticCost int + staticCost int } func UnmarshalRequest(reader io.Reader, request *Request) error { From 513375c7aaf2006af9966dd0c34a1c99b1af0a71 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:34:54 +0200 Subject: [PATCH 41/43] fix typo --- execution/engine/execution_engine.go | 2 +- execution/graphql/request.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index 19d487d81a..51c7dbb78c 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -215,7 +215,7 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques } e.lastPlan = cachedPlan if costCalculator != nil { - operation.ComputerStaticCost(costCalculator, e.config.plannerConfig, execContext.resolveContext.Variables) + operation.ComputeStaticCost(costCalculator, e.config.plannerConfig, execContext.resolveContext.Variables) // Debugging of cost trees. Do not remove. // fmt.Println(costCalculator.DebugPrint(e.config.plannerConfig, execContext.resolveContext.Variables)) } diff --git a/execution/graphql/request.go b/execution/graphql/request.go index d1cf90f857..e141e7673f 100644 --- a/execution/graphql/request.go +++ b/execution/graphql/request.go @@ -195,7 +195,7 @@ func (r *Request) OperationType() (OperationType, error) { return OperationTypeUnknown, nil } -func (r *Request) ComputerStaticCost(calc *plan.CostCalculator, config plan.Configuration, variables *astjson.Value) { +func (r *Request) ComputeStaticCost(calc *plan.CostCalculator, config plan.Configuration, variables *astjson.Value) { r.staticCost = calc.GetStaticCost(config, variables) } From 2db07ae63d8d5c85833ecd6fb0d2028781b2a830 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:49:18 +0200 Subject: [PATCH 42/43] simplify execution part --- execution/engine/execution_engine.go | 11 +++-------- execution/engine/execution_engine_test.go | 2 -- execution/graphql/request.go | 6 +++++- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index 51c7dbb78c..8b57b9c11d 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -62,8 +62,6 @@ type ExecutionEngine struct { resolver *resolve.Resolver executionPlanCache *lru.Cache apolloCompatibilityFlags apollocompatibility.Flags - // Holds the plan after Execute(). Used in testing. - lastPlan plan.Plan } type WebsocketBeforeStartHook interface { @@ -213,12 +211,9 @@ func (e *ExecutionEngine) Execute(ctx context.Context, operation *graphql.Reques if report.HasErrors() { return report } - e.lastPlan = cachedPlan - if costCalculator != nil { - operation.ComputeStaticCost(costCalculator, e.config.plannerConfig, execContext.resolveContext.Variables) - // Debugging of cost trees. Do not remove. - // fmt.Println(costCalculator.DebugPrint(e.config.plannerConfig, execContext.resolveContext.Variables)) - } + operation.ComputeStaticCost(costCalculator, e.config.plannerConfig, execContext.resolveContext.Variables) + // Debugging of cost trees. Do not remove. + // fmt.Println(costCalculator.DebugPrint(e.config.plannerConfig, execContext.resolveContext.Variables)) if execContext.resolveContext.TracingOptions.Enable && !execContext.resolveContext.TracingOptions.ExcludePlannerStats { planningTime := resolve.GetDurationNanoSinceTraceStart(execContext.resolveContext.Context()) - tracePlanStart diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 3356580360..9ac77cd131 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -336,8 +336,6 @@ func TestExecutionEngine_Execute(t *testing.T) { } if testCase.expectedStaticCost != 0 { - lastPlan := engine.lastPlan - require.NotNil(t, lastPlan) gotCost := operation.StaticCost() require.Equal(t, testCase.expectedStaticCost, gotCost) } diff --git a/execution/graphql/request.go b/execution/graphql/request.go index e141e7673f..85a7051d80 100644 --- a/execution/graphql/request.go +++ b/execution/graphql/request.go @@ -196,7 +196,11 @@ func (r *Request) OperationType() (OperationType, error) { } func (r *Request) ComputeStaticCost(calc *plan.CostCalculator, config plan.Configuration, variables *astjson.Value) { - r.staticCost = calc.GetStaticCost(config, variables) + if calc != nil { + r.staticCost = calc.GetStaticCost(config, variables) + } else { + r.staticCost = 0 + } } func (r *Request) StaticCost() int { From e885f66e9302e6f8605b662cf1879bee33f233c0 Mon Sep 17 00:00:00 2001 From: Yury Smolski <140245+ysmolski@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:25:45 +0200 Subject: [PATCH 43/43] clarify comment --- v2/pkg/engine/plan/planner.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/plan/planner.go b/v2/pkg/engine/plan/planner.go index f25c3e02c3..73df967228 100644 --- a/v2/pkg/engine/plan/planner.go +++ b/v2/pkg/engine/plan/planner.go @@ -159,7 +159,10 @@ func (p *Planner) Plan(operation, definition *ast.Document, operationName string p.planningWalker.RegisterEnterDirectiveVisitor(p.planningVisitor) p.planningWalker.RegisterInlineFragmentVisitor(p.planningVisitor) - // Register cost visitor on the same walker (will be invoked after planningVisitor hooks) + // Register cost visitor on the same walker (will be invoked after planningVisitor hooks). + // We have to register it last in the walker, as it depends on the fieldPlanners field of the + // visitor. That field is populated in the AllowVisitor callback. Walker calls Enter* callbacks + // in the order they were registered, and Leave* callbacks in the reverse order. if p.config.ComputeStaticCost { p.costVisitor = NewStaticCostVisitor(p.planningWalker, operation, definition) p.costVisitor.planners = plannersConfigurations