Skip to content

Commit 1c41925

Browse files
committed
readPacket optimization
Since configuration options doesn't change at runtime, after connection is established, use dedicated function, in order to avoid multiple test test compress, checking ReadTimeout configuration option
1 parent 73d4e73 commit 1c41925

File tree

7 files changed

+103
-34
lines changed

7 files changed

+103
-34
lines changed

benchmark_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,47 @@ func benchmarkQueryHelper(b *testing.B, compr bool) {
113113
}
114114
}
115115

116+
func BenchmarkSelect10000rows(b *testing.B) {
117+
db := initDB(b, false)
118+
defer db.Close()
119+
120+
// Check if we're using MariaDB
121+
var version string
122+
err := db.QueryRow("SELECT @@version").Scan(&version)
123+
if err != nil {
124+
b.Fatalf("Failed to get server version: %v", err)
125+
}
126+
127+
if !strings.Contains(strings.ToLower(version), "mariadb") {
128+
b.Skip("Skipping benchmark as it requires MariaDB sequence table")
129+
return
130+
}
131+
132+
b.StartTimer()
133+
stmt, err := db.Prepare("SELECT * FROM seq_1_to_10000")
134+
if err != nil {
135+
b.Fatalf("Failed to prepare statement: %v", err)
136+
}
137+
defer stmt.Close()
138+
for n := 0; n < b.N; n++ {
139+
rows, err := stmt.Query()
140+
if err != nil {
141+
b.Fatalf("Failed to query 10000rows: %v", err)
142+
}
143+
144+
var id int64
145+
for rows.Next() {
146+
err = rows.Scan(&id)
147+
if err != nil {
148+
rows.Close()
149+
b.Fatalf("Failed to scan row: %v", err)
150+
}
151+
}
152+
rows.Close()
153+
}
154+
b.StopTimer()
155+
}
156+
116157
func BenchmarkExec(b *testing.B) {
117158
tb := (*TB)(b)
118159
b.StopTimer()

compress_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []by
4040
conn := new(mockConn)
4141
conn.data = compressedPacket
4242
mc.netConn = conn
43+
mc.readNextFunc = mc.compIO.readNext
44+
mc.readFunc = conn.Read
4345

4446
uncompressedPacket, err := mc.readPacket()
4547
if err != nil {

connection.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type mysqlConn struct {
4040
compressSequence uint8
4141
parseTime bool
4242
compress bool
43+
readFunc func([]byte) (int, error)
44+
readNextFunc func(int, readerFunc) ([]byte, error)
4345

4446
// for context support (Go 1.8+)
4547
watching bool
@@ -65,16 +67,6 @@ func (mc *mysqlConn) log(v ...any) {
6567
mc.cfg.Logger.Print(v...)
6668
}
6769

68-
func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) {
69-
to := mc.cfg.ReadTimeout
70-
if to > 0 {
71-
if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil {
72-
return 0, err
73-
}
74-
}
75-
return mc.netConn.Read(b)
76-
}
77-
7870
func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) {
7971
to := mc.cfg.WriteTimeout
8072
if to > 0 {

connection_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@ import (
1818
)
1919

2020
func TestInterpolateParams(t *testing.T) {
21+
buf := newBuffer()
22+
nc := &net.TCPConn{}
2123
mc := &mysqlConn{
22-
buf: newBuffer(),
24+
buf: buf,
25+
netConn: nc,
2326
maxAllowedPacket: maxPacketSize,
2427
cfg: &Config{
2528
InterpolateParams: true,
2629
},
30+
readNextFunc: buf.readNext,
31+
readFunc: nc.Read,
2732
}
2833

2934
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})

connector.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"os"
1717
"strconv"
1818
"strings"
19+
"time"
1920
)
2021

2122
type connector struct {
@@ -130,6 +131,22 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
130131

131132
mc.buf = newBuffer()
132133

134+
// setting readNext/read functions
135+
mc.readNextFunc = mc.buf.readNext
136+
137+
// Initialize read function based on configuration
138+
if mc.cfg.ReadTimeout > 0 {
139+
mc.readFunc = func(b []byte) (int, error) {
140+
deadline := time.Now().Add(mc.cfg.ReadTimeout)
141+
if err := mc.netConn.SetReadDeadline(deadline); err != nil {
142+
return 0, err
143+
}
144+
return mc.netConn.Read(b)
145+
}
146+
} else {
147+
mc.readFunc = mc.netConn.Read
148+
}
149+
133150
// Reading Handshake Initialization Packet
134151
authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket()
135152
if err != nil {
@@ -170,6 +187,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
170187
if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 {
171188
mc.compress = true
172189
mc.compIO = newCompIO(mc)
190+
mc.readNextFunc = mc.compIO.readNext
173191
}
174192
if mc.cfg.MaxAllowedPacket > 0 {
175193
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket

packets.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
3030
var prevData []byte
3131
invalidSequence := false
3232

33-
readNext := mc.buf.readNext
34-
if mc.compress {
35-
readNext = mc.compIO.readNext
36-
}
37-
3833
for {
3934
// read packet header
40-
data, err := readNext(4, mc.readWithTimeout)
35+
data, err := mc.readNextFunc(4, mc.readFunc)
4136
if err != nil {
4237
mc.close()
4338
if cerr := mc.canceled.Value(); cerr != nil {
@@ -85,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
8580
}
8681

8782
// read packet body [pktLen bytes]
88-
data, err = readNext(pktLen, mc.readWithTimeout)
83+
data, err = mc.readNextFunc(pktLen, mc.readFunc)
8984
if err != nil {
9085
mc.close()
9186
if cerr := mc.canceled.Value(); cerr != nil {
@@ -390,6 +385,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, serverCapabil
390385
return err
391386
}
392387
mc.netConn = tlsConn
388+
mc.readFunc = mc.netConn.Read
393389
}
394390

395391
// User [null terminated string]

packets_test.go

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,24 +97,30 @@ var _ net.Conn = new(mockConn)
9797
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
9898
conn := new(mockConn)
9999
connector := newConnector(NewConfig())
100+
buf := newBuffer()
100101
mc := &mysqlConn{
101-
buf: newBuffer(),
102+
buf: buf,
102103
cfg: connector.cfg,
103104
connector: connector,
104105
netConn: conn,
105106
closech: make(chan struct{}),
106107
maxAllowedPacket: defaultMaxAllowedPacket,
107108
sequence: sequence,
109+
readNextFunc: buf.readNext,
110+
readFunc: conn.Read,
108111
}
109112
return conn, mc
110113
}
111114

112115
func TestReadPacketSingleByte(t *testing.T) {
113116
conn := new(mockConn)
117+
buf := newBuffer()
114118
mc := &mysqlConn{
115-
netConn: conn,
116-
buf: newBuffer(),
117-
cfg: NewConfig(),
119+
netConn: conn,
120+
buf: buf,
121+
cfg: NewConfig(),
122+
readNextFunc: buf.readNext,
123+
readFunc: conn.Read,
118124
}
119125

120126
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
@@ -165,10 +171,13 @@ func TestReadPacketWrongSequenceID(t *testing.T) {
165171

166172
func TestReadPacketSplit(t *testing.T) {
167173
conn := new(mockConn)
174+
buf := newBuffer()
168175
mc := &mysqlConn{
169-
netConn: conn,
170-
buf: newBuffer(),
171-
cfg: NewConfig(),
176+
netConn: conn,
177+
buf: buf,
178+
cfg: NewConfig(),
179+
readNextFunc: buf.readNext,
180+
readFunc: conn.Read,
172181
}
173182

174183
data := make([]byte, maxPacketSize*2+4*3)
@@ -272,11 +281,14 @@ func TestReadPacketSplit(t *testing.T) {
272281

273282
func TestReadPacketFail(t *testing.T) {
274283
conn := new(mockConn)
284+
buf := newBuffer()
275285
mc := &mysqlConn{
276-
netConn: conn,
277-
buf: newBuffer(),
278-
closech: make(chan struct{}),
279-
cfg: NewConfig(),
286+
netConn: conn,
287+
buf: buf,
288+
closech: make(chan struct{}),
289+
cfg: NewConfig(),
290+
readNextFunc: buf.readNext,
291+
readFunc: conn.Read,
280292
}
281293

282294
// illegal empty (stand-alone) packet
@@ -317,12 +329,15 @@ func TestReadPacketFail(t *testing.T) {
317329
// not-NUL terminated plugin_name in init packet
318330
func TestRegression801(t *testing.T) {
319331
conn := new(mockConn)
332+
buf := newBuffer()
320333
mc := &mysqlConn{
321-
netConn: conn,
322-
buf: newBuffer(),
323-
cfg: new(Config),
324-
sequence: 42,
325-
closech: make(chan struct{}),
334+
netConn: conn,
335+
buf: buf,
336+
cfg: new(Config),
337+
sequence: 42,
338+
closech: make(chan struct{}),
339+
readNextFunc: buf.readNext,
340+
readFunc: conn.Read,
326341
}
327342

328343
conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,

0 commit comments

Comments
 (0)