Skip to content

Commit 73d4e73

Browse files
committed
basic pipelining PREPARE + EXECUTE when possible
1 parent 9908c10 commit 73d4e73

File tree

3 files changed

+106
-18
lines changed

3 files changed

+106
-18
lines changed

connection.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,17 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
208208
if mc.closed.Load() {
209209
return nil, driver.ErrBadConn
210210
}
211+
if mc.clientExtCapabilities&clientStmtBulkOperations != 0 {
212+
// can delay PREPARE
213+
stmt := &mysqlStmt{
214+
mc: mc,
215+
id: 0xffffffff,
216+
paramCount: -1,
217+
initialQuery: query,
218+
}
219+
return stmt, nil
220+
}
221+
211222
// Send command
212223
err := mc.writeCommandPacketStr(comStmtPrepare, query)
213224
if err != nil {
@@ -221,28 +232,32 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
221232
}
222233

223234
// Read Result
235+
err = mc.readPrepareResult(stmt)
236+
return stmt, err
237+
}
238+
239+
func (mc *mysqlConn) readPrepareResult(stmt *mysqlStmt) error {
224240
columnCount, err := stmt.readPrepareResultPacket()
225241
if err == nil {
226242
if stmt.paramCount > 0 {
227243
if err = mc.skipColumns(stmt.paramCount); err != nil {
228-
return nil, err
244+
return err
229245
}
230246
}
231247

232248
if columnCount > 0 {
233249
if mc.clientExtCapabilities&clientCacheMetadata != 0 {
234250
if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil {
235-
return nil, err
251+
return err
236252
}
237253
} else {
238254
if err = mc.skipColumns(int(columnCount)); err != nil {
239-
return nil, err
255+
return err
240256
}
241257
}
242258
}
243259
}
244-
245-
return stmt, err
260+
return err
246261
}
247262

248263
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
@@ -256,7 +271,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
256271
// can not take the buffer. Something must be wrong with the connection
257272
mc.cleanup()
258273
// interpolateParams would be called before sending any query.
259-
// So its safe to retry.
274+
// So it's safe to retry.
260275
return "", driver.ErrBadConn
261276
}
262277
buf = buf[:0]

