Skip to content

Commit 44a0a96

Browse files
authored
Merge pull request #1549 from ydb-platform/preferred-node-id
Preferred nodeID
2 parents bdcf0ff + cbe1c46 commit 44a0a96

File tree

4 files changed

+187
-9
lines changed

4 files changed

+187
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Added `ydb.WithPreferredNodeID(ctx, nodeID)` context modifier for trying to execute queries on given nodeID
2+
13
## v3.90.2
24
* Set the `pick_first` balancer for short-lived grpc connection inside ydb cluster discovery attempt
35

context.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"time"
66

7+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
78
"github.com/ydb-platform/ydb-go-sdk/v3/internal/operation"
89
)
910

@@ -19,3 +20,8 @@ func WithOperationTimeout(ctx context.Context, operationTimeout time.Duration) c
1920
func WithOperationCancelAfter(ctx context.Context, operationCancelAfter time.Duration) context.Context {
2021
return operation.WithCancelAfter(ctx, operationCancelAfter)
2122
}
23+
24+
// WithPreferredNodeID allows to set preferred node to get session from
25+
func WithPreferredNodeID(ctx context.Context, nodeID uint32) context.Context {
26+
return endpoint.WithNodeID(ctx, nodeID)
27+
}

internal/pool/pool.go

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/jonboulle/clockwork"
1010

