Skip to content
This repository was archived by the owner on Jan 5, 2023. It is now read-only.

Commit 5501ab9

Browse files
committed
update to radon upstream
1 parent 54730e6 commit 5501ab9

35 files changed

+11569
-5429
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ go:
44
- 1.x
55

66
before_install:
7+
- go get github.com/shopspring/decimal
78
- go get github.com/pierrre/gotestcover
89
- go get github.com/stretchr/testify/assert
910

driver/mock.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ func (th *TestHandler) ServerVersion() string {
170170
return "FakeDB"
171171
}
172172

173+
// SetServerVersion implements the interface.
174+
func (th *TestHandler) SetServerVersion() {
175+
return
176+
}
177+
173178
// NewSession implements the interface.
174179
func (th *TestHandler) NewSession(s *Session) {
175180
th.mu.Lock()

driver/server.go

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ import (
1919

2020
"github.com/xelabs/go-mysqlstack/proto"
2121
"github.com/xelabs/go-mysqlstack/sqldb"
22-
"github.com/xelabs/go-mysqlstack/xlog"
23-
2422
"github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
2523
querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
2624
"github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
25+
"github.com/xelabs/go-mysqlstack/xlog"
2726
)
2827

2928
// Handler interface.
3029
type Handler interface {
3130
ServerVersion() string
31+
SetServerVersion()
3232
NewSession(session *Session)
3333
SessionInc(session *Session)
3434
SessionDec(session *Session)
@@ -54,8 +54,6 @@ type Listener struct {
5454

5555
// Incrementing ID for connection id.
5656
connectionID uint32
57-
58-
serverVersion string
5957
}
6058

6159
// NewListener creates a new Listener.
@@ -66,12 +64,11 @@ func NewListener(log *xlog.Log, address string, handler Handler) (*Listener, err
6664
}
6765

6866
return &Listener{
69-
log: log,
70-
address: address,
71-
handler: handler,
72-
listener: listener,
73-
connectionID: 1,
74-
serverVersion: handler.ServerVersion(),
67+
log: log,
68+
address: address,
69+
handler: handler,
70+
listener: listener,
71+
connectionID: 1,
7572
}, nil
7673
}
7774

@@ -86,7 +83,7 @@ func (l *Listener) Accept() {
8683
}
8784
ID := l.connectionID
8885
l.connectionID++
89-
go l.handle(conn, ID, l.serverVersion)
86+
go l.handle(conn, ID)
9087
}
9188
}
9289

