Skip to content

Commit 2ca7b28

Browse files
authored
feat(router): support input objects in Cost Control (#2716)
This PR adds support of input object fields passed as arguments. It handles nested input objects, recursive types, lists of input objects and list-typed arguments. Negative weights on input fields reduce the cost, but clamped to zero for each field.
1 parent 956e27e commit 2ca7b28

File tree

11 files changed

+171
-62
lines changed

11 files changed

+171
-62
lines changed

demo/pkg/subgraphs/employees/subgraph/generated/generated.go

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

demo/pkg/subgraphs/employees/subgraph/schema.graphqls

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,6 @@ type WorkSetup @shareable {
277277

278278
input FindEmployeeCriteria @oneOf {
279279
id: Int
280-
department: Department
281-
title: String
280+
department: Department @cost(weight: 17)
281+
title: String @cost(weight: -3) # totally made-up example for testing
282282
}

docs-website/router/security/cost-control.mdx

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,25 @@ type Query {
142142
}
143143
```
144144

145+
When specified **on input object fields** — weights on nested input fields are accumulated
146+
based on which fields the client provides in the query.
147+
Only the fields present in the input contribute to the cost:
148+
149+
```graphql
150+
input FindEmployeeCriteria {
151+
id: Int
152+
department: Department @cost(weight: 4)
153+
title: String @cost(weight: 3)
154+
}
155+
156+
type Query {
157+
findEmployeesBy(criteria: FindEmployeeCriteria): [Employee]
158+
}
159+
```
160+
161+
Input object weights are evaluated per request, not cached with the query plan.
162+
Two requests using the same query but different input fields produce different cost estimates.
163+
145164
When specified **on a field returning a list** — the list size multiplies the weight of this field:
146165

147166
```graphql
@@ -422,7 +441,5 @@ Use `@listSize` to provide realistic estimates for deeply nested structures.
422441

423442
## Features Not Yet Implemented
424443

425-
- `requireOneSlicingArgument` in `@listSize` for validation that exactly one slicing argument is provided
426-
- **Weights on input object fields** — when an argument accepts an input object, the `@cost` weights on its nested fields are not yet accumulated recursively
427444
- **Weights on directive arguments** — `@cost` placed on arguments of custom directives is not accounted for
428445

router-tests/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ require (
2828
github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e
2929
github.com/wundergraph/cosmo/router v0.0.0-20260319123623-f186a0f724f6
3030
github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e
31-
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.268
31+
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.269
3232
go.opentelemetry.io/otel v1.39.0
3333
go.opentelemetry.io/otel/sdk v1.39.0
3434
go.opentelemetry.io/otel/sdk/metric v1.39.0

router-tests/go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ github.com/wundergraph/astjson v1.1.0 h1:xORDosrZ87zQFJwNGe/HIHXqzpdHOFmqWgykCLV
357357
github.com/wundergraph/astjson v1.1.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw=
358358
github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc=
359359
github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw=
360-
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.268 h1:lP9kWLiPO2U3JuwpQ/WX7nTVfKeMtVab2G3DAFblVA0=
361-
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.268/go.mod h1:HjTAO/cuICpu31IfHY9qmSPygx6Gza7Wt9hTSReTI+A=
360+
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.269 h1:BFQ4/IFqucZsrmzs6vkqjHC5j2XV6rhnmoMLmtYMcp8=
361+
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.269/go.mod h1:HjTAO/cuICpu31IfHY9qmSPygx6Gza7Wt9hTSReTI+A=
362362
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg=
363363
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
364364
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=

router-tests/security/costs_test.go

Lines changed: 112 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package integration
33
import (
44
"context"
55
"net/http"
6-
"strconv"
76
"testing"
87

98
"github.com/stretchr/testify/require"
@@ -53,7 +52,6 @@ func TestOperationCost(t *testing.T) {
5352

5453
// @listSize(assumedSize: 50) overrides EstimatedListSize; cost = 50 * 2 = 100
5554
estimated := res.Response.Header.Get(core.CostEstimatedHeader)
56-
require.NotEmpty(t, estimated, "estimated cost header should be present")
5755
require.Equal(t, "100", estimated)
5856

5957
// the actual cost should not be calculated nor exposed
@@ -153,7 +151,6 @@ func TestOperationCost(t *testing.T) {
153151

154152
// @listSize(assumedSize: 50) on employees overrides EstimatedListSize(200)
155153
estimated := res.Response.Header.Get(core.CostEstimatedHeader)
156-
require.NotEmpty(t, estimated, "estimated cost header should be present")
157154
require.Equal(t, "50", estimated)
158155
})
159156
})
@@ -277,11 +274,9 @@ func TestOperationCost(t *testing.T) {
277274
// upc, repositoryURL, id: 0 (scalars)
278275
// total: (10 + 13) × 10 = 230
279276
estimated := res.Response.Header.Get(core.CostEstimatedHeader)
280-
require.NotEmpty(t, estimated, "estimated cost header should be present")
281277
require.Equal(t, "230", estimated)
282278

283279
actual := res.Response.Header.Get(core.CostActualHeader)
284-
require.NotEmpty(t, actual, "actual cost header should be present")
285280
require.Equal(t, "45", actual)
286281

287282
// Query 2: only employees-subgraph fields — Cosmo @cost(weight: 5) from employees applies
@@ -294,11 +289,60 @@ func TestOperationCost(t *testing.T) {
294289
require.Equal(t, "150", estimated2)
295290

296291
actual2 := res2.Response.Header.Get(core.CostActualHeader)
297-
require.NotEmpty(t, actual2, "actual cost header should be present")
298292
require.Equal(t, "21", actual2)
299293
})
300294
})
301295