11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
1112
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
1213
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
1314
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
@@ -20,6 +21,7 @@ type (
2021
Item interface {
2122
IsAlive() bool
2223
Close(ctx context.Context) error
24+
NodeID() uint32
2325
}
2426
ItemConstraint[T any] interface {
2527
*T
@@ -442,7 +444,7 @@ func (p *Pool[PT, T]) Close(ctx context.Context) (finalErr error) {
442444
}
443445
}
444446

445-
// getWaitCh returns pointer to a channel of sessions.
447+
// getWaitCh returns pointer to a channel of items.
446448
//
447449
// Note that returning a pointer reduces allocations on sync.Pool usage –
448450
// sync.Client.Get() returns empty interface, which leads to allocation for
@@ -474,9 +476,25 @@ func (p *Pool[PT, T]) peekFirstIdle() (item PT, touched time.Time) {
474476
return item, info.lastUsage
475477
}
476478

477-
// removes first session from idle and resets the keepAliveCount
478-
// to prevent session from dying in the internalPoolGC after it was returned
479-
// to be used only in outgoing functions that make session busy.
479+
// p.mu must be held.
480+
func (p *Pool[PT, T]) peekFirstIdleByNodeID(nodeID uint32) (item PT, touched time.Time) {
481+
el := p.idle.Front()
482+
for el != nil && el.Value.NodeID() != nodeID {
483+
el = el.Next()
484+
}
485+
if el == nil {
486+
return
487+
}
488+
item = el.Value
489+
info, has := p.index[item]
490+
if !has || el != info.idle {
491+
panic(fmt.Sprintf("inconsistent index: (%v, %+v, %+v)", has, el, info.idle))
492+
}
493+
494+
return item, info.lastUsage
495+
}
496+
497+
// removes first item from idle to use only in outgoing functions that make item busy.
480498
// p.mu must be held.
481499
func (p *Pool[PT, T]) removeFirstIdle() PT {
482500
idle, _ := p.peekFirstIdle()
@@ -488,10 +506,22 @@ func (p *Pool[PT, T]) removeFirstIdle() PT {
488506
return idle
489507
}
490508

509+
// removes first item with preferred nodeID from idle to use only in outgoing functions that make item busy.
510+
// p.mu must be held.
511+
func (p *Pool[PT, T]) removeIdleByNodeID(nodeID uint32) PT {
512+
idle, _ := p.peekFirstIdleByNodeID(nodeID)
513+
if idle != nil {
514+
info := p.removeIdle(idle)
515+
p.index[idle] = info
516+
}
517+
518+
return idle
519+
}
520+
491521
// p.mu must be held.
492522
func (p *Pool[PT, T]) notifyAboutIdle(idle PT) (notified bool) {
493523
for el := p.waitQ.Front(); el != nil; el = p.waitQ.Front() {
494-
// Some goroutine is waiting for a session.
524+
// Some goroutine is waiting for a item.
495525
//
496526
// It could be in this states:
497527
// 1) Reached the select code and awaiting for a value in channel.
@@ -532,7 +562,7 @@ func (p *Pool[PT, T]) notifyAboutIdle(idle PT) (notified bool) {
532562
func (p *Pool[PT, T]) removeIdle(item PT) itemInfo[PT, T] {
533563
info, has := p.index[item]
534564
if !has || info.idle == nil {
535-
panic("inconsistent session client index")
565+
panic("inconsistent item client index")
536566
}
537567

538568
p.changeState(func() Stats {
@@ -585,6 +615,8 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { /
585615
}
586616
}
587617

618+
preferredNodeID, hasPreferredNodeID := endpoint.ContextNodeID(ctx)
619+
588620
for ; attempt < maxAttempts; attempt++ {
589621
select {
590622
case <-p.done:
@@ -593,6 +625,18 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { /
593625
}
594626

595627
if item := xsync.WithLock(&p.mu, func() PT { //nolint:nestif
628+
if hasPreferredNodeID {
629+
item := p.removeIdleByNodeID(preferredNodeID)
630+
if item != nil {
631+
return item
632+
}
633+
634+
if len(p.index)+p.createInProgress < p.config.limit {
635+
// for create item with preferred nodeID
636+
return nil
637+
}
638+
}
639+
596640
return p.removeFirstIdle()
597641
}); item != nil {
598642
if item.IsAlive() {
@@ -706,15 +750,15 @@ func (p *Pool[PT, T]) waitFromCh(ctx context.Context) (item PT, finalErr error)
706750

707751
case item, ok := <-*ch:
708752
// Note that race may occur and some goroutine may try to write
709-
// session into channel after it was enqueued but before it being
753+
// item into channel after it was enqueued but before it being
710754
// read here. In that case we will receive nil here and will retry.
711755
//
712-
// The same way will work when some session become deleted - the
756+
// The same way will work when some item become deleted - the
713757
// nil value will be sent into the channel.
714758
if ok {
715759
// Put only filled and not closed channel back to the Client.
716760
// That is, we need to avoid races on filling reused channel
717-
// for the next waiter – session could be lost for a long time.
761+
// for the next waiter – item could be lost for a long time.
718762
p.putWaitCh(ch)
719763
}
720764

internal/pool/pool_test.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
grpcStatus "google.golang.org/grpc/status"
2121

2222
"github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
23+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
2324
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
2425
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
2526
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
@@ -38,6 +39,7 @@ type (
3839

3940
onClose func() error
4041
onIsAlive func() bool
42+
onNodeID func() uint32
4143
}
4244
testWaitChPool struct {
4345
xsync.Pool[chan *testItem]
@@ -113,6 +115,14 @@ func (t *testItem) ID() string {
113115
return ""
114116
}
115117

118+
func (t *testItem) NodeID() uint32 {
119+
if t.onNodeID != nil {
120+
return t.onNodeID()
121+
}
122+
123+
return 0
124+
}
125+
116126
func (t *testItem) Close(context.Context) error {
117127
if t.closed.Len() > 0 {
118128
debug.PrintStack()
@@ -167,10 +177,126 @@ func TestPool(t *testing.T) { //nolint:gocyclo
167177
WithTrace[*testItem, testItem](defaultTrace),
168178
)
169179
err := p.With(rootCtx, func(ctx context.Context, testItem *testItem) error {
180+
require.EqualValues(t, 0, testItem.NodeID())
181+
170182
return nil
171183
})
172184
require.NoError(t, err)
173185
})
186+
t.Run("RequireNodeIdFromPool", func(t *testing.T) {
187+
nextNodeID := uint32(0)
188+
var newSessionCalled uint32
189+
p := New[*testItem, testItem](rootCtx,
190+
WithTrace[*testItem, testItem](defaultTrace),
191+
WithCreateItemFunc(func(ctx context.Context) (*testItem, error) {
192+
newSessionCalled++
193+
var (
194+
nodeID = nextNodeID
195+
v = testItem{
196+
v: 0,
197+
onNodeID: func() uint32 {
198+
return nodeID
199+
},
200+
}
201+
)
202+
203+
return &v, nil
204+
}),
205+
)
206+
207+
item := mustGetItem(t, p)
208+
require.EqualValues(t, 0, item.NodeID())
209+
require.EqualValues(t, true, item.IsAlive())
210+
mustPutItem(t, p, item)
211+
212+
nextNodeID = 32
213+
214+
item, err := p.getItem(endpoint.WithNodeID(context.Background(), 32))
215+
require.NoError(t, err)
216+
require.EqualValues(t, 32, item.NodeID())
217+
mustPutItem(t, p, item)
218+
219+
nextNodeID = 33
220+
221+
item, err = p.getItem(endpoint.WithNodeID(context.Background(), 33))
222+
require.NoError(t, err)
223+
require.EqualValues(t, 33, item.NodeID())
224+
mustPutItem(t, p, item)
225+
226+
item, err = p.getItem(endpoint.WithNodeID(context.Background(), 32))
227+
require.NoError(t, err)
228+
require.EqualValues(t, 32, item.NodeID())
229+
mustPutItem(t, p, item)
230+
231+
item, err = p.getItem(endpoint.WithNodeID(context.Background(), 33))
232+
require.NoError(t, err)
233+
require.EqualValues(t, 33, item.NodeID())
234+
mustPutItem(t, p, item)
235+
236+
item, err = p.getItem(endpoint.WithNodeID(context.Background(), 32))
237+
require.NoError(t, err)
238+
item2, err := p.getItem(endpoint.WithNodeID(context.Background(), 33))
239+
require.NoError(t, err)
240+
require.EqualValues(t, 32, item.NodeID())
241+
require.EqualValues(t, 33, item2.NodeID())
242+
mustPutItem(t, p, item2)
243+
mustPutItem(t, p, item)
244+
245+
item, err = p.getItem(endpoint.WithNodeID(context.Background(), 32))
246+
require.NoError(t, err)
247+
item2, err = p.getItem(endpoint.WithNodeID(context.Background(), 33))
248+
require.NoError(t, err)
249+
require.EqualValues(t, 32, item.NodeID())
250+
require.EqualValues(t, 33, item2.NodeID())
251+
mustPutItem(t, p, item)
252+
mustPutItem(t, p, item2)
253+
254+
item, err = p.getItem(endpoint.WithNodeID(context.Background(), 32))
255+
require.NoError(t, err)
256+
item2, err = p.getItem(endpoint.WithNodeID(context.Background(), 33))
257+
require.NoError(t, err)
258+
item3, err := p.getItem(context.Background())
259+
require.NoError(t, err)
260+
require.EqualValues(t, 32, item.NodeID())
261+
require.EqualValues(t, 33, item2.NodeID())
262+
require.EqualValues(t, 0, item3.NodeID())
263+
mustPutItem(t, p, item)
264+
mustPutItem(t, p, item2)
265+
mustPutItem(t, p, item3)
266+
267+
require.EqualValues(t, 3, newSessionCalled)
268+
})
269+
t.Run("CreateSessionOnGivenNode", func(t *testing.T) {
270+
var newSessionCalled uint32
271+
p := New[*testItem, testItem](rootCtx,
272+
WithTrace[*testItem, testItem](defaultTrace),
273+
WithCreateItemFunc(func(ctx context.Context) (*testItem, error) {
274+
newSessionCalled++
275+
v := testItem{
276+
v: 0,
277+
onNodeID: func() uint32 {
278+
nodeID, _ := endpoint.ContextNodeID(ctx)
279+
280+
return nodeID
281+
},
282+
}
283+
284+
return &v, nil
285+
}),
286+
)
287+
288+
item, err := p.getItem(endpoint.WithNodeID(context.Background(), 32))
289+
require.NoError(t, err)
290+
require.EqualValues(t, 32, item.NodeID())
291+
require.EqualValues(t, true, item.IsAlive())
292+
mustPutItem(t, p, item)
293+
294+
item = mustGetItem(t, p)
295+
require.EqualValues(t, 32, item.NodeID())
296+
mustPutItem(t, p, item)
297+
298+
require.EqualValues(t, 1, newSessionCalled)
299+
})
174300
t.Run("WithLimit", func(t *testing.T) {
175301
p := New[*testItem, testItem](rootCtx, WithLimit[*testItem, testItem](1),
176302
WithTrace[*testItem, testItem](defaultTrace),

0 commit comments

Comments
 (0)