Skip to content

Commit c3a5492

Browse files
committed
Fix PodTopologySpread matching pods counts for constraints with the same topologyKey
1 parent a499fac commit c3a5492

File tree

5 files changed

+670
-652
lines changed

5 files changed

+670
-652
lines changed

pkg/scheduler/framework/plugins/podtopologyspread/common.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ import (
2727
"k8s.io/utils/ptr"
2828
)
2929

30-
type topologyPair struct {
31-
key string
32-
value string
33-
}
34-
3530
// topologySpreadConstraint is an internal version for v1.TopologySpreadConstraint
3631
// and where the selector is parsed.
3732
// Fields are exported for comparison during testing.

pkg/scheduler/framework/plugins/podtopologyspread/filtering.go

Lines changed: 64 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package podtopologyspread
1919
import (
2020
"context"
2121
"fmt"
22+
"maps"
2223
"math"
2324

2425
v1 "k8s.io/api/core/v1"
@@ -31,37 +32,31 @@ import (
3132
const preFilterStateKey = "PreFilter" + Name
3233

3334
// preFilterState computed at PreFilter and used at Filter.
34-
// It combines TpKeyToCriticalPaths and TpPairToMatchNum to represent:
35+
// It combines CriticalPaths and TpValueToMatchNum to represent:
3536
// (1) critical paths where the least pods are matched on each spread constraint.
3637
// (2) number of pods matched on each spread constraint.
3738
// A nil preFilterState denotes it's not set at all (in PreFilter phase);
3839
// An empty preFilterState object denotes it's a legit state and is set in PreFilter phase.
3940
// Fields are exported for comparison during testing.
4041
type preFilterState struct {
4142
Constraints []topologySpreadConstraint
42-
// We record 2 critical paths instead of all critical paths here.
43-
// criticalPaths[0].MatchNum always holds the minimum matching number.
44-
// criticalPaths[1].MatchNum is always greater or equal to criticalPaths[0].MatchNum, but
43+
// CriticalPaths is a slice indexed by constraint index.
44+
// Per each entry, we record 2 critical paths instead of all critical paths.
45+
// CriticalPaths[i][0].MatchNum always holds the minimum matching number.
46+
// CriticalPaths[i][1].MatchNum is always greater or equal to CriticalPaths[i][0].MatchNum, but
4547
// it's not guaranteed to be the 2nd minimum match number.
46-
TpKeyToCriticalPaths map[string]*criticalPaths
47-
// TpKeyToDomainsNum is keyed with topologyKey, and valued with the number of domains.
48-
TpKeyToDomainsNum map[string]int
49-
// TpPairToMatchNum is keyed with topologyPair, and valued with the number of matching pods.
50-
TpPairToMatchNum map[topologyPair]int
48+
CriticalPaths []*criticalPaths
49+
// TpValueToMatchNum is a slice indexed by constraint index.
50+
// Each entry is keyed with topology value, and valued with the number of matching pods.
51+
TpValueToMatchNum []map[string]int
5152
}
5253

5354
// minMatchNum returns the global minimum for the calculation of skew while taking MinDomains into account.
54-
func (s *preFilterState) minMatchNum(tpKey string, minDomains int32) (int, error) {
55-
paths, ok := s.TpKeyToCriticalPaths[tpKey]
56-
if !ok {
57-
return 0, fmt.Errorf("failed to retrieve path by topology key")
58-
}
55+
func (s *preFilterState) minMatchNum(constraintID int, minDomains int32) (int, error) {
56+
paths := s.CriticalPaths[constraintID]
5957

6058
minMatchNum := paths[0].MatchNum
61-
domainsNum, ok := s.TpKeyToDomainsNum[tpKey]
62-
if !ok {
63-
return 0, fmt.Errorf("failed to retrieve the number of domains by topology key")
64-
}
59+
domainsNum := len(s.TpValueToMatchNum[constraintID])
6560

6661
if domainsNum < int(minDomains) {
6762
// When the number of eligible domains with matching topology keys is less than `minDomains`,
@@ -79,17 +74,15 @@ func (s *preFilterState) Clone() framework.StateData {
7974
}
8075
copy := preFilterState{
8176
// Constraints are shared because they don't change.
82-
Constraints: s.Constraints,
83-
TpKeyToCriticalPaths: make(map[string]*criticalPaths, len(s.TpKeyToCriticalPaths)),
84-
// The number of domains does not change as a result of AddPod/RemovePod methods on PreFilter Extensions
85-
TpKeyToDomainsNum: s.TpKeyToDomainsNum,
86-
TpPairToMatchNum: make(map[topologyPair]int, len(s.TpPairToMatchNum)),
77+
Constraints: s.Constraints,
78+
CriticalPaths: make([]*criticalPaths, len(s.CriticalPaths)),
79+
TpValueToMatchNum: make([]map[string]int, len(s.TpValueToMatchNum)),
8780
}
88-
for tpKey, paths := range s.TpKeyToCriticalPaths {
89-
copy.TpKeyToCriticalPaths[tpKey] = &criticalPaths{paths[0], paths[1]}
81+
for i, paths := range s.CriticalPaths {
82+
copy.CriticalPaths[i] = &criticalPaths{paths[0], paths[1]}
9083
}
91-
for tpPair, matchNum := range s.TpPairToMatchNum {
92-
copy.TpPairToMatchNum[tpPair] = matchNum
84+
for i, tpMap := range s.TpValueToMatchNum {
85+
copy.TpValueToMatchNum[i] = maps.Clone(tpMap)
9386
}
9487
return &copy
9588
}
@@ -200,7 +193,7 @@ func (pl *PodTopologySpread) updateWithPod(s *preFilterState, updatedPod, preemp
200193
}
201194

202195
podLabelSet := labels.Set(updatedPod.Labels)
203-
for _, constraint := range s.Constraints {
196+
for i, constraint := range s.Constraints {
204197
if !constraint.Selector.Matches(podLabelSet) {
205198
continue
206199
}
@@ -210,10 +203,9 @@ func (pl *PodTopologySpread) updateWithPod(s *preFilterState, updatedPod, preemp
210203
continue
211204
}
212205

213-
k, v := constraint.TopologyKey, node.Labels[constraint.TopologyKey]
214-
pair := topologyPair{key: k, value: v}
215-
s.TpPairToMatchNum[pair] += delta
216-
s.TpKeyToCriticalPaths[k].update(v, s.TpPairToMatchNum[pair])
206+
v := node.Labels[constraint.TopologyKey]
207+
s.TpValueToMatchNum[i][v] += delta
208+
s.CriticalPaths[i].update(v, s.TpValueToMatchNum[i][v])
217209
}
218210
}
219211

@@ -232,6 +224,12 @@ func getPreFilterState(cycleState *framework.CycleState) (*preFilterState, error
232224
return s, nil
233225
}
234226

227+
type topologyCount struct {
228+
topologyValue string
229+
constraintID int
230+
count int
231+
}
232+
235233
// calPreFilterState computes preFilterState describing how pods are spread on topologies.
236234
func (pl *PodTopologySpread) calPreFilterState(ctx context.Context, pod *v1.Pod) (*preFilterState, error) {
237235
constraints, err := pl.getConstraints(pod)
@@ -248,15 +246,18 @@ func (pl *PodTopologySpread) calPreFilterState(ctx context.Context, pod *v1.Pod)
248246
}
249247

250248
s := preFilterState{
251-
Constraints: constraints,
252-
TpKeyToCriticalPaths: make(map[string]*criticalPaths, len(constraints)),
253-
TpPairToMatchNum: make(map[topologyPair]int, sizeHeuristic(len(allNodes), constraints)),
249+
Constraints: constraints,
250+
CriticalPaths: make([]*criticalPaths, len(constraints)),
251+
TpValueToMatchNum: make([]map[string]int, len(constraints)),
252+
}
253+
for i := 0; i < len(constraints); i++ {
254+
s.TpValueToMatchNum[i] = make(map[string]int, sizeHeuristic(len(allNodes), constraints[i]))
254255
}
255256

256-
tpCountsByNode := make([]map[topologyPair]int, len(allNodes))
257+
tpCountsByNode := make([][]topologyCount, len(allNodes))
257258
requiredNodeAffinity := nodeaffinity.GetRequiredNodeAffinity(pod)
258-
processNode := func(i int) {
259-
nodeInfo := allNodes[i]
259+
processNode := func(n int) {
260+
nodeInfo := allNodes[n]
260261
node := nodeInfo.Node()
261262

262263
if !pl.enableNodeInclusionPolicyInPodTopologySpread {
@@ -272,38 +273,39 @@ func (pl *PodTopologySpread) calPreFilterState(ctx context.Context, pod *v1.Pod)
272273
return
273274
}
274275

275-
tpCounts := make(map[topologyPair]int, len(constraints))
276-
for _, c := range constraints {
276+
tpCounts := make([]topologyCount, 0, len(constraints))
277+
for i, c := range constraints {
277278
if pl.enableNodeInclusionPolicyInPodTopologySpread &&
278279
!c.matchNodeInclusionPolicies(pod, node, requiredNodeAffinity) {
279280
continue
280281
}
281282

282-
pair := topologyPair{key: c.TopologyKey, value: node.Labels[c.TopologyKey]}
283+
value := node.Labels[c.TopologyKey]
283284
count := countPodsMatchSelector(nodeInfo.Pods, c.Selector, pod.Namespace)
284-
tpCounts[pair] = count
285+
tpCounts = append(tpCounts, topologyCount{
286+
topologyValue: value,
287+
constraintID: i,
288+
count: count,
289+
})
285290
}
286-
tpCountsByNode[i] = tpCounts
291+
tpCountsByNode[n] = tpCounts
287292
}
288293
pl.parallelizer.Until(ctx, len(allNodes), processNode, pl.Name())
289294

290295
for _, tpCounts := range tpCountsByNode {
291-
for tp, count := range tpCounts {
292-
s.TpPairToMatchNum[tp] += count
296+
// tpCounts might not hold all the constraints, so index can't be used here as constraintID.
297+
for _, tpCount := range tpCounts {
298+
s.TpValueToMatchNum[tpCount.constraintID][tpCount.topologyValue] += tpCount.count
293299
}
294300
}
295-
s.TpKeyToDomainsNum = make(map[string]int, len(constraints))
296-
for tp := range s.TpPairToMatchNum {
297-
s.TpKeyToDomainsNum[tp.key]++
298-
}
299301

300-
// calculate min match for each topology pair
302+
// calculate min match for each constraint and topology value
301303
for i := 0; i < len(constraints); i++ {
302-
key := constraints[i].TopologyKey
303-
s.TpKeyToCriticalPaths[key] = newCriticalPaths()
304-
}
305-
for pair, num := range s.TpPairToMatchNum {
306-
s.TpKeyToCriticalPaths[pair.key].update(pair.value, num)
304+
s.CriticalPaths[i] = newCriticalPaths()
305+
306+
for value, num := range s.TpValueToMatchNum[i] {
307+
s.CriticalPaths[i].update(value, num)
308+
}
307309
}
308310

309311
return &s, nil
@@ -325,7 +327,7 @@ func (pl *PodTopologySpread) Filter(ctx context.Context, cycleState *framework.C
325327

326328
logger := klog.FromContext(ctx)
327329
podLabelSet := labels.Set(pod.Labels)
328-
for _, c := range s.Constraints {
330+
for i, c := range s.Constraints {
329331
tpKey := c.TopologyKey
330332
tpVal, ok := node.Labels[c.TopologyKey]
331333
if !ok {
@@ -335,9 +337,9 @@ func (pl *PodTopologySpread) Filter(ctx context.Context, cycleState *framework.C
335337

336338
// judging criteria:
337339
// 'existing matching num' + 'if self-match (1 or 0)' - 'global minimum' <= 'maxSkew'
338-
minMatchNum, err := s.minMatchNum(tpKey, c.MinDomains)
340+
minMatchNum, err := s.minMatchNum(i, c.MinDomains)
339341
if err != nil {
340-
logger.Error(err, "Internal error occurred while retrieving value precalculated in PreFilter", "topologyKey", tpKey, "paths", s.TpKeyToCriticalPaths)
342+
logger.Error(err, "Internal error occurred while retrieving value precalculated in PreFilter", "topologyKey", tpKey, "paths", s.CriticalPaths[i])
341343
continue
342344
}
343345

@@ -346,11 +348,7 @@ func (pl *PodTopologySpread) Filter(ctx context.Context, cycleState *framework.C
346348
selfMatchNum = 1
347349
}
348350

349-
pair := topologyPair{key: tpKey, value: tpVal}
350-
matchNum := 0
351-
if tpCount, ok := s.TpPairToMatchNum[pair]; ok {
352-
matchNum = tpCount
353-
}
351+
matchNum := s.TpValueToMatchNum[i][tpVal]
354352
skew := matchNum + selfMatchNum - minMatchNum
355353
if skew > int(c.MaxSkew) {
356354
logger.V(5).Info("Node failed spreadConstraint: matchNum + selfMatchNum - minMatchNum > maxSkew", "node", klog.KObj(node), "topologyKey", tpKey, "matchNum", matchNum, "selfMatchNum", selfMatchNum, "minMatchNum", minMatchNum, "maxSkew", c.MaxSkew)
@@ -361,11 +359,9 @@ func (pl *PodTopologySpread) Filter(ctx context.Context, cycleState *framework.C
361359
return nil
362360
}
363361

364-
func sizeHeuristic(nodes int, constraints []topologySpreadConstraint) int {
365-
for _, c := range constraints {
366-
if c.TopologyKey == v1.LabelHostname {
367-
return nodes
368-
}
362+
func sizeHeuristic(nodes int, constraint topologySpreadConstraint) int {
363+
if constraint.TopologyKey == v1.LabelHostname {
364+
return nodes
369365
}
370366
return 0
371367
}

0 commit comments

Comments
 (0)