296+
t.Run("input object field cost weight on department", func(t *testing.T) {
297+
t.Parallel()
298+
testenv.Run(t, &testenv.Config{
299+
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
300+
securityConfiguration.CostControl = &config.CostControl{
301+
Enabled: true,
302+
Mode: config.CostControlModeMeasure,
303+
MaxEstimatedLimit: 10000,
304+
EstimatedListSize: 10,
305+
ExposeHeaders: true,
306+
}
307+
},
308+
}, func(t *testing.T, xEnv *testenv.Environment) {
309+
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
310+
Query: `{ findEmployeesBy(criteria: { department: ENGINEERING }) { id } }`,
311+
})
312+
require.Contains(t, res.Body, `"data":`)
313+
314+
// 10*1 + 17
315+
require.Equal(t, "27", res.Response.Header.Get(core.CostEstimatedHeader))
316+
// 7*1 + 17
317+
require.Equal(t, "24", res.Response.Header.Get(core.CostActualHeader))
318+
})
319+
})
320+
321+
t.Run("input object field cost weight on title", func(t *testing.T) {
322+
t.Parallel()
323+
testenv.Run(t, &testenv.Config{
324+
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
325+
securityConfiguration.CostControl = &config.CostControl{
326+
Enabled: true,
327+
Mode: config.CostControlModeMeasure,
328+
MaxEstimatedLimit: 10000,
329+
EstimatedListSize: 10,
330+
ExposeHeaders: true,
331+
}
332+
},
333+
}, func(t *testing.T, xEnv *testenv.Environment) {
334+
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
335+
Query: `{ findEmployeesBy(criteria: { title: "Founder" }) { id } }`,
336+
})
337+
require.Contains(t, res.Body, `"data":`)
338+
339+
// 10 * 1 - 3
340+
require.Equal(t, "7", res.Response.Header.Get(core.CostEstimatedHeader))
341+
// 1 * 1 - 3
342+
require.Equal(t, "0", res.Response.Header.Get(core.CostActualHeader))
343+
})
344+
})
345+
302346
t.Run("slicingArguments controls list size estimation", func(t *testing.T) {
303347
t.Parallel()
304348
testenv.Run(t, &testenv.Config{
@@ -714,27 +758,16 @@ func TestOperationCost(t *testing.T) {
714758
// 1st request – plan cache MISS
715759
res1 := xEnv.MakeGraphQLRequestOK(query)
716760
require.Contains(t, res1.Body, `"data":`)
717-
718-
estimated1 := res1.Response.Header.Get(core.CostEstimatedHeader)
719-
actual1 := res1.Response.Header.Get(core.CostActualHeader)
720-
require.NotEmpty(t, estimated1, "first request should have estimated cost header")
721-
require.NotEmpty(t, actual1, "first request should have actual cost header")
761+
require.Equal(t, "MISS", res1.Response.Header.Get(core.ExecutionPlanCacheHeader))
762+
require.Equal(t, "8", res1.Response.Header.Get(core.CostEstimatedHeader))
763+
require.Equal(t, "8", res1.Response.Header.Get(core.CostActualHeader))
722764

723765
// 2nd request – plan cache HIT
724766
res2 := xEnv.MakeGraphQLRequestOK(query)
725767
require.Contains(t, res2.Body, `"data":`)
726-
727-
estimated2 := res2.Response.Header.Get(core.CostEstimatedHeader)
728-
actual2 := res2.Response.Header.Get(core.CostActualHeader)
729-
require.NotEmpty(t, estimated2, "second request should have estimated cost header")
730-
require.NotEmpty(t, actual2, "second request should have actual cost header")
731-
732-
require.Equal(t, estimated1, estimated2,
733-
"estimated cost differs between cache miss (%s) and cache hit (%s) ",
734-
estimated1, estimated2)
735-
require.Equal(t, actual1, actual2,
736-
"actual cost differs between cache miss (%s) and cache hit (%s) ",
737-
actual1, actual2)
768+
require.Equal(t, "HIT", res2.Response.Header.Get(core.ExecutionPlanCacheHeader))
769+
require.Equal(t, "8", res2.Response.Header.Get(core.CostEstimatedHeader))
770+
require.Equal(t, "8", res2.Response.Header.Get(core.CostActualHeader))
738771
})
739772
})
740773

@@ -759,8 +792,8 @@ func TestOperationCost(t *testing.T) {
759792

760793
estimated1 := res1.Response.Header.Get(core.CostEstimatedHeader)
761794
actual1 := res1.Response.Header.Get(core.CostActualHeader)
762-
require.NotEmpty(t, estimated1, "first request should have estimated cost header")
763-
require.NotEmpty(t, actual1, "first request should have actual cost header")
795+
require.Equal(t, "8", estimated1)
796+
require.Equal(t, "8", actual1)
764797

765798
// 2nd request – plan cache HIT
766799
query2 := testenv.GraphQLRequest{
@@ -771,15 +804,8 @@ func TestOperationCost(t *testing.T) {
771804

772805
estimated2 := res2.Response.Header.Get(core.CostEstimatedHeader)
773806
actual2 := res2.Response.Header.Get(core.CostActualHeader)
774-
require.NotEmpty(t, estimated2, "second request should have estimated cost header")
775-
require.NotEmpty(t, actual2, "second request should have actual cost header")
776-
777-
require.Equal(t, estimated1, estimated2,
778-
"estimated cost differs between cache miss (%s) and cache hit (%s) ",
779-
estimated1, estimated2)
780-
require.Equal(t, actual1, actual2,
781-
"actual cost differs between cache miss (%s) and cache hit (%s) ",
782-
actual1, actual2)
807+
require.Equal(t, "8", estimated2)
808+
require.Equal(t, "8", actual2)
783809
})
784810
})
785811

