Skip to content

Commit f882817

Browse files
authored
Merge pull request kubernetes#126359 from jpbetz/quantity-estimated-cost
Fix estimated cost for Kubernetes defined CEL types for equals
2 parents e9c9a27 + e5f207a commit f882817

File tree

5 files changed

+133
-11
lines changed

5 files changed

+133
-11
lines changed

staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/celcoststability_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,14 +2010,14 @@ func TestCelEstimatedCostStability(t *testing.T) {
20102010
`isQuantity(self.val2)`: 314575,
20112011
`isQuantity("200M")`: 1,
20122012
`isQuantity("20Mi")`: 1,
2013-
`quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`: uint64(3689348814741910532),
2014-
`quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`: uint64(5534023222112865798),
2013+
`quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`: uint64(6),
2014+
`quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`: uint64(9),
20152015
`quantity(self.val1).isLessThan(quantity(self.val2))`: 629151,
20162016
`quantity("50M").isLessThan(quantity("100M"))`: 3,
20172017
`quantity("50Mi").isGreaterThan(quantity("50M"))`: 3,
20182018
`quantity("200M").compareTo(quantity("0.2G")) == 0`: 4,
2019-
`quantity("50k").add(quantity("20")) == quantity("50.02k")`: uint64(1844674407370955268),
2020-
`quantity("50k").sub(20) == quantity("49980")`: uint64(1844674407370955267),
2019+
`quantity("50k").add(quantity("20")) == quantity("50.02k")`: uint64(5),
2020+
`quantity("50k").sub(20) == quantity("49980")`: uint64(4),
20212021
`quantity("50").isInteger()`: 2,
20222022
`quantity(self.val1).isInteger()`: 314576,
20232023
},

staging/src/k8s.io/apiserver/pkg/cel/library/cost.go

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package library
1919
import (
2020
"fmt"
2121
"math"
22+
"reflect"
2223

2324
"github.com/google/cel-go/checker"
2425
"github.com/google/cel-go/common"
@@ -27,6 +28,7 @@ import (
2728
"github.com/google/cel-go/common/types/ref"
2829
"github.com/google/cel-go/common/types/traits"
2930

31+
"k8s.io/apimachinery/pkg/util/sets"
3032
"k8s.io/apiserver/pkg/cel"
3133
)
3234

@@ -48,6 +50,22 @@ var knownUnhandledFunctions = map[string]bool{
4850
"strings.quote": true,
4951
}
5052

53+
// TODO: Replace this with a utility that extracts types from libraries.
54+
var knownKubernetesRuntimeTypes = sets.New[reflect.Type](
55+
reflect.ValueOf(cel.URL{}).Type(),
56+
reflect.ValueOf(cel.IP{}).Type(),
57+
reflect.ValueOf(cel.CIDR{}).Type(),
58+
reflect.ValueOf(&cel.Format{}).Type(),
59+
reflect.ValueOf(cel.Quantity{}).Type(),
60+
)
61+
var knownKubernetesCompilerTypes = sets.New[ref.Type](
62+
cel.CIDRType,
63+
cel.IPType,
64+
cel.FormatType,
65+
cel.QuantityType,
66+
cel.URLType,
67+
)
68+
5169
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
5270
type CostEstimator struct {
5371
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
@@ -235,6 +253,27 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
235253
// url accessors
236254
cost := uint64(1)
237255
return &cost
256+
case "_==_":
257+
if len(args) == 2 {
258+
unitCost := uint64(1)
259+
lhs := args[0]
260+
switch lhs.(type) {
261+
case cel.Quantity:
262+
return &unitCost
263+
case cel.IP:
264+
return &unitCost
265+
case cel.CIDR:
266+
return &unitCost
267+
case *cel.Format: // Formats have a small max size.
268+
return &unitCost
269+
case cel.URL: // TODO: Computing the actual cost is expensive, and changing this would be a breaking change
270+
return &unitCost
271+
default:
272+
if panicOnUnknown && knownKubernetesRuntimeTypes.Has(reflect.ValueOf(lhs).Type()) {
273+
panic(fmt.Errorf("CallCost: unhandled equality for Kubernetes type %T", lhs))
274+
}
275+
}
276+
}
238277
}
239278
if panicOnUnknown && !knownUnhandledFunctions[function] {
240279
panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args))
@@ -278,7 +317,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
278317
case "url":
279318
if len(args) == 1 {
280319
sz := l.sizeEstimate(args[0])
281-
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
320+
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
282321
}
283322
case "lowerAscii", "upperAscii", "substring", "trim":
284323
if target != nil {
@@ -475,6 +514,39 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
475514
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
476515
// url accessors
477516
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
517+
case "_==_":
518+
if len(args) == 2 {
519+
lhs := args[0]
520+
rhs := args[1]
521+
if lhs.Type().Equal(rhs.Type()) == types.True {
522+
t := lhs.Type()
523+
if t.Kind() == types.OpaqueKind {
524+
switch t.TypeName() {
525+
case cel.IPType.TypeName(), cel.CIDRType.TypeName():
526+
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
527+
}
528+
}
529+
if t.Kind() == types.StructKind {
530+
switch t {
531+
case cel.QuantityType: // O(1) cost equality checks
532+
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
533+
case cel.FormatType:
534+
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: cel.MaxFormatSize}.MultiplyByCostFactor(common.StringTraversalCostFactor)}
535+
case cel.URLType:
536+
size := checker.SizeEstimate{Min: 1, Max: 1}
537+
rhSize := rhs.ComputedSize()
538+
lhSize := rhs.ComputedSize()
539+
if rhSize != nil && lhSize != nil {
540+
size = rhSize.Union(*lhSize)
541+
}
542+
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: size.Max}.MultiplyByCostFactor(common.StringTraversalCostFactor)}
543+
}
544+
}
545+
if panicOnUnknown && knownKubernetesCompilerTypes.Has(t) {
546+
panic(fmt.Errorf("EstimateCallCost: unhandled equality for Kubernetes type %v", t))
547+
}
548+
}
549+
}
478550
}
479551
if panicOnUnknown && !knownUnhandledFunctions[function] {
480552
panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args))

staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@ func TestURLsCost(t *testing.T) {
206206
expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4},
207207
expectRuntimeCost: 4,
208208
},
209+
{
210+
ops: []string{" == url('https:://kubernetes.io/')"},
211+
expectEsimatedCost: checker.CostEstimate{Min: 7, Max: 9},
212+
expectRuntimeCost: 7,
213+
},
214+
{
215+
ops: []string{" == url('http://x.b')"},
216+
expectEsimatedCost: checker.CostEstimate{Min: 5, Max: 5},
217+
expectRuntimeCost: 5,
218+
},
209219
}
210220

211221
for _, tc := range cases {
@@ -245,6 +255,14 @@ func TestIPCost(t *testing.T) {
245255
},
246256
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
247257
},
258+
{
259+
ops: []string{" == ip('192.168.0.1')"},
260+
// For most other operations, the cost is expected to be the base + 1.
261+
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
262+
return c.Add(ipv4BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1})
263+
},
264+
expectRuntimeCost: func(c uint64) uint64 { return c + ipv4BaseRuntimeCost + 1 },
265+
},
248266
}
249267

250268
for _, tc := range testCases {
@@ -320,6 +338,14 @@ func TestCIDRCost(t *testing.T) {
320338
},
321339
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
322340
},
341+
{
342+
ops: []string{" == cidr('2001:db8::/32')"},
343+
// For most other operations, the cost is expected to be the base + 1.
344+
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
345+
return c.Add(ipv6BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1})
346+
},
347+
expectRuntimeCost: func(c uint64) uint64 { return c + ipv6BaseRuntimeCost + 1 },
348+
},
323349
}
324350

325351
//nolint:gocritic
@@ -708,19 +734,19 @@ func TestQuantityCost(t *testing.T) {
708734
{
709735
name: "equality_reflexivity",
710736
expr: `quantity("200M") == quantity("200M")`,
711-
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 1844674407370955266},
737+
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3},
712738
expectRuntimeCost: 3,
713739
},
714740
{
715741
name: "equality_symmetry",
716742
expr: `quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`,
717-
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3689348814741910532},
743+
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 6},
718744
expectRuntimeCost: 6,
719745
},
720746
{
721747
name: "equality_transitivity",
722748
expr: `quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`,
723-
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 5534023222112865798},
749+
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 9},
724750
expectRuntimeCost: 9,
725751
},
726752
{
@@ -744,19 +770,19 @@ func TestQuantityCost(t *testing.T) {
744770
{
745771
name: "add_quantity",
746772
expr: `quantity("50k").add(quantity("20")) == quantity("50.02k")`,
747-
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268},
773+
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5},
748774
expectRuntimeCost: 5,
749775
},
750776
{
751777
name: "sub_quantity",
752778
expr: `quantity("50k").sub(quantity("20")) == quantity("49.98k")`,
753-
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268},
779+
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5},
754780
expectRuntimeCost: 5,
755781
},
756782
{
757783
name: "sub_int",
758784
expr: `quantity("50k").sub(20) == quantity("49980")`,
759-
expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 1844674407370955267},
785+
expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 4},
760786
expectRuntimeCost: 4,
761787
},
762788
{
@@ -825,6 +851,18 @@ func TestNameFormatCost(t *testing.T) {
825851
expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34},
826852
expectRuntimeCost: 10,
827853
},
854+
{
855+
name: "format.dns1123label.validate",
856+
expr: `format.named("dns1123Label").value().validate("my-name")`,
857+
expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34},
858+
expectRuntimeCost: 10,
859+
},
860+
{
861+
name: "format.dns1123label.validate",
862+
expr: `format.named("dns1123Label").value() == format.named("dns1123Label").value()`,
863+
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 11},
864+
expectRuntimeCost: 5,
865+
},
828866
}
829867

830868
for _, tc := range cases {

staging/src/k8s.io/apiserver/pkg/cel/library/format_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222

2323
"github.com/google/cel-go/common/types"
2424
"github.com/google/cel-go/common/types/ref"
25+
26+
"k8s.io/apiserver/pkg/cel"
2527
"k8s.io/apiserver/pkg/cel/library"
2628
)
2729

@@ -228,3 +230,11 @@ func TestFormat(t *testing.T) {
228230
})
229231
}
230232
}
233+
234+
func TestSizeLimit(t *testing.T) {
235+
for name := range library.ConstantFormats {
236+
if len(name) > cel.MaxFormatSize {
237+
t.Fatalf("All formats must be <= %d chars in length", cel.MaxFormatSize)
238+
}
239+
}
240+
}

staging/src/k8s.io/apiserver/pkg/cel/limits.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,7 @@ const (
4848
// MinNumberSize is the length of literal 0
4949
MinNumberSize = 1
5050

51+
// MaxFormatSize is the maximum size we allow for format strings
52+
MaxFormatSize = 64
5153
MaxNameFormatRegexSize = 128
5254
)

0 commit comments

Comments
 (0)