@@ -53,29 +53,37 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
53
53
# Process added requests.
54
54
for index , params , _ , _ in batch_update .added :
55
55
min_p = params .min_p
56
- if self .min_p_cpu [index ] != min_p :
56
+ min_p_before = self .min_p_cpu [index ]
57
+ if min_p_before != min_p :
57
58
needs_update = True
58
59
self .min_p_cpu [index ] = min_p
59
- if min_p :
60
- self .min_p_count += 1
60
+ if min_p and not min_p_before :
61
+ self .min_p_count += 1
62
+ elif not min_p and min_p_before :
63
+ self .min_p_count -= 1
61
64
62
65
if self .min_p_count :
63
66
# Process removed requests.
64
- needs_update |= bool (batch_update .removed )
65
- for index in batch_update .removed :
66
- if self .min_p_cpu [index ]:
67
- self .min_p_count -= 1
67
+ if batch_update .removed :
68
+ needs_update = True
69
+ for index in batch_update .removed :
70
+ if self .min_p_cpu [index ]:
71
+ self .min_p_cpu [index ] = 0
72
+ self .min_p_count -= 1
68
73
69
- # Process moved requests, unidirectional (a->b) and swap (a<->b)
74
+ # Process moved requests, unidirectional (a->b) and swap (a<->b).
70
75
for adx , bdx , direct in batch_update .moved :
71
- change = (min_p_a :=
72
- self .min_p_cpu [adx ]) != (min_p_b :=
73
- self .min_p_cpu [bdx ])
74
- needs_update |= change
75
- if change :
76
+ min_p_a , min_p_b = self .min_p_cpu [adx ], self .min_p_cpu [bdx ]
77
+ if min_p_a != min_p_b :
78
+ needs_update = True
76
79
self .min_p_cpu [bdx ] = min_p_a
77
80
if direct == MoveDirectionality .SWAP :
78
81
self .min_p_cpu [adx ] = min_p_b
82
+ if direct == MoveDirectionality .UNIDIRECTIONAL :
83
+ if min_p_a :
84
+ self .min_p_cpu [adx ] = 0
85
+ if min_p_b :
86
+ self .min_p_count -= 1
79
87
80
88
# Update tensors if needed.
81
89
size = batch_update .batch_size
0 commit comments