Skip to content

Commit c93c91c

Browse files
authored
Merge pull request #1418 from ydb-platform/pool-sema
refactoring of internal/pool
2 parents aa4bd13 + 3a98bcb commit c93c91c

File tree

3 files changed

+73
-184
lines changed

3 files changed

+73
-184
lines changed

internal/pool/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ import (
77
var (
88
errClosedPool = errors.New("closed pool")
99
errItemIsNotAlive = errors.New("item is not alive")
10+
errPoolIsOverflow = errors.New("pool is overflow")
1011
)

internal/pool/pool.go

Lines changed: 70 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,6 @@ type (
2020
IsAlive() bool
2121
Close(ctx context.Context) error
2222
}
23-
safeStats struct {
24-
mu xsync.RWMutex
25-
v Stats
26-
onChange func(Stats)
27-
}
28-
statsItemAddr struct {
29-
v *int
30-
onChange func(func())
31-
}
3223
Pool[PT Item[T], T any] struct {
3324
trace *Trace
3425
limit int
@@ -37,80 +28,18 @@ type (
3728
createTimeout time.Duration
3829
closeTimeout time.Duration
3930

40-
mu xsync.Mutex
31+
sema chan struct{}
32+
33+
mu xsync.RWMutex
4134
idle []PT
4235
index map[PT]struct{}
4336
done chan struct{}
4437

45-
stats *safeStats
38+
stats *Stats
4639
}
4740
option[PT Item[T], T any] func(p *Pool[PT, T])
4841
)
4942

50-
func (field statsItemAddr) Inc() {
51-
field.onChange(func() {
52-
*field.v++
53-
})
54-
}
55-
56-
func (field statsItemAddr) Dec() {
57-
field.onChange(func() {
58-
*field.v--
59-
})
60-
}
61-
62-
func (s *safeStats) Get() Stats {
63-
s.mu.RLock()
64-
defer s.mu.RUnlock()
65-
66-
return s.v
67-
}
68-
69-
func (s *safeStats) Index() statsItemAddr {
70-
s.mu.RLock()
71-
defer s.mu.RUnlock()
72-
73-
return statsItemAddr{
74-
v: &s.v.Index,
75-
onChange: func(f func()) {
76-
s.mu.WithLock(f)
77-
if s.onChange != nil {
78-
s.onChange(s.Get())
79-
}
80-
},
81-
}
82-
}
83-
84-
func (s *safeStats) Idle() statsItemAddr {
85-
s.mu.RLock()
86-
defer s.mu.RUnlock()
87-
88-
return statsItemAddr{
89-
v: &s.v.Idle,
90-
onChange: func(f func()) {
91-
s.mu.WithLock(f)
92-
if s.onChange != nil {
93-
s.onChange(s.Get())
94-
}
95-
},
96-
}
97-
}
98-
99-
func (s *safeStats) InUse() statsItemAddr {
100-
s.mu.RLock()
101-
defer s.mu.RUnlock()
102-
103-
return statsItemAddr{
104-
v: &s.v.InUse,
105-
onChange: func(f func()) {
106-
s.mu.WithLock(f)
107-
if s.onChange != nil {
108-
s.onChange(s.Get())
109-
}
110-
},
111-
}
112-
}
113-
11443
func WithCreateFunc[PT Item[T], T any](f func(ctx context.Context) (PT, error)) option[PT, T] {
11544
return func(p *Pool[PT, T]) {
11645
p.createItem = f
@@ -170,13 +99,10 @@ func New[PT Item[T], T any](
17099
}()
171100

172101
p.createItem = createItemWithTimeoutHandling(p.createItem, p)
173-
102+
p.sema = make(chan struct{}, p.limit)
174103
p.idle = make([]PT, 0, p.limit)
175104
p.index = make(map[PT]struct{}, p.limit)
176-
p.stats = &safeStats{
177-
v: Stats{Limit: p.limit},
178-
onChange: p.trace.OnChange,
179-
}
105+
p.stats = &Stats{Limit: p.limit}
180106

181107
return p
182108
}
@@ -263,7 +189,7 @@ func createItemWithContext[PT Item[T], T any](
263189
if len(p.index) < p.limit {
264190
p.idle = append(p.idle, newItem)
265191
p.index[newItem] = struct{}{}
266-
p.stats.Index().Inc()
192+
p.stats.Index++
267193
needCloseItem = false
268194
}
269195

@@ -276,10 +202,13 @@ func createItemWithContext[PT Item[T], T any](
276202
}
277203

278204
func (p *Pool[PT, T]) Stats() Stats {
279-
return p.stats.Get()
205+
p.mu.RLock()
206+
defer p.mu.RUnlock()
207+
208+
return *p.stats
280209
}
281210

282-
func (p *Pool[PT, T]) getItem(ctx context.Context) (_ PT, finalErr error) {
211+
func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) {
283212
onDone := p.trace.OnGet(&GetStartInfo{
284213
Context: &ctx,
285214
Call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/pool.(*Pool).getItem"),
@@ -290,56 +219,29 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (_ PT, finalErr error) {
290219
})
291220
}()
292221

293-
if err := ctx.Err(); err != nil {
294-
return nil, xerrors.WithStackTrace(err)
295-
}
222+
p.mu.Lock()
223+
defer p.mu.Unlock()
296224

297-
for {
298-
select {
299-
case <-p.done:
300-
return nil, xerrors.WithStackTrace(errClosedPool)
301-
case <-ctx.Done():
302-
return nil, xerrors.WithStackTrace(ctx.Err())
303-
default:
304-
var item PT
305-
p.mu.WithLock(func() {
306-
if len(p.idle) > 0 {
307-
item, p.idle = p.idle[0], p.idle[1:]
308-
p.stats.Idle().Dec()
309-
}
310-
})
225+
if len(p.idle) > 0 {
226+
item, p.idle = p.idle[0], p.idle[1:]
227+
p.stats.Idle--
311228

312-
if item != nil {
313-
if item.IsAlive() {
314-
return item, nil
315-
}
316-
_ = p.closeItem(ctx, item)
317-
p.mu.WithLock(func() {
318-
delete(p.index, item)
319-
})
320-
p.stats.Index().Dec()
321-
}
322-
var err error
323-
var newItem PT
324-
p.mu.WithLock(func() {
325-
if len(p.index) >= p.limit {
326-
return
327-
}
328-
newItem, err = p.createItem(ctx)
329-
if err != nil {
330-
return
331-
}
332-
p.index[newItem] = struct{}{}
333-
p.stats.Index().Inc()
334-
})
335-
if err != nil {
336-
return nil, xerrors.WithStackTrace(err)
337-
}
338-
if newItem != nil {
339-
return newItem, nil
340-
}
229+
if item.IsAlive() {
230+
return item, nil
341231
}
232+
233+
_ = p.closeItem(ctx, item)
342234
}
235+
236+
newItem, err := p.createItem(ctx)
237+
if err != nil {
238+
return nil, xerrors.WithStackTrace(xerrors.Retryable(err))
239+
}
240+
241+
p.stats.Index++
242+
p.index[newItem] = struct{}{}
243+
244+
return newItem, nil
343245
}
344246

345247
func (p *Pool[PT, T]) putItem(ctx context.Context, item PT) (finalErr error) {
@@ -353,37 +255,28 @@ func (p *Pool[PT, T]) putItem(ctx context.Context, item PT) (finalErr error) {
353255
})
354256
}()
355257

356-
if err := ctx.Err(); err != nil {
357-
return xerrors.WithStackTrace(err)
258+
if !item.IsAlive() {
259+
_ = p.closeItem(ctx, item)
260+
261+
return xerrors.WithStackTrace(errItemIsNotAlive)
358262
}
359263

360-
select {
361-
case <-p.done:
362-
return xerrors.WithStackTrace(errClosedPool)
363-
default:
364-
if !item.IsAlive() {
365-
_ = p.closeItem(ctx, item)
264+
p.mu.Lock()
265+
defer p.mu.Unlock()
366266

367-
p.mu.WithLock(func() {
368-
delete(p.index, item)
369-
})
370-
p.stats.Index().Dec()
267+
if len(p.idle) >= p.limit {
268+
_ = p.closeItem(ctx, item)
371269

372-
return xerrors.WithStackTrace(errItemIsNotAlive)
373-
}
270+
return xerrors.WithStackTrace(errPoolIsOverflow)
271+
}
374272

375-
p.mu.WithLock(func() {
376-
p.idle = append(p.idle, item)
377-
})
378-
p.stats.Idle().Inc()
273+
p.idle = append(p.idle, item)
274+
p.stats.Idle--
379275

380-
return nil
381-
}
276+
return nil
382277
}
383278

384279
func (p *Pool[PT, T]) closeItem(ctx context.Context, item PT) error {
385-
ctx = xcontext.ValueOnly(ctx)
386-
387280
var cancel context.CancelFunc
388281
if d := p.closeTimeout; d > 0 {
389282
ctx, cancel = xcontext.WithTimeout(ctx, d)
@@ -392,6 +285,13 @@ func (p *Pool[PT, T]) closeItem(ctx context.Context, item PT) error {
392285
}
393286
defer cancel()
394287

288+
defer func() {
289+
p.mu.WithLock(func() {
290+
delete(p.index, item)
291+
p.stats.Index--
292+
})
293+
}()
294+
395295
return item.Close(ctx)
396296
}
397297

@@ -406,6 +306,17 @@ func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item
406306
})
407307
}()
408308

309+
select {
310+
case <-p.done:
311+
return xerrors.WithStackTrace(errClosedPool)
312+
case <-ctx.Done():
313+
return xerrors.WithStackTrace(ctx.Err())
314+
case p.sema <- struct{}{}:
315+
defer func() {
316+
<-p.sema
317+
}()
318+
}
319+
409320
item, err := p.getItem(ctx)
410321
if err != nil {
411322
if xerrors.IsYdb(err) {
@@ -419,8 +330,14 @@ func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item
419330
_ = p.putItem(ctx, item)
420331
}()
421332

422-
p.stats.InUse().Inc()
423-
defer p.stats.InUse().Dec()
333+
p.mu.Lock()
334+
p.stats.InUse++
335+
p.mu.Unlock()
336+
defer func() {
337+
p.mu.Lock()
338+
p.stats.InUse--
339+
p.mu.Unlock()
340+
}()
424341

425342
err = f(ctx, item)
426343
if err != nil {

internal/pool/pool_test.go

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package pool
33
import (
44
"context"
55
"errors"
6-
"math/rand"
76
"sync"
87
"sync/atomic"
98
"testing"
@@ -287,7 +286,7 @@ func TestPool(t *testing.T) {
287286
require.NoError(t, err)
288287
}()
289288
wg.Wait()
290-
}, xtest.StopAfter(42*time.Second))
289+
}, xtest.StopAfter(14*time.Second))
291290
})
292291
t.Run("ParallelCreation", func(t *testing.T) {
293292
xtest.TestManyTimes(t, func(t testing.TB) {
@@ -313,34 +312,6 @@ func TestPool(t *testing.T) {
313312
require.Equal(t, DefaultLimit, stats.Limit)
314313
require.Equal(t, 0, stats.InUse)
315314
require.LessOrEqual(t, stats.Idle, DefaultLimit)
316-
}, xtest.StopAfter(30*time.Second))
315+
}, xtest.StopAfter(14*time.Second))
317316
})
318317
}
319-
320-
func TestSafeStatsRace(t *testing.T) {
321-
xtest.TestManyTimes(t, func(t testing.TB) {
322-
var (
323-
wg sync.WaitGroup
324-
s = &safeStats{}
325-
)
326-
wg.Add(1000)
327-
for range make([]struct{}, 1000) {
328-
go func() {
329-
defer wg.Done()
330-
require.NotPanics(t, func() {
331-
switch rand.Int31n(4) { //nolint:gosec
332-
case 0:
333-
s.Index().Inc()
334-
case 1:
335-
s.InUse().Inc()
336-
case 2:
337-
s.Idle().Inc()
338-
default:
339-
s.Get()
340-
}
341-
})
342-
}()
343-
}
344-
wg.Wait()
345-
}, xtest.StopAfter(5*time.Second))
346-
}

0 commit comments

Comments
 (0)