From afd54e6aba85574c72ecbb5e583b47b6a794898c Mon Sep 17 00:00:00 2001 From: Yuya Tanaka Date: Thu, 1 Apr 2021 03:19:20 +0900 Subject: [PATCH] Fix NamedValueChecker implemented on conn not used --- connection.go | 2 +- statement.go | 6 +++++- statement_test.go | 20 ++++++++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/connection.go b/connection.go index e55432b..3bb07e9 100644 --- a/connection.go +++ b/connection.go @@ -261,7 +261,7 @@ func (c *connection) statement(stmt driver.Stmt, err error, id, query string) (d return stmt, err } - return &statement{Stmt: stmt, query: query, logger: c.logger, connID: c.id, id: id}, nil + return &statement{Stmt: stmt, query: query, logger: c.logger, connID: c.id, conn: c.Conn, id: id}, nil } func (c *connection) rows(res driver.Rows, err error, query string, args []driver.Value) (driver.Rows, error) { diff --git a/statement.go b/statement.go index e576c25..f57a628 100644 --- a/statement.go +++ b/statement.go @@ -17,6 +17,7 @@ type statement struct { query string logger *logger id string + conn driver.Conn connID string } @@ -115,7 +116,10 @@ func (s *statement) QueryContext(ctx context.Context, args []driver.NamedValue) func (s *statement) CheckNamedValue(nm *driver.NamedValue) error { checker, ok := s.Stmt.(driver.NamedValueChecker) if !ok { - return driver.ErrSkip + checker, ok = s.conn.(driver.NamedValueChecker) + if !ok { + return driver.ErrSkip + } } lvl, start := LevelTrace, time.Now() diff --git a/statement_test.go b/statement_test.go index ebcb89b..9402547 100644 --- a/statement_test.go +++ b/statement_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql/driver" "encoding/json" + "errors" "testing" "github.com/stretchr/testify/assert" @@ -376,6 +377,17 @@ func TestStatement_CheckNamedValue(t *testing.T) { assert.NotEmpty(t, stmtOutput.Data[testOpts.connIDFieldname]) }) + t.Run("Implement driver.NamedValueChecker on conn but stmt", func(t *testing.T) { + stmtMock := &statementMock{} + connMock := &driverConnNamedValueCheckerMock{} + mockErr := errors.New("mock") + connMock.On("CheckNamedValue", mock.Anything).Return(mockErr) + + stmt := &statement{Stmt: stmtMock, logger: testLogger, id: testLogger.opt.uidGenerator.UniqueID(), connID: testLogger.opt.uidGenerator.UniqueID(), conn: connMock} + err := stmt.CheckNamedValue(&driver.NamedValue{Name: "", Ordinal: 0, Value: "testid"}) + assert.Equal(t, mockErr, err) + }) + t.Run("Not implement driver.NamedValueChecker", func(t *testing.T) { stmtMock := &statementMock{} @@ -458,6 +470,14 @@ func (m *statementNamedValueCheckerMock) CheckNamedValue(nm *driver.NamedValue) return m.Called().Error(0) } +type driverConnNamedValueCheckerMock struct { + driverConnMock +} + +func (c *driverConnNamedValueCheckerMock) CheckNamedValue(nm *driver.NamedValue) error { + return c.Called().Error(0) +} + type statementValueConverterMock struct { statementMock }