Skip to content

Commit dc2aed2

Browse files
huebnerrSiting Ren
authored andcommitted
Fix #17 - panic() after reusing prepared stmt (#18)
1 parent 81454c5 commit dc2aed2

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

driver.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ type Driver struct{}
4747

4848
const (
4949
driverName string = "vertica-sql-go"
50-
driverVersion string = "0.1.2"
50+
driverVersion string = "0.1.3"
5151
protocolVersion uint32 = 0x00030008
5252
)
5353

driver_test.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func assertErr(t *testing.T, err error, errorSubstring string) {
122122
return
123123
}
124124

125-
t.Fatalf("expected an error, but it was '%s' instead of containing '%s'", errStr, errorSubstring)
125+
t.Fatalf("expected an error containing '%s', but found '%s'", errorSubstring, errStr)
126126
}
127127

128128
func assertNext(t *testing.T, rows *sql.Rows) {
@@ -477,6 +477,38 @@ func TestValueTypes(t *testing.T) {
477477
assertNoErr(t, rows.Close())
478478
}
479479

480+
func TestStmtReuseBug(t *testing.T) {
481+
connDB := openConnection(t)
482+
defer closeConnection(t, connDB)
483+
484+
var res bool
485+
486+
stmt, err := connDB.PrepareContext(ctx, "SELECT true AS res")
487+
assertNoErr(t, err)
488+
489+
// first call
490+
rows, err := stmt.QueryContext(ctx)
491+
assertNoErr(t, err)
492+
493+
defer rows.Close()
494+
495+
assertNext(t, rows)
496+
assertNoErr(t, rows.Scan(&res))
497+
assertEqual(t, res, true)
498+
assertNoNext(t, rows)
499+
500+
// second call
501+
rows, err = stmt.QueryContext(ctx)
502+
assertNoErr(t, err)
503+
504+
defer rows.Close()
505+
506+
assertNext(t, rows)
507+
assertNoErr(t, rows.Scan(&res))
508+
assertEqual(t, res, true)
509+
assertNoNext(t, rows)
510+
}
511+
480512
func init() {
481513
logger.SetLogLevel(logger.INFO)
482514

stmt.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ type stmt struct {
6767
preparedName string
6868
parseState parseState
6969
paramTypes []common.ParameterType
70+
lastRowDesc *msgs.BERowDescMsg
7071
}
7172

7273
func newStmt(connection *connection, command string) (*stmt, error) {
@@ -176,7 +177,7 @@ func (s *stmt) QueryContextRaw(ctx context.Context, args []driver.NamedValue) (*
176177
return s.collectResults()
177178
}
178179

179-
// We aren't a prepared statement, manually interpolate and do as a simpe query.
180+
// We aren't a prepared statement, manually interpolate and do as a simple query.
180181
cmd, err = s.interpolate(args)
181182

182183
if err != nil {
@@ -196,9 +197,12 @@ func (s *stmt) QueryContextRaw(ctx context.Context, args []driver.NamedValue) (*
196197

197198
switch msg := bMsg.(type) {
198199
case *msgs.BEDataRowMsg:
200+
if rows == emptyRowSet {
201+
rows = newRows(s.lastRowDesc, s.conn.serverTZOffset)
202+
}
199203
rows.addRow(msg)
200204
case *msgs.BERowDescMsg:
201-
rows = newRows(msg, s.conn.serverTZOffset)
205+
s.lastRowDesc = msg
202206
case *msgs.BECmdCompleteMsg:
203207
break
204208
case *msgs.BEErrorMsg:
@@ -341,9 +345,12 @@ func (s *stmt) collectResults() (*rows, error) {
341345

342346
switch msg := bMsg.(type) {
343347
case *msgs.BEDataRowMsg:
348+
if rows == emptyRowSet {
349+
rows = newRows(s.lastRowDesc, s.conn.serverTZOffset)
350+
}
344351
rows.addRow(msg)
345352
case *msgs.BERowDescMsg:
346-
rows = newRows(msg, s.conn.serverTZOffset)
353+
s.lastRowDesc = msg
347354
case *msgs.BEErrorMsg:
348355
return emptyRowSet, msg.ToErrorType()
349356
case *msgs.BEEmptyQueryResponseMsg:

0 commit comments

Comments
 (0)