@@ -885,6 +911,57 @@ func TestOperationCost(t *testing.T) {
885911
require.Equal(t, int64(24), totalSum, "total estimated cost sum should be 3×8=24")
886912
})
887913
})
914+
915+
t.Run("input object field costs are consistent across cache hits for different queries", func(t *testing.T) {
916+
t.Parallel()
917+
testenv.Run(t, &testenv.Config{
918+
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
919+
securityConfiguration.CostControl = &config.CostControl{
920+
Enabled: true,
921+
Mode: config.CostControlModeMeasure,
922+
EstimatedListSize: 10,
923+
ExposeHeaders: true,
924+
}
925+
},
926+
}, func(t *testing.T, xEnv *testenv.Environment) {
927+
// 1st request – plan cache MISS
928+
resDept1 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
929+
Query: `{ findEmployeesBy(criteria: { department: ENGINEERING }) { id } }`,
930+
})
931+
require.Contains(t, resDept1.Body, `"data":`)
932+
require.Equal(t, "MISS", resDept1.Response.Header.Get(core.ExecutionPlanCacheHeader))
933+
require.Equal(t, "27", resDept1.Response.Header.Get(core.CostEstimatedHeader))
934+
require.Equal(t, "24", resDept1.Response.Header.Get(core.CostActualHeader))
935+
936+
// 2nd request – plan cache HIT (same normalized query, different input field)
937+
// Cost is recalculated per request based on actual input field values
938+
resTitle1 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
939+
Query: `{ findEmployeesBy(criteria: { title: "Founder" }) { id } }`,
940+
})
941+
require.Contains(t, resTitle1.Body, `"data":`)
942+
require.Equal(t, "HIT", resTitle1.Response.Header.Get(core.ExecutionPlanCacheHeader))
943+
require.Equal(t, "7", resTitle1.Response.Header.Get(core.CostEstimatedHeader))
944+
require.Equal(t, "0", resTitle1.Response.Header.Get(core.CostActualHeader))
945+
946+
// 3rd request – cache HIT, same input field as 1st, different value
947+
resDept2 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
948+
Query: `{ findEmployeesBy(criteria: { department: MARKETING }) { id } }`,
949+
})
950+
require.Contains(t, resDept2.Body, `"data":`)
951+
require.Equal(t, "HIT", resDept2.Response.Header.Get(core.ExecutionPlanCacheHeader))
952+
require.Equal(t, "27", resDept2.Response.Header.Get(core.CostEstimatedHeader))
953+
require.Equal(t, "20", resDept2.Response.Header.Get(core.CostActualHeader))
954+
955+
// 4th request – cache HIT, same input field as 2nd, different value
956+
resTitle2 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
957+
Query: `{ findEmployeesBy(criteria: { title: "Director" }) { id } }`,
958+
})
959+
require.Contains(t, resTitle2.Body, `"data":`)
960+
require.Equal(t, "HIT", resTitle2.Response.Header.Get(core.ExecutionPlanCacheHeader))
961+
require.Equal(t, "7", resTitle2.Response.Header.Get(core.CostEstimatedHeader))
962+
require.Equal(t, "0", resTitle2.Response.Header.Get(core.CostActualHeader))
963+
})
964+
})
888965
})
889966

890967
t.Run("negative weights", func(t *testing.T) {
@@ -972,12 +1049,7 @@ func TestOperationCost(t *testing.T) {
9721049
})
9731050
require.Contains(t, res.Body, `"data":`)
9741051

975-
estimated := res.Response.Header.Get(core.CostEstimatedHeader)
976-
require.NotEmpty(t, estimated)
977-
estimatedVal, err := strconv.Atoi(estimated)
978-
require.NoError(t, err)
979-
require.Equal(t, estimatedVal, 8, "estimated cost must not be negative")
980-
require.Equal(t, estimatedVal, 8, "negative type weight should reduce cost below baseline of 18")
1052+
require.Equal(t, "8", res.Response.Header.Get(core.CostEstimatedHeader))
9811053
})
9821054
})
9831055

0 commit comments

Comments
 (0)