Skip to content

Commit a4838d4

Browse files
authored
test(sources/s3): fix infinite blocking and timeout issue in TestSource_Chunks (#4048)
This PR addresses two issues: Duplicate test functions in s3_integration_test.go. The file currently contains two separate TestSourceChunksNoResumption. The only differences between them are the use of t.Parallel() in one and the bucket name. This PR merges both into a single, table-driven test Possible infinite block in TestSource_Chunks As reported in f#4069 (comment), the test TestSource_Chunks in s3_test.go may block indefinitely. This is due to the use of an unbuffered channel (chunksCh), combined with only a single receive operation (gotChunk := <-chunksCh). If the test bucket contains more than one chunk, the s.Chunks(ctx, chunksCh) call will block, causing the test to hang. See #4048 (comment) for full explanation. * test(sources/s3): remove duplicate integration test case Signed-off-by: Eng Zer Jun <[email protected]> * Merge two TestSourceChunksNoResumption into table-driven test Signed-off-by: Eng Zer Jun <[email protected]> * Fix infinite blocking and timeout issue in TestSource_Chunks Signed-off-by: Eng Zer Jun <[email protected]> * Drain `chunksCh` channel Reference: #4048 (review) Signed-off-by: Eng Zer Jun <[email protected]> * Add missing return for ctx.Done() case Signed-off-by: Eng Zer Jun <[email protected]> --------- Signed-off-by: Eng Zer Jun <[email protected]>
1 parent cb93ce2 commit a4838d4

File tree

2 files changed

+73
-69
lines changed

2 files changed

+73
-69
lines changed

pkg/sources/s3/s3_integration_test.go

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -82,37 +82,6 @@ func TestSource_ChunksLarge(t *testing.T) {
8282
assert.Equal(t, got, wantChunkCount)
8383
}
8484

85-
func TestSourceChunksNoResumption(t *testing.T) {
86-
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
87-
defer cancel()
88-
89-
s := Source{}
90-
connection := &sourcespb.S3{
91-
Credential: &sourcespb.S3_Unauthenticated{},
92-
Buckets: []string{"trufflesec-ahrav-test-2"},
93-
}
94-
conn, err := anypb.New(connection)
95-
if err != nil {
96-
t.Fatal(err)
97-
}
98-
99-
err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
100-
chunksCh := make(chan *sources.Chunk)
101-
go func() {
102-
defer close(chunksCh)
103-
err = s.Chunks(ctx, chunksCh)
104-
assert.Nil(t, err)
105-
}()
106-
107-
wantChunkCount := 19787
108-
got := 0
109-
110-
for range chunksCh {
111-
got++
112-
}
113-
assert.Equal(t, got, wantChunkCount)
114-
}
115-
11685
func TestSource_Validate(t *testing.T) {
11786
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
11887
defer cancel()
@@ -251,34 +220,50 @@ func TestSource_Validate(t *testing.T) {
251220
func TestSourceChunksNoResumption(t *testing.T) {
252221
t.Parallel()
253222

254-
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
255-
defer cancel()
256-
257-
s := Source{}
258-
connection := &sourcespb.S3{
259-
Credential: &sourcespb.S3_Unauthenticated{},
260-
Buckets: []string{"integration-resumption-tests"},
261-
}
262-
conn, err := anypb.New(connection)
263-
if err != nil {
264-
t.Fatal(err)
223+
tests := []struct {
224+
bucket string
225+
wantChunkCount int
226+
}{
227+
{
228+
bucket: "trufflesec-ahrav-test-2",
229+
wantChunkCount: 19787,
230+
},
231+
{
232+
bucket: "integration-resumption-tests",
233+
wantChunkCount: 19787,
234+
},
265235
}
266236

267-
err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
268-
chunksCh := make(chan *sources.Chunk)
269-
go func() {
270-
defer close(chunksCh)
271-
err = s.Chunks(ctx, chunksCh)
272-
assert.Nil(t, err)
273-
}()
274-
275-
wantChunkCount := 19787
276-
got := 0
237+
for _, tt := range tests {
238+
t.Run(tt.bucket, func(t *testing.T) {
239+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
240+
defer cancel()
241+
242+
s := Source{}
243+
connection := &sourcespb.S3{
244+
Credential: &sourcespb.S3_Unauthenticated{},
245+
Buckets: []string{tt.bucket},
246+
}
247+
conn, err := anypb.New(connection)
248+
if err != nil {
249+
t.Fatal(err)
250+
}
277251

278-
for range chunksCh {
279-
got++
252+
err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
253+
chunksCh := make(chan *sources.Chunk)
254+
go func() {
255+
defer close(chunksCh)
256+
err = s.Chunks(ctx, chunksCh)
257+
assert.Nil(t, err)
258+
}()
259+
260+
got := 0
261+
for range chunksCh {
262+
got++
263+
}
264+
assert.Equal(t, tt.wantChunkCount, got)
265+
})
280266
}
281-
assert.Equal(t, wantChunkCount, got)
282267
}
283268

284269
func TestSourceChunksResumption(t *testing.T) {

pkg/sources/s3/s3_test.go

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"encoding/base64"
55
"fmt"
66
"os"
7-
"sync"
87
"testing"
98
"time"
109

@@ -99,8 +98,7 @@ func TestSource_Chunks(t *testing.T) {
9998
for _, tt := range tests {
10099
t.Run(tt.name, func(t *testing.T) {
101100
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
102-
var cancelOnce sync.Once
103-
defer cancelOnce.Do(cancel)
101+
defer cancel()
104102

105103
for k, v := range tt.init.setEnv {
106104
t.Setenv(k, v)
@@ -117,26 +115,47 @@ func TestSource_Chunks(t *testing.T) {
117115
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
118116
return
119117
}
120-
chunksCh := make(chan *sources.Chunk)
121-
var wg sync.WaitGroup
122-
wg.Add(1)
118+
chunksCh := make(chan *sources.Chunk, 1)
123119
go func() {
124-
defer wg.Done()
120+
defer close(chunksCh)
125121
err = s.Chunks(ctx, chunksCh)
126122
if (err != nil) != tt.wantErr {
127123
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
128124
os.Exit(1)
129125
}
130126
}()
131-
gotChunk := <-chunksCh
132-
wantData, _ := base64.StdEncoding.DecodeString(tt.wantChunkData)
133127

134-
if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
135-
t.Errorf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
128+
waitFn := func() {
129+
receivedFirstChunk := false
130+
for {
131+
select {
132+
case <-ctx.Done():
133+
t.Errorf("TestSource_Chunks timed out: %v", ctx.Err())
134+
return
135+
case gotChunk, ok := <-chunksCh:
136+
if !ok {
137+
t.Logf("Source.Chunks() finished, channel closed")
138+
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
139+
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
140+
return
141+
}
142+
if receivedFirstChunk {
143+
// wantChunkData is the first chunk data. After the first chunk has
144+
// been received and matched below, we want to drain chunksCh
145+
// so Source.Chunks() can finish completely.
146+
continue
147+
}
148+
149+
receivedFirstChunk = true
150+
wantData, _ := base64.StdEncoding.DecodeString(tt.wantChunkData)
151+
152+
if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
153+
t.Logf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
154+
}
155+
}
156+
}
136157
}
137-
wg.Wait()
138-
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
139-
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
158+
waitFn()
140159
})
141160
}
142161
}

0 commit comments

Comments
 (0)