@@ -123,16 +120,21 @@ func (l *Listener) parserComStatementExecute(data []byte, session *Session) (*St
123120
if err != nil {
124121
return nil, err
125122
}
126-
protoStmt, err := proto.UnPackStatementExecute(data[1:], stmt.ParamCount, sqltypes.ParseMySQLValues)
127-
if err != nil {
123+
124+
protoStmt := &proto.Statement{
125+
ID: stmt.ID,
126+
ParamCount: stmt.ParamCount,
127+
ParamsType: stmt.ParamsType,
128+
BindVars: stmt.BindVars,
129+
}
130+
if err = proto.UnPackStatementExecute(data[1:], protoStmt, sqltypes.ParseMySQLValues); err != nil {
128131
return nil, err
129132
}
130-
stmt.BindVars = protoStmt.BindVars
131133
return stmt, nil
132134
}
133135

134136
// handle is called in a go routine for each client connection.
135-
func (l *Listener) handle(conn net.Conn, ID uint32, serverVersion string) {
137+
func (l *Listener) handle(conn net.Conn, ID uint32) {
136138
var err error
137139
var data []byte
138140
var authPkt []byte
@@ -146,7 +148,11 @@ func (l *Listener) handle(conn net.Conn, ID uint32, serverVersion string) {
146148
log.Error("server.handle.panic:\n%v\n%s", x, debug.Stack())
147149
}
148150
}()
149-
session := newSession(log, ID, l.serverVersion, conn)
151+
152+
// set server version if backend MySQL version is different.
153+
l.handler.SetServerVersion()
154+
155+
session := newSession(log, ID, l.handler.ServerVersion(), conn)
150156
// Session check.
151157
if err = l.handler.SessionCheck(session); err != nil {
152158
log.Warning("session[%v].check.failed.error:%+v", ID, err)
@@ -252,6 +258,7 @@ func (l *Listener) handle(conn net.Conn, ID uint32, serverVersion string) {
252258
ID: id,
253259
PrepareStmt: query,
254260
ParamCount: paramCount,
261+
ParamsType: make([]int32, paramCount),
255262
BindVars: make(map[string]*querypb.BindVariable, paramCount),
256263
}
257264
for i := uint16(0); i < paramCount; i++ {
@@ -274,6 +281,7 @@ func (l *Listener) handle(conn net.Conn, ID uint32, serverVersion string) {
274281
return
275282
}
276283
}
284+
277285
if err = l.handler.ComQuery(session, stmt.PrepareStmt, sqltypes.CopyBindVariables(stmt.BindVars), func(qr *sqltypes.Result) error {
278286
return session.writeBinaryRows(qr)
279287
}); err != nil {

driver/statement.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ package driver
1212
import (
1313
"github.com/xelabs/go-mysqlstack/proto"
1414
"github.com/xelabs/go-mysqlstack/sqldb"
15-
1615
querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
1716
"github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
1817
)
@@ -23,6 +22,7 @@ type Statement struct {
2322
ID uint32
2423
ParamCount uint16
2524
PrepareStmt string
25+
ParamsType []int32
2626
ColumnNames []string
2727
BindVars map[string]*querypb.BindVariable
2828
}

makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ fmt:
55
go vet ./...
66

77
test:
8+
go get github.com/shopspring/decimal
89
go get github.com/stretchr/testify/assert
910
@echo "--> Testing..."
1011
@$(MAKE) testxlog

proto/auth.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ func (a *Auth) UnPack(payload []byte) error {
9797
return fmt.Errorf("auth.unpack: can't read authResponse")
9898
}
9999
} else {
100-
if a.authResponse, err = buf.ReadBytesNUL(); err != nil {
100+
if a.authResponse, err = buf.ReadBytes(20); err != nil {
101+
return fmt.Errorf("auth.unpack: can't read authResponse")
102+
}
103+
if err = buf.ReadZero(1); err != nil {
101104
return fmt.Errorf("auth.unpack: can't read authResponse")
102105
}
103106
}

proto/auth_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,19 +179,19 @@ func TestAuthWithoutSecure(t *testing.T) {
179179
want.authResponseLen = 20
180180
want.clientFlags = DefaultClientCapability &^ sqldb.CLIENT_SECURE_CONNECTION &^ sqldb.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
181181
want.clientFlags |= sqldb.CLIENT_CONNECT_WITH_DB
182-
want.authResponse = nativePassword("sbtest", DefaultSalt)
183-
want.user = "sbtest"
184-
want.database = "sbtest"
182+
want.authResponse = nativePassword("password", DefaultSalt)
183+
want.user = "root"
184+
want.database = "test_db"
185185
want.pluginName = DefaultAuthPluginName
186186

187187
got := NewAuth()
188188
err := got.UnPack(want.Pack(
189189
DefaultClientCapability&^sqldb.CLIENT_SECURE_CONNECTION,
190190
0x02,
191-
"sbtest",
192-
"sbtest",
191+
"root",
192+
"password",
193193
DefaultSalt,
194-
"sbtest",
194+
"test_db",
195195
))
196196
got.authResponseLen = 20
197197
assert.Nil(t, err)

proto/statement.go

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"fmt"
1414

1515
"github.com/xelabs/go-mysqlstack/sqldb"
16-
1716
"github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
1817
querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
1918
"github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
@@ -26,6 +25,7 @@ type Statement struct {
2625
ColumnCount uint16
2726
ParamCount uint16
2827
Warnings uint16
28+
ParamsType []int32
2929
ColumnNames []string
3030

3131
BindVars map[string]*querypb.BindVariable
@@ -156,87 +156,79 @@ func PackStatementExecute(stmtID uint32, parameters []sqltypes.Value) ([]byte, e
156156
}
157157

158158
// UnPackStatementExecute -- unpack the stmt-execute packet from client.
159-
func UnPackStatementExecute(data []byte, paramsCount uint16, parseValueFn func(*common.Buffer, querypb.Type) (interface{}, error)) (*Statement, error) {
159+
func UnPackStatementExecute(data []byte, prepare *Statement, parseValueFn func(*common.Buffer, querypb.Type) (interface{}, error)) error {
160160
var err error
161-
var paramsType []int32
162-
163-
stmt := &Statement{}
164161
bitMap := make([]byte, 0)
165162
buf := common.ReadBuffer(data)
166163

167-
// statement ID
168-
if stmt.ID, err = buf.ReadU32(); err != nil {
169-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading statement ID failed")
164+
if _, err = buf.ReadU32(); err != nil {
165+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading statement ID failed")
170166
}
171167

172168
// cursor type flags
173169
if _, err = buf.ReadU8(); err != nil {
174-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading cursor type flags failed")
170+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading cursor type flags failed")
175171
}
176172

177173
// iteration count
178174
var itercount uint32
179175
if itercount, err = buf.ReadU32(); err != nil {
180-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading iteration count failed")
176+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading iteration count failed")
181177
}
182178
if itercount != 1 {
183-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "iteration count is not equal to 1")
179+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "iteration count is not equal to 1")
184180
}
185181

186-
if paramsCount > 0 {
187-
// Init.
188-
paramsType = make([]int32, paramsCount)
189-
stmt.BindVars = make(map[string]*querypb.BindVariable)
190-
191-
if bitMap, err = buf.ReadBytes(int((paramsCount + 7) / 8)); err != nil {
192-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading NULL-bitmap failed")
182+
if prepare.ParamCount > 0 {
183+
if bitMap, err = buf.ReadBytes(int((prepare.ParamCount + 7) / 8)); err != nil {
184+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading NULL-bitmap failed")
193185
}
194186

195187
var newParamsBoundFlag byte
196188
if newParamsBoundFlag, err = buf.ReadU8(); err != nil {
197-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading NULL-bitmap failed")
189+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading NULL-bitmap failed")
198190
}
199191
if newParamsBoundFlag == 0x01 {
200192
var mysqlType, flags byte
201-
for i := uint16(0); i < paramsCount; i++ {
193+
for i := uint16(0); i < prepare.ParamCount; i++ {
202194
if mysqlType, err = buf.ReadU8(); err != nil {
203-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading parameter type failed")
195+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading parameter type failed")
204196
}
205197

206198
if flags, err = buf.ReadU8(); err != nil {
207-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading parameter flags failed")
199+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "reading parameter flags failed")
208200
}
209201
// Convert MySQL type to Vitess type.
210202
valType, err := sqltypes.MySQLToType(int64(mysqlType), int64(flags))
211203
if err != nil {
212-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, fmt.Sprintf("MySQLToType(%v,%v) failed: %v", mysqlType, flags, err))
204+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, fmt.Sprintf("MySQLToType(%v,%v) failed: %v", mysqlType, flags, err))
213205
}
214-
paramsType[i] = int32(valType)
206+
prepare.ParamsType[i] = int32(valType)
215207
}
216208
}
217209

218-
for i := uint16(0); i < paramsCount; i++ {
210+
for i := uint16(0); i < prepare.ParamCount; i++ {
219211
var val interface{}
220-
if paramsType[i] == int32(sqltypes.Text) || paramsType[i] == int32(sqltypes.Blob) {
212+
if prepare.ParamsType[i] == int32(sqltypes.Text) || prepare.ParamsType[i] == int32(sqltypes.Blob) {
221213
continue
222214
}
223215

224216
if (bitMap[i/8] & (1 << uint(i%8))) > 0 {
225217
val, err = parseValueFn(buf, sqltypes.Null)
226218
} else {
227-
val, err = parseValueFn(buf, querypb.Type(paramsType[i]))
219+
val, err = parseValueFn(buf, querypb.Type(prepare.ParamsType[i]))
228220
}
229221
if err != nil {
230-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, fmt.Sprintf("decoding parameter value failed(%v) failed: %v", paramsType[i], err))
222+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, fmt.Sprintf("decoding parameter value failed(%v) failed: %v", prepare.ParamsType[i], err))
231223
}
232224

233225
// If value is nil, must set bind variables to nil.
234226
bv, err := sqltypes.BuildBindVariable(val)
235227
if err != nil {
236-
return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, fmt.Sprintf("build converted parameters value failed: %v", err))
228+
return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, fmt.Sprintf("build converted parameters value failed: %v", err))
237229
}
238-
stmt.BindVars[fmt.Sprintf("v%d", i+1)] = bv
230+
prepare.BindVars[fmt.Sprintf("v%d", i+1)] = bv
239231
}
240232
}
241-
return stmt, nil
233+
return nil
242234
}

proto/statement_test.go

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,15 @@ func TestStatementExecute(t *testing.T) {
9393
parseFn := func(*common.Buffer, querypb.Type) (interface{}, error) {
9494
return nil, nil
9595
}
96-
got, err := UnPackStatementExecute(datas, 4, parseFn)
96+
97+
protoStmt := &Statement{
98+
ID: id,
99+
ParamCount: uint16(len(values)),
100+
ParamsType: make([]int32, len(values)),
101+
BindVars: make(map[string]*querypb.BindVariable, len(values)),
102+
}
103+
err = UnPackStatementExecute(datas, protoStmt, parseFn)
97104
assert.Nil(t, err)
98-
assert.NotNil(t, got)
99105
}
100106

101107
func TestStatementExecuteUnPackError(t *testing.T) {
@@ -140,8 +146,44 @@ func TestStatementExecuteUnPackError(t *testing.T) {
140146
buff := common.NewBuffer(32)
141147
fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6}
142148
for i := 0; i < len(fs); i++ {
143-
_, err := UnPackStatementExecute(buff.Datas(), 1, parseFn)
149+
150+
protoStmt := &Statement{
151+
ID: 1,
152+
ParamCount: 2,
153+
ParamsType: make([]int32, 2),
154+
BindVars: make(map[string]*querypb.BindVariable, 2),
155+
}
156+
157+
err := UnPackStatementExecute(buff.Datas(), protoStmt, parseFn)
144158
assert.NotNil(t, err)
145159
fs[i](buff)
146160
}
147161
}
162+
163+
// issue 462.
164+
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
165+
// test about new-params-bound-flag about 0 1
166+
func TestStatementExecuteBatchUnPackStatementExecute(t *testing.T) {
167+
data := []byte{ /*23,*/ 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 1, 1, 128, 1}
168+
data2 := []byte{ /*23,*/ 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 0, 1, 128, 1}
169+
170+
var dataBatch [][]byte
171+
dataBatch = append(dataBatch, data)
172+
dataBatch = append(dataBatch, data2)
173+
174+
parseFn := func(*common.Buffer, querypb.Type) (interface{}, error) {
175+
return nil, nil
176+
}
177+
178+
protoStmt := &Statement{
179+
ID: 23,
180+
ParamCount: 1,
181+
ParamsType: make([]int32, 1),
182+
BindVars: make(map[string]*querypb.BindVariable, 1),
183+
}
184+
err := UnPackStatementExecute(dataBatch[0], protoStmt, parseFn)
185+
assert.Nil(t, err)
186+
187+
err = UnPackStatementExecute(dataBatch[1], protoStmt, parseFn)
188+
assert.Nil(t, err)
189+
}

0 commit comments

Comments
 (0)