Skip to content

Commit c63687c

Browse files
authored
Merge pull request #416 Fix race condition bugs for background worker
2 parents a35129b + 7842b2f commit c63687c

File tree

4 files changed

+253
-36
lines changed

4 files changed

+253
-36
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Fix internal race-condition bugs in internal background worker
2+
13
## v3.38.3
24
* Added retries to initial discovering
35

internal/background/worker.go

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"runtime/pprof"
77
"sync"
8-
"sync/atomic"
98

109
"github.com/ydb-platform/ydb-go-sdk/v3/internal/empty"
1110
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
@@ -24,13 +23,19 @@ type Worker struct {
2423
workers sync.WaitGroup
2524
onceInit sync.Once
2625

26+
tasksCompleted empty.Chan
27+
2728
m xsync.Mutex
2829

29-
closed uint32
30+
tasks chan backgroundTask
31+
32+
closed bool
3033
stop xcontext.CancelErrFunc
3134
closeReason error
3235
}
3336

37+
type CallbackFunc func(ctx context.Context)
38+
3439
func NewWorker(parent context.Context) *Worker {
3540
w := Worker{}
3641
w.ctx, w.stop = xcontext.WithErrCancel(parent)
@@ -44,54 +49,52 @@ func (b *Worker) Context() context.Context {
4449
return b.ctx
4550
}
4651

47-
func (b *Worker) Start(name string, f func(ctx context.Context)) {
48-
if atomic.LoadUint32(&b.closed) != 0 {
49-
f(b.ctx)
50-
return
51-
}
52-
52+
func (b *Worker) Start(name string, f CallbackFunc) {
5353
b.init()
5454

55-
b.m.Lock()
56-
defer b.m.Unlock()
57-
58-
if b.ctx.Err() != nil {
59-
return
60-
}
61-
62-
b.workers.Add(1)
63-
go func() {
64-
defer b.workers.Done()
55+
b.m.WithLock(func() {
56+
if b.closed {
57+
return
58+
}
6559

66-
pprof.Do(b.ctx, pprof.Labels("background", name), f)
67-
}()
60+
b.tasks <- backgroundTask{
61+
callback: f,
62+
name: name,
63+
}
64+
})
6865
}
6966

7067
func (b *Worker) Done() <-chan struct{} {
7168
b.init()
7269

73-
b.m.Lock()
74-
defer b.m.Unlock()
75-
7670
return b.ctx.Done()
7771
}
7872

7973
func (b *Worker) Close(ctx context.Context, err error) error {
80-
if !atomic.CompareAndSwapUint32(&b.closed, 0, 1) {
81-
return xerrors.WithStackTrace(ErrAlreadyClosed)
82-
}
83-
8474
b.init()
8575

86-
b.m.Lock()
87-
defer b.m.Unlock()
76+
var resErr error
77+
b.m.WithLock(func() {
78+
if b.closed {
79+
resErr = xerrors.WithStackTrace(ErrAlreadyClosed)
80+
return
81+
}
8882

89-
b.closeReason = err
90-
if b.closeReason == nil {
91-
b.closeReason = errClosedWithNilReason
83+
b.closed = true
84+
85+
close(b.tasks)
86+
b.closeReason = err
87+
if b.closeReason == nil {
88+
b.closeReason = errClosedWithNilReason
89+
}
90+
91+
b.stop(err)
92+
})
93+
if resErr != nil {
94+
return resErr
9295
}
9396

94-
b.stop(err)
97+
<-b.tasksCompleted
9598

9699
bgCompleted := make(empty.Chan)
97100

@@ -120,5 +123,27 @@ func (b *Worker) init() {
120123
if b.ctx == nil {
121124
b.ctx, b.stop = xcontext.WithErrCancel(context.Background())
122125
}
126+
b.tasks = make(chan backgroundTask)
127+
b.tasksCompleted = make(empty.Chan)
128+
go b.starterLoop()
123129
})
124130
}
131+
132+
func (b *Worker) starterLoop() {
133+
defer close(b.tasksCompleted)
134+
135+
for bgTask := range b.tasks {
136+
b.workers.Add(1)
137+
138+
go func(task backgroundTask) {
139+
defer b.workers.Done()
140+
141+
pprof.Do(b.ctx, pprof.Labels("background", task.name), task.callback)
142+
}(bgTask)
143+
}
144+
}
145+
146+
type backgroundTask struct {
147+
callback CallbackFunc
148+
name string
149+
}

internal/background/worker_test.go

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package background
2+
3+
import (
4+
"context"
5+
"runtime"
6+
"sync/atomic"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/empty"
13+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xatomic"
14+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
15+
)
16+
17+
func TestWorkerContext(t *testing.T) {
18+
t.Run("Empty", func(t *testing.T) {
19+
w := Worker{}
20+
require.NotNil(t, w.Context())
21+
require.NotNil(t, w.ctx)
22+
require.NotNil(t, w.stop)
23+
})
24+
25+
t.Run("Dedicated", func(t *testing.T) {
26+
type ctxkey struct{}
27+
ctx := context.WithValue(context.Background(), ctxkey{}, "2")
28+
w := NewWorker(ctx)
29+
require.Equal(t, "2", w.Context().Value(ctxkey{}))
30+
})
31+
32+
t.Run("Stop", func(t *testing.T) {
33+
w := Worker{}
34+
ctx := w.Context()
35+
require.NoError(t, ctx.Err())
36+
37+
_ = w.Close(context.Background(), nil)
38+
require.Error(t, ctx.Err())
39+
})
40+
}
41+
42+
func TestWorkerStart(t *testing.T) {
43+
t.Run("Started", func(t *testing.T) {
44+
w := NewWorker(xtest.Context(t))
45+
started := make(empty.Chan)
46+
w.Start("test", func(ctx context.Context) {
47+
close(started)
48+
})
49+
xtest.WaitChannelClosed(t, started)
50+
})
51+
t.Run("Stopped", func(t *testing.T) {
52+
ctx := xtest.Context(t)
53+
w := NewWorker(ctx)
54+
_ = w.Close(ctx, nil)
55+
56+
started := make(empty.Chan)
57+
w.Start("test", func(ctx context.Context) {
58+
close(started)
59+
})
60+
61+
// expected: no close channel
62+
time.Sleep(time.Second / 100)
63+
select {
64+
case <-started:
65+
t.Fatal()
66+
default:
67+
// pass
68+
}
69+
})
70+
}
71+
72+
func TestWorkerClose(t *testing.T) {
73+
t.Run("StopBackground", func(t *testing.T) {
74+
ctx := xtest.Context(t)
75+
w := NewWorker(ctx)
76+
77+
started := make(empty.Chan)
78+
stopped := xatomic.Bool{}
79+
w.Start("test", func(innerCtx context.Context) {
80+
close(started)
81+
<-innerCtx.Done()
82+
stopped.Store(true)
83+
})
84+
85+
xtest.WaitChannelClosed(t, started)
86+
require.NoError(t, w.Close(ctx, nil))
87+
require.True(t, stopped.Load())
88+
})
89+
90+
t.Run("DoubleClose", func(t *testing.T) {
91+
ctx := xtest.Context(t)
92+
w := NewWorker(ctx)
93+
require.NoError(t, w.Close(ctx, nil))
94+
require.Error(t, w.Close(ctx, nil))
95+
})
96+
}
97+
98+
func TestWorkerConcurrentStartAndClose(t *testing.T) {
99+
xtest.TestManyTimes(t, func(t testing.TB) {
100+
targetClose := int64(100)
101+
parallel := 10
102+
103+
var counter int64
104+
105+
ctx := xtest.Context(t)
106+
w := NewWorker(ctx)
107+
108+
closeIndex := int64(0)
109+
closed := make(empty.Chan)
110+
111+
go func() {
112+
defer close(closed)
113+
114+
xtest.SpinWaitCondition(t, nil, func() bool {
115+
return atomic.LoadInt64(&counter) > targetClose
116+
})
117+
118+
require.NoError(t, w.Close(ctx, nil))
119+
closeIndex = atomic.LoadInt64(&counter)
120+
}()
121+
122+
stopNewStarts := xatomic.Bool{}
123+
for i := 0; i < parallel; i++ {
124+
go func() {
125+
for {
126+
if stopNewStarts.Load() {
127+
return
128+
}
129+
130+
w.Start("test", func(ctx context.Context) {
131+
atomic.AddInt64(&counter, 1)
132+
})
133+
}
134+
}()
135+
}
136+
137+
xtest.WaitChannelClosed(t, closed)
138+
runtime.Gosched()
139+
require.Equal(t, closeIndex, atomic.LoadInt64(&counter))
140+
stopNewStarts.Store(true)
141+
})
142+
}
143+
144+
func TestWorkerStartCompletedWhileLongWait(t *testing.T) {
145+
xtest.TestManyTimes(t, func(t testing.TB) {
146+
ctx := xtest.Context(t)
147+
w := NewWorker(ctx)
148+
149+
allowStop := make(empty.Chan)
150+
closeStarted := make(empty.Chan)
151+
w.Start("test", func(ctx context.Context) {
152+
<-ctx.Done()
153+
close(closeStarted)
154+
155+
<-allowStop
156+
})
157+
158+
closed := make(empty.Chan)
159+
160+
callStartFinished := make(empty.Chan)
161+
go func() {
162+
defer close(callStartFinished)
163+
start := time.Now()
164+
165+
for time.Since(start) < time.Millisecond {
166+
w.Start("test2", func(ctx context.Context) {
167+
// pass
168+
})
169+
}
170+
}()
171+
172+
go func() {
173+
defer close(closed)
174+
175+
_ = w.Close(ctx, nil)
176+
}()
177+
178+
xtest.WaitChannelClosed(t, callStartFinished)
179+
runtime.Gosched()
180+
181+
select {
182+
case <-closed:
183+
t.Fatal()
184+
default:
185+
// pass
186+
}
187+
188+
close(allowStop)
189+
xtest.WaitChannelClosed(t, closed)
190+
})
191+
}

internal/topic/topicreaderinternal/stream_reconnector_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,13 @@ func TestTopicReaderReconnectorConnectionLoop(t *testing.T) {
222222
},
223223
{
224224
callback: func(ctx context.Context) (batchedStreamReader, error) {
225-
t.Error()
225+
t.Fatal()
226226
return nil, errors.New("unexpected call")
227227
},
228228
},
229229
}...)
230230

231-
reconnector.background.Start("test-reconnectionLoop", reconnector.reconnectionLoop)
232-
reconnector.reconnectFromBadStream <- nil
231+
reconnector.start()
233232

234233
<-stream1Ready
235234

0 commit comments

Comments
 (0)