Skip to content

Commit b6aa471

Browse files
Michaelhuebnerr
andauthored
Rollback severity seems to end the connection (#64)
* Rollback seems to end the connection Check for ROLLBACK severity and guard against it in Close Fixes #61 Fixes #48 * ROLLBACK kills connection too Ping would try to receive a message during the session reset which won't ever happen. Co-authored-by: huebnerr <45869321+huebnerr@users.noreply.github.com>
1 parent 2fff12d commit b6aa471

File tree

5 files changed

+75
-32
lines changed

5 files changed

+75
-32
lines changed

connection.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ type connection struct {
7171
scratch [512]byte
7272
sessionID string
7373
serverTZOffset string
74+
dead bool // used if a ROLLBACK severity error is encountered
7475
sessMutex sync.Mutex
7576
}
7677

@@ -153,6 +154,9 @@ func (v *connection) Ping(ctx context.Context) error {
153154
// ResetSession implements the SessionResetter interface for connection. This allows the sql
154155
// package to evaluate the connection state when managing the connection pool.
155156
func (v *connection) ResetSession(ctx context.Context) error {
157+
if v.dead {
158+
return driver.ErrBadConn
159+
}
156160
return v.Ping(ctx)
157161
}
158162

driver_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,30 @@ func TestConnectionClosure(t *testing.T) {
801801
rows.Close()
802802
}
803803

804+
func TestConcurrentStatementQuery(t *testing.T) {
805+
connDB := openConnection(t, "test_stmt_ordering_threads_pre")
806+
defer closeConnection(t, connDB, "test_stmt_ordering_threads_post")
807+
stmt, err := connDB.PrepareContext(ctx, "SELECT a FROM stmt_thread_test")
808+
assertNoErr(t, err)
809+
wg := new(sync.WaitGroup)
810+
wg.Add(3)
811+
for i := 0; i < 3; i++ {
812+
go func() {
813+
defer wg.Done()
814+
_, err := stmt.QueryContext(ctx)
815+
assertNoErr(t, err)
816+
}()
817+
}
818+
wg.Wait()
819+
}
820+
821+
func TestInvalidDDLStatement(t *testing.T) {
822+
connDB := openConnection(t)
823+
defer closeConnection(t, connDB)
824+
_, err := connDB.Exec("DROP VIEW DOESNOTEXISTVIEW")
825+
assertErr(t, err, "does not exist")
826+
}
827+
804828
var verticaUserName = flag.String("user", "dbadmin", "the user name to connect to Vertica")
805829
var verticaPassword = flag.String("password", os.Getenv("VERTICA_TEST_PASSWORD"), "Vertica password for this user")
806830
var verticaHostPort = flag.String("locator", "localhost:5433", "Vertica's host and port")

msgs/beerrormsg.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ type BEErrorMsg struct {
4646
Routine string
4747
}
4848

49-
// InitFromMsgBody docs
49+
// CreateFromMsgBody docs
5050
func (b *BEErrorMsg) CreateFromMsgBody(buf *msgBuffer) (BackEndMsg, error) {
5151

5252
res := &BEErrorMsg{}

msgs/msg.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,30 @@ import (
3636
"fmt"
3737
)
3838

39+
// CmdTargetType describes the target of a command
3940
type CmdTargetType byte
4041

41-
var (
42+
// Possible command targets
43+
const (
4244
CmdTargetTypePortal CmdTargetType = 'P'
4345
CmdTargetTypeStatement CmdTargetType = 'S'
4446
)
4547

46-
// FrontEndMsg docs
48+
// FrontEndMsg is sent from the adapter to the database
4749
type FrontEndMsg interface {
4850
Flatten() ([]byte, byte)
4951
String() string
5052
}
5153

52-
// BackEndMsg docs
54+
// BackEndMsg is received from the database
5355
type BackEndMsg interface {
5456
CreateFromMsgBody(*msgBuffer) (BackEndMsg, error)
5557
String() string
5658
}
5759

60+
// backEndMsgTypeMap is a global map of message descriptor bytes to instances
61+
// of that message. The instances are not used directly, but instead are used to
62+
// construct new values of that message type. This is populated on init.
5863
var backEndMsgTypeMap = make(map[byte]BackEndMsg)
5964

6065
func registerBackEndMsgType(msgType byte, bem BackEndMsg) {

stmt.go

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ type stmt struct {
7272
posArgCnt int
7373
paramTypes []common.ParameterType
7474
lastRowDesc *msgs.BERowDescMsg
75+
// set if Vertica issues an error of ROLLBACK severity
76+
rolledBack bool
7577
}
7678

7779
func newStmt(connection *connection, command string) (*stmt, error) {
@@ -99,40 +101,41 @@ func (s *stmt) pushNamed(name string) {
99101

100102
// Close closes this statement.
101103
func (s *stmt) Close() error {
102-
if s.parseState == parseStateParsed {
103-
closeMsg := &msgs.FECloseMsg{TargetType: msgs.CmdTargetTypeStatement, TargetName: s.preparedName}
104+
if s.parseState != parseStateParsed {
105+
return nil
106+
}
107+
if s.rolledBack {
108+
s.parseState = parseStateUnparsed
109+
s.conn.dead = true
110+
return nil
111+
}
112+
closeMsg := &msgs.FECloseMsg{TargetType: msgs.CmdTargetTypeStatement, TargetName: s.preparedName}
104113

105-
//s.conn.lockSessionMutex()
106-
//defer s.conn.unlockSessionMutex()
114+
if err := s.conn.sendMessage(closeMsg); err != nil {
115+
return err
116+
}
107117

108-
if err := s.conn.sendMessage(closeMsg); err != nil {
109-
return err
110-
}
118+
if err := s.conn.sendMessage(&msgs.FEFlushMsg{}); err != nil {
119+
return err
120+
}
121+
122+
for {
123+
bMsg, err := s.conn.recvMessage()
111124

112-
if err := s.conn.sendMessage(&msgs.FEFlushMsg{}); err != nil {
125+
if err != nil {
113126
return err
114127
}
115128

116-
for {
117-
bMsg, err := s.conn.recvMessage()
118-
119-
if err != nil {
120-
return err
121-
}
122-
123-
switch bMsg.(type) {
124-
case *msgs.BECloseCompleteMsg:
125-
s.parseState = parseStateUnparsed
126-
return nil
127-
case *msgs.BECmdDescriptionMsg:
128-
continue
129-
default:
130-
s.conn.defaultMessageHandler(bMsg)
131-
}
129+
switch bMsg.(type) {
130+
case *msgs.BECloseCompleteMsg:
131+
s.parseState = parseStateUnparsed
132+
return nil
133+
case *msgs.BECmdDescriptionMsg:
134+
continue
135+
default:
136+
s.conn.defaultMessageHandler(bMsg)
132137
}
133138
}
134-
135-
return nil
136139
}
137140

138141
// NumInput is used by database/sql to sanity check the number of arguments given
@@ -280,7 +283,7 @@ func (s *stmt) QueryContextRaw(ctx context.Context, baseArgs []driver.NamedValue
280283
case *msgs.BECmdCompleteMsg:
281284
break
282285
case *msgs.BEErrorMsg:
283-
return newEmptyRows(), msg.ToErrorType()
286+
return newEmptyRows(), s.evaluateErrorMsg(msg)
284287
case *msgs.BEEmptyQueryResponseMsg:
285288
return newEmptyRows(), nil
286289
case *msgs.BEReadyForQueryMsg, *msgs.BEPortalSuspendedMsg:
@@ -388,6 +391,13 @@ func (s *stmt) interpolate(args []driver.NamedValue) (string, error) {
388391
return result, nil
389392
}
390393

394+
func (s *stmt) evaluateErrorMsg(msg *msgs.BEErrorMsg) error {
395+
if msg.Severity == "ROLLBACK" {
396+
s.rolledBack = true
397+
}
398+
return msg.ToErrorType()
399+
}
400+
391401
func (s *stmt) prepareAndDescribe() error {
392402

393403
parseMsg := &msgs.FEParseMsg{
@@ -493,7 +503,7 @@ func (s *stmt) collectResults(ctx context.Context) (*rows, error) {
493503
s.lastRowDesc = msg
494504
rows = newRows(ctx, s.lastRowDesc, s.conn.serverTZOffset)
495505
case *msgs.BEErrorMsg:
496-
return newEmptyRows(), msg.ToErrorType()
506+
return newEmptyRows(), s.evaluateErrorMsg(msg)
497507
case *msgs.BEEmptyQueryResponseMsg:
498508
return newEmptyRows(), nil
499509
case *msgs.BEBindCompleteMsg, *msgs.BECmdDescriptionMsg:

0 commit comments

Comments
 (0)