Skip to content

Commit 3521fcd

Browse files
committed
fix(zstd decompression): limit concurrency to 1 to prevent deadlock in zstd library
1 parent 2cfec8c commit 3521fcd

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

encoding/codecv7_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ import (
44
"encoding/hex"
55
"encoding/json"
66
"fmt"
7+
"log"
78
"math/big"
89
"math/rand"
10+
"net/http"
11+
_ "net/http"
12+
_ "net/http/pprof"
913
"strings"
1014
"testing"
1115

@@ -17,6 +21,32 @@ import (
1721
"github.com/stretchr/testify/require"
1822
)
1923

24+
// TestDecodeAllDeadlock tests the decompression of random bytes to trigger deadlock in zstd library.
25+
26+
func TestDecodeAllDeadlock(t *testing.T) {
27+
//t.Skip("Skip test that triggers deadlock in zstd library")
28+
29+
go func() {
30+
log.Println(http.ListenAndServe("localhost:6060", nil))
31+
}()
32+
33+
// generate some random bytes
34+
randomBytes := make([]byte, maxBlobBytes)
35+
rand.Read(randomBytes)
36+
37+
c := NewDACodecV8()
38+
39+
compressed, err := c.CompressScrollBatchBytes(randomBytes)
40+
require.NoError(t, err)
41+
42+
// repeatedly decompress the bytes to trigger deadlock in zstd library
43+
for i := 0; i < 100000; i++ {
44+
uncompressed, err := decompressV7Bytes(compressed)
45+
require.NoError(t, err)
46+
require.Equal(t, randomBytes, uncompressed)
47+
}
48+
}
49+
2050
// TestCodecV7DABlockEncodeDecode tests the encoding and decoding of daBlockV7.
2151
func TestCodecV7DABlockEncodeDecode(t *testing.T) {
2252
codecV7, err := CodecFromVersion(CodecV7)

encoding/codecv7_types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ func decompressV7Bytes(compressedBytes []byte) ([]byte, error) {
482482

483483
compressedBytes = append(zstdMagicNumber, compressedBytes...)
484484
r := bytes.NewReader(compressedBytes)
485-
zr, err := zstd.NewReader(r)
485+
zr, err := zstd.NewReader(r, zstd.WithDecoderConcurrency(1))
486486
if err != nil {
487487
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
488488
}

0 commit comments

Comments
 (0)