packets.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ func (mc *mysqlConn) initClientCapabilities(serverCapabilities capabilityFlag, c
310310
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string) error {
311311
// Adjust client flags based on server support
312312
mc.clientCapabilities = mc.initClientCapabilities(serverCapabilities, mc.cfg)
313-
mc.clientExtCapabilities = clientCacheMetadata & serverExtendedCapabilities
313+
mc.clientExtCapabilities = (clientCacheMetadata | clientStmtBulkOperations) & serverExtendedCapabilities
314314

315315
sendConnectAttrs := mc.clientCapabilities&clientConnectAttrs != 0
316316

@@ -1016,7 +1016,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
10161016
// Execute Prepared Statement
10171017
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
10181018
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
1019-
if len(args) != stmt.paramCount {
1019+
if stmt.paramCount != -1 && len(args) != stmt.paramCount {
10201020
return fmt.Errorf(
10211021
"argument count mismatch (got: %d; has: %d)",
10221022
len(args),
@@ -1028,7 +1028,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10281028
mc := stmt.mc
10291029

10301030
// Determine threshold dynamically to avoid packet size shortage.
1031-
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
1031+
longDataSize := mc.maxAllowedPacket / (len(args) + 1)
10321032
if longDataSize < 64 {
10331033
longDataSize = 64
10341034
}
@@ -1066,7 +1066,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10661066

10671067
var nullMask []byte
10681068
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
1069-
// buffer has to be extended but we don't know by how much so
1069+
// buffer has to be extended, but we don't know by how much so
10701070
// we depend on append after all data with known sizes fit.
10711071
// We stop at that because we deal with a lot of columns here
10721072
// which makes the required allocation size hard to guess.
@@ -1118,7 +1118,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11181118
case uint64:
11191119
paramTypes[i+i] = byte(fieldTypeLongLong)
11201120
paramTypes[i+i+1] = 0x80 // type is unsigned
1121-
paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
1121+
paramValues = binary.LittleEndian.AppendUint64(paramValues, v)
11221122

11231123
case float64:
11241124
paramTypes[i+i] = byte(fieldTypeDouble)

statement.go

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ import (
1717
)
1818

1919
type mysqlStmt struct {
20-
mc *mysqlConn
21-
id uint32
22-
paramCount int
23-
columns []mysqlField
20+
mc *mysqlConn
21+
id uint32
22+
paramCount int
23+
columns []mysqlField
24+
initialQuery string
2425
}
2526

2627
func (stmt *mysqlStmt) Close() error {
@@ -52,16 +53,53 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
5253
}
5354

5455
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
55-
if stmt.mc.closed.Load() {
56+
mc := stmt.mc
57+
if mc.closed.Load() {
5658
return nil, driver.ErrBadConn
5759
}
60+
61+
var prepareSequence uint8
62+
var prepareCompressSequence uint8
63+
64+
if mc.clientExtCapabilities&clientStmtBulkOperations != 0 {
65+
// Send command
66+
err := mc.writeCommandPacketStr(comStmtPrepare, stmt.initialQuery)
67+
if err != nil {
68+
mc.log(err)
69+
return nil, driver.ErrBadConn
70+
}
71+
prepareSequence = mc.sequence
72+
prepareCompressSequence = mc.compressSequence
73+
}
74+
5875
// Send command
5976
err := stmt.writeExecutePacket(args)
6077
if err != nil {
6178
return nil, stmt.mc.markBadConn(err)
6279
}
6380

64-
mc := stmt.mc
81+
if stmt.mc.clientExtCapabilities&clientStmtBulkOperations != 0 {
82+
// Read Prepare Result
83+
var executeSequence uint8
84+
var executeCompressSequence uint8
85+
executeSequence = mc.sequence
86+
executeCompressSequence = mc.compressSequence
87+
88+
mc.sequence = prepareSequence
89+
mc.compressSequence = prepareCompressSequence
90+
91+
err = mc.readPrepareResult(stmt)
92+
93+
mc.sequence = executeSequence
94+
mc.compressSequence = executeCompressSequence
95+
if err != nil {
96+
// skip executeResult (will return an error)
97+
handleOk := stmt.mc.clearResult()
98+
_, _, _ = handleOk.readResultSetHeaderPacket()
99+
return nil, err
100+
}
101+
}
102+
65103
handleOk := stmt.mc.clearResult()
66104

67105
// Read Result
@@ -98,13 +136,48 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
98136
if stmt.mc.closed.Load() {
99137
return nil, driver.ErrBadConn
100138
}
139+
140+
mc := stmt.mc
141+
var prepareSequence uint8
142+
var prepareCompressSequence uint8
143+
144+
if mc.clientExtCapabilities&clientStmtBulkOperations != 0 {
145+
// Send command
146+
err := mc.writeCommandPacketStr(comStmtPrepare, stmt.initialQuery)
147+
if err != nil {
148+
mc.log(err)
149+
return nil, driver.ErrBadConn
150+
}
151+
prepareSequence = mc.sequence
152+
prepareCompressSequence = mc.compressSequence
153+
}
101154
// Send command
102155
err := stmt.writeExecutePacket(args)
103156
if err != nil {
104157
return nil, stmt.mc.markBadConn(err)
105158
}
106159

107-
mc := stmt.mc
160+
if stmt.mc.clientExtCapabilities&clientStmtBulkOperations != 0 {
161+
// Read Prepare Result
162+
var executeSequence uint8
163+
var executeCompressSequence uint8
164+
executeSequence = mc.sequence
165+
executeCompressSequence = mc.compressSequence
166+
167+
mc.sequence = prepareSequence
168+
mc.compressSequence = prepareCompressSequence
169+
170+
err = mc.readPrepareResult(stmt)
171+
172+
mc.sequence = executeSequence
173+
mc.compressSequence = executeCompressSequence
174+
if err != nil {
175+
// skip executeResult (will return an error)
176+
handleOk := stmt.mc.clearResult()
177+
_, _, _ = handleOk.readResultSetHeaderPacket()
178+
return nil, err
179+
}
180+
}
108181

109182
// Read Result
110183
handleOk := stmt.mc.clearResult()

0 commit comments

Comments
 (0)