Skip to content

Commit 4a46a77

Browse files
committed
fix worker start after close
1 parent 88cebd8 commit 4a46a77

File tree

2 files changed

+166
-25
lines changed

2 files changed

+166
-25
lines changed

internal/background/worker.go

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ import (
55
"errors"
66
"runtime/pprof"
77
"sync"
8-
"sync/atomic"
98

109
"github.com/ydb-platform/ydb-go-sdk/v3/internal/empty"
10+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xatomic"
1111
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
1212
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
1313
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
@@ -26,11 +26,13 @@ type Worker struct {
2626

2727
m xsync.Mutex
2828

29-
closed uint32
29+
closed xatomic.Bool
3030
stop xcontext.CancelErrFunc
3131
closeReason error
3232
}
3333

34+
type CallbackFunc func(ctx context.Context)
35+
3436
func NewWorker(parent context.Context) *Worker {
3537
w := Worker{}
3638
w.ctx, w.stop = xcontext.WithErrCancel(parent)
@@ -44,58 +46,59 @@ func (b *Worker) Context() context.Context {
4446
return b.ctx
4547
}
4648

47-
func (b *Worker) Start(name string, f func(ctx context.Context)) {
48-
if atomic.LoadUint32(&b.closed) != 0 {
49-
f(b.ctx)
49+
func (b *Worker) Start(name string, f CallbackFunc) {
50+
if b.closed.Load() {
5051
return
5152
}
5253

5354
b.init()
5455

55-
b.m.Lock()
56-
defer b.m.Unlock()
57-
5856
if b.ctx.Err() != nil {
5957
return
6058
}
6159

62-
b.workers.Add(1)
63-
go func() {
64-
defer b.workers.Done()
60+
b.m.WithLock(func() {
61+
if b.closed.Load() {
62+
return
63+
}
64+
b.workers.Add(1)
6565

66-
pprof.Do(b.ctx, pprof.Labels("background", name), f)
67-
}()
66+
go func() {
67+
defer b.workers.Done()
68+
69+
pprof.Do(b.ctx, pprof.Labels("background", name), f)
70+
}()
71+
})
6872
}
6973

7074
func (b *Worker) Done() <-chan struct{} {
7175
b.init()
7276

73-
b.m.Lock()
74-
defer b.m.Unlock()
75-
7677
return b.ctx.Done()
7778
}
7879

7980
func (b *Worker) Close(ctx context.Context, err error) error {
80-
if !atomic.CompareAndSwapUint32(&b.closed, 0, 1) {
81+
if b.closed.Swap(true) {
8182
return xerrors.WithStackTrace(ErrAlreadyClosed)
8283
}
8384

8485
b.init()
8586

86-
b.m.Lock()
87-
defer b.m.Unlock()
88-
89-
b.closeReason = err
90-
if b.closeReason == nil {
91-
b.closeReason = errClosedWithNilReason
92-
}
87+
b.m.WithLock(func() {
88+
b.closeReason = err
89+
if b.closeReason == nil {
90+
b.closeReason = errClosedWithNilReason
91+
}
9392

94-
b.stop(err)
93+
b.stop(err)
94+
})
9595

9696
bgCompleted := make(empty.Chan)
9797

9898
go func() {
99+
b.m.Lock()
100+
defer b.m.Unlock()
101+
99102
b.workers.Wait()
100103
close(bgCompleted)
101104
}()

internal/background/worker_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package background
2+
3+
import (
4+
"context"
5+
"runtime"
6+
"sync/atomic"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/empty"
13+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xatomic"
14+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
15+
)
16+
17+
func TestWorkerContext(t *testing.T) {
18+
t.Run("Empty", func(t *testing.T) {
19+
w := Worker{}
20+
require.NotNil(t, w.Context())
21+
require.NotNil(t, w.ctx)
22+
require.NotNil(t, w.stop)
23+
})
24+
25+
t.Run("Dedicated", func(t *testing.T) {
26+
ctx := context.WithValue(context.Background(), "1", "2")
27+
w := NewWorker(ctx)
28+
require.Equal(t, "2", w.Context().Value("1"))
29+
})
30+
31+
t.Run("Stop", func(t *testing.T) {
32+
w := Worker{}
33+
ctx := w.Context()
34+
require.NoError(t, ctx.Err())
35+
36+
_ = w.Close(context.Background(), nil)
37+
require.Error(t, ctx.Err())
38+
})
39+
}
40+
41+
func TestWorkerStart(t *testing.T) {
42+
t.Run("Started", func(t *testing.T) {
43+
w := NewWorker(xtest.Context(t))
44+
started := make(empty.Chan)
45+
w.Start("test", func(ctx context.Context) {
46+
close(started)
47+
})
48+
xtest.WaitChannelClosed(t, started)
49+
})
50+
t.Run("Stopped", func(t *testing.T) {
51+
ctx := xtest.Context(t)
52+
w := NewWorker(ctx)
53+
_ = w.Close(ctx, nil)
54+
55+
started := make(empty.Chan)
56+
w.Start("test", func(ctx context.Context) {
57+
close(started)
58+
})
59+
60+
// expected: no close channel
61+
time.Sleep(time.Second / 100)
62+
select {
63+
case <-started:
64+
t.Fatal()
65+
default:
66+
// pass
67+
}
68+
})
69+
}
70+
71+
func TestWorkerClose(t *testing.T) {
72+
t.Run("StopBackground", func(t *testing.T) {
73+
ctx := xtest.Context(t)
74+
w := NewWorker(ctx)
75+
76+
started := make(empty.Chan)
77+
stopped := xatomic.Bool{}
78+
w.Start("test", func(innerCtx context.Context) {
79+
close(started)
80+
<-innerCtx.Done()
81+
stopped.Store(true)
82+
})
83+
84+
xtest.WaitChannelClosed(t, started)
85+
require.NoError(t, w.Close(ctx, nil))
86+
require.True(t, stopped.Load())
87+
})
88+
89+
t.Run("DoubleClose", func(t *testing.T) {
90+
ctx := xtest.Context(t)
91+
w := NewWorker(ctx)
92+
require.NoError(t, w.Close(ctx, nil))
93+
require.Error(t, w.Close(ctx, nil))
94+
})
95+
}
96+
97+
func TestWorkerConcurrentStartAndClose(t *testing.T) {
98+
targetClose := int64(10000)
99+
parallel := 10
100+
101+
var counter int64
102+
103+
ctx := xtest.Context(t)
104+
w := NewWorker(ctx)
105+
106+
closeIndex := int64(0)
107+
closed := make(empty.Chan)
108+
go func() {
109+
xtest.SpinWaitCondition(t, nil, func() bool {
110+
return atomic.LoadInt64(&counter) > targetClose
111+
})
112+
require.NoError(t, w.Close(ctx, nil))
113+
closeIndex = atomic.LoadInt64(&counter)
114+
close(closed)
115+
}()
116+
117+
stopNewStarts := xatomic.Bool{}
118+
for i := 0; i < parallel; i++ {
119+
go func() {
120+
for {
121+
if stopNewStarts.Load() {
122+
return
123+
}
124+
125+
go func() {
126+
w.Start("test", func(ctx context.Context) {
127+
atomic.AddInt64(&counter, 1)
128+
})
129+
}()
130+
}
131+
}()
132+
}
133+
134+
xtest.WaitChannelClosed(t, closed)
135+
runtime.Gosched()
136+
require.Equal(t, closeIndex, atomic.LoadInt64(&counter))
137+
stopNewStarts.Store(true)
138+
}

0 commit comments

Comments
 (0)