Skip to content

Commit 7ea01d1

Browse files
authored
Merge pull request #89 from shogo82148/revert-84-main
2 parents 300c00d + 918628b commit 7ea01d1

File tree

3 files changed

+56
-45
lines changed

3 files changed

+56
-45
lines changed

conn.go

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ type Conn struct {
1717
// It will trigger PrePing, Ping, PostPing hooks.
1818
//
1919
// If the original connection does not satisfy "database/sql/driver".Pinger, it does nothing.
20-
func (conn *Conn) Ping(c context.Context) (err error) {
20+
func (conn *Conn) Ping(c context.Context) error {
21+
var err error
2122
var ctx interface{}
2223
hooks := conn.Proxy.getHooks(c)
2324

2425
if hooks != nil {
25-
defer func() { err = hooks.postPing(c, ctx, conn, err) }()
26+
defer func() { hooks.postPing(c, ctx, conn, err) }()
2627
if ctx, err = hooks.prePing(c, conn); err != nil {
2728
return err
2829
}
@@ -48,30 +49,31 @@ func (conn *Conn) Prepare(query string) (driver.Stmt, error) {
4849
}
4950

5051
// PrepareContext returns a prepared statement which is wrapped by Stmt.
51-
func (conn *Conn) PrepareContext(c context.Context, query string) (stmt driver.Stmt, err error) {
52+
func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt, error) {
5253
var ctx interface{}
53-
var stmtAux = &Stmt{
54+
var stmt = &Stmt{
5455
QueryString: query,
5556
Proxy: conn.Proxy,
5657
Conn: conn,
5758
}
59+
var err error
5860
hooks := conn.Proxy.getHooks(c)
5961
if hooks != nil {
60-
defer func() { err = hooks.postPrepare(c, ctx, stmtAux, err) }()
61-
if ctx, err = hooks.prePrepare(c, stmtAux); err != nil {
62+
defer func() { hooks.postPrepare(c, ctx, stmt, err) }()
63+
if ctx, err = hooks.prePrepare(c, stmt); err != nil {
6264
return nil, err
6365
}
6466
}
6567

6668
if connCtx, ok := conn.Conn.(driver.ConnPrepareContext); ok {
67-
stmtAux.Stmt, err = connCtx.PrepareContext(c, stmtAux.QueryString)
69+
stmt.Stmt, err = connCtx.PrepareContext(c, stmt.QueryString)
6870
} else {
69-
stmtAux.Stmt, err = conn.Conn.Prepare(stmtAux.QueryString)
71+
stmt.Stmt, err = conn.Conn.Prepare(stmt.QueryString)
7072
if err == nil {
7173
select {
7274
default:
7375
case <-c.Done():
74-
stmtAux.Stmt.Close()
76+
stmt.Stmt.Close()
7577
return nil, c.Err()
7678
}
7779
}
@@ -81,20 +83,21 @@ func (conn *Conn) PrepareContext(c context.Context, query string) (stmt driver.S
8183
}
8284

8385
if hooks != nil {
84-
if err = hooks.prepare(c, ctx, stmtAux); err != nil {
86+
if err = hooks.prepare(c, ctx, stmt); err != nil {
8587
return nil, err
8688
}
8789
}
88-
return stmtAux, nil
90+
return stmt, nil
8991
}
9092

9193
// Close calls the original Close method.
92-
func (conn *Conn) Close() (err error) {
94+
func (conn *Conn) Close() error {
9395
ctx := context.Background()
96+
var err error
9497
var myctx interface{}
9598

9699
if hooks := conn.Proxy.hooks; hooks != nil {
97-
defer func() { err = hooks.postClose(ctx, myctx, conn, err) }()
100+
defer func() { hooks.postClose(ctx, myctx, conn, err) }()
98101
if myctx, err = hooks.preClose(ctx, conn); err != nil {
99102
return err
100103
}
@@ -120,12 +123,14 @@ func (conn *Conn) Begin() (driver.Tx, error) {
120123

121124
// BeginTx starts and returns a new transaction which is wrapped by Tx.
122125
// It will trigger PreBegin, Begin, PostBegin hooks.
123-
func (conn *Conn) BeginTx(c context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
126+
func (conn *Conn) BeginTx(c context.Context, opts driver.TxOptions) (driver.Tx, error) {
124127
// set the hooks.
128+
var err error
125129
var ctx interface{}
130+
var tx driver.Tx
126131
hooks := conn.Proxy.getHooks(c)
127132
if hooks != nil {
128-
defer func() { err = hooks.postBegin(c, ctx, conn, err) }()
133+
defer func() { hooks.postBegin(c, ctx, conn, err) }()
129134
if ctx, err = hooks.preBegin(c, conn); err != nil {
130135
return nil, err
131136
}
@@ -188,7 +193,7 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error)
188193
// It will trigger PreExec, Exec, PostExec hooks.
189194
//
190195
// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
191-
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (drv driver.Result, err error) {
196+
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
192197
execer, exOk := conn.Conn.(driver.Execer)
193198
execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext)
194199
if !exOk && !exCtxOk {
@@ -202,17 +207,19 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
202207
Conn: conn,
203208
}
204209
var ctx interface{}
210+
var err error
211+
var result driver.Result
205212
hooks := conn.Proxy.getHooks(c)
206213
if hooks != nil {
207-
defer func() { err = hooks.postExec(c, ctx, stmt, args, drv, err) }()
214+
defer func() { hooks.postExec(c, ctx, stmt, args, result, err) }()
208215
if ctx, err = hooks.preExec(c, stmt, args); err != nil {
209216
return nil, err
210217
}
211218
}
212219

213220
// call the original method.
214221
if execerCtx != nil {
215-
drv, err = execerCtx.ExecContext(c, stmt.QueryString, args)
222+
result, err = execerCtx.ExecContext(c, stmt.QueryString, args)
216223
} else {
217224
select {
218225
default:
@@ -223,18 +230,19 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
223230
if err0 != nil {
224231
return nil, err0
225232
}
226-
drv, err = execer.Exec(stmt.QueryString, dargs)
233+
result, err = execer.Exec(stmt.QueryString, dargs)
227234
}
228235
if err != nil {
229236
return nil, err
230237
}
231238

232239
if hooks != nil {
233-
if err = hooks.exec(c, ctx, stmt, args, drv); err != nil {
240+
if err = hooks.exec(c, ctx, stmt, args, result); err != nil {
234241
return nil, err
235242
}
236243
}
237-
return drv, err
244+
245+
return result, nil
238246
}
239247

240248
// Query executes a query that may return rows.
@@ -250,7 +258,7 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error)
250258
// It wil trigger PreQuery, Query, PostQuery hooks.
251259
//
252260
// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error.
253-
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
261+
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
254262
queryer, qok := conn.Conn.(driver.Queryer)
255263
queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext)
256264
if !qok && !qCtxOk {
@@ -263,9 +271,11 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
263271
Conn: conn,
264272
}
265273
var ctx interface{}
274+
var err error
275+
var rows driver.Rows
266276
hooks := conn.Proxy.getHooks(c)
267277
if hooks != nil {
268-
defer func() { err = hooks.postQuery(c, ctx, stmt, args, rows, err) }()
278+
defer func() { hooks.postQuery(c, ctx, stmt, args, rows, err) }()
269279
if ctx, err = hooks.preQuery(c, stmt, args); err != nil {
270280
return nil, err
271281
}
@@ -333,12 +343,13 @@ type sessionResetter interface {
333343
}
334344

335345
// ResetSession resets the state of Conn.
336-
func (conn *Conn) ResetSession(ctx context.Context) (err error) {
346+
func (conn *Conn) ResetSession(ctx context.Context) error {
347+
var err error
337348
var myctx interface{}
338349
hooks := conn.Proxy.getHooks(ctx)
339350

340351
if hooks != nil {
341-
defer func() { err = hooks.postResetSession(ctx, myctx, conn, err) }()
352+
defer func() { hooks.postResetSession(ctx, myctx, conn, err) }()
342353
if myctx, err = hooks.preResetSession(ctx, conn); err != nil {
343354
return err
344355
}

hooks.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ func (h *HooksContext) ping(c context.Context, ctx interface{}, conn *Conn) erro
412412

413413
func (h *HooksContext) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error {
414414
if h == nil || h.PostPing == nil {
415-
return err
415+
return nil
416416
}
417417
return h.PostPing(c, ctx, conn, err)
418418
}
@@ -433,7 +433,7 @@ func (h *HooksContext) open(c context.Context, ctx interface{}, conn *Conn) erro
433433

434434
func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error {
435435
if h == nil || h.PostOpen == nil {
436-
return err
436+
return nil
437437
}
438438
return h.PostOpen(c, ctx, conn, err)
439439
}
@@ -454,7 +454,7 @@ func (h *HooksContext) prepare(c context.Context, ctx interface{}, stmt *Stmt) e
454454

455455
func (h *HooksContext) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
456456
if h == nil || h.PostPrepare == nil {
457-
return err
457+
return nil
458458
}
459459
return h.PostPrepare(c, ctx, stmt, err)
460460
}
@@ -475,7 +475,7 @@ func (h *HooksContext) exec(c context.Context, ctx interface{}, stmt *Stmt, args
475475

476476
func (h *HooksContext) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error {
477477
if h == nil || h.PostExec == nil {
478-
return err
478+
return nil
479479
}
480480
return h.PostExec(c, ctx, stmt, args, result, err)
481481
}
@@ -496,7 +496,7 @@ func (h *HooksContext) query(c context.Context, ctx interface{}, stmt *Stmt, arg
496496

497497
func (h *HooksContext) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error {
498498
if h == nil || h.PostQuery == nil {
499-
return err
499+
return nil
500500
}
501501
return h.PostQuery(c, ctx, stmt, args, rows, err)
502502
}
@@ -517,7 +517,7 @@ func (h *HooksContext) begin(c context.Context, ctx interface{}, conn *Conn) err
517517

518518
func (h *HooksContext) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error {
519519
if h == nil || h.PostBegin == nil {
520-
return err
520+
return nil
521521
}
522522
return h.PostBegin(c, ctx, conn, err)
523523
}
@@ -538,7 +538,7 @@ func (h *HooksContext) commit(c context.Context, ctx interface{}, tx *Tx) error
538538

539539
func (h *HooksContext) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error {
540540
if h == nil || h.PostCommit == nil {
541-
return err
541+
return nil
542542
}
543543
return h.PostCommit(c, ctx, tx, err)
544544
}
@@ -559,7 +559,7 @@ func (h *HooksContext) rollback(c context.Context, ctx interface{}, tx *Tx) erro
559559

560560
func (h *HooksContext) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error {
561561
if h == nil || h.PostRollback == nil {
562-
return err
562+
return nil
563563
}
564564
return h.PostRollback(c, ctx, tx, err)
565565
}
@@ -580,7 +580,7 @@ func (h *HooksContext) close(c context.Context, ctx interface{}, conn *Conn) err
580580

581581
func (h *HooksContext) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error {
582582
if h == nil || h.PostClose == nil {
583-
return err
583+
return nil
584584
}
585585
return h.PostClose(c, ctx, conn, err)
586586
}
@@ -601,7 +601,7 @@ func (h *HooksContext) resetSession(c context.Context, ctx interface{}, conn *Co
601601

602602
func (h *HooksContext) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error {
603603
if h == nil || h.PostResetSession == nil {
604-
return err
604+
return nil
605605
}
606606
return h.PostResetSession(c, ctx, conn, err)
607607
}

logging_hook_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (h *loggingHook) postPing(c context.Context, ctx interface{}, conn *Conn, e
3737
h.mu.Lock()
3838
defer h.mu.Unlock()
3939
fmt.Fprintln(h, "[PostPing]")
40-
return err
40+
return nil
4141
}
4242

4343
func (h *loggingHook) preOpen(c context.Context, name string) (interface{}, error) {
@@ -58,7 +58,7 @@ func (h *loggingHook) postOpen(c context.Context, ctx interface{}, conn *Conn, e
5858
h.mu.Lock()
5959
defer h.mu.Unlock()
6060
fmt.Fprintln(h, "[PostOpen]")
61-
return err
61+
return nil
6262
}
6363

6464
func (h *loggingHook) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
@@ -79,7 +79,7 @@ func (h *loggingHook) postPrepare(c context.Context, ctx interface{}, stmt *Stmt
7979
h.mu.Lock()
8080
defer h.mu.Unlock()
8181
fmt.Fprintln(h, "[PostPrepare]")
82-
return err
82+
return nil
8383
}
8484

8585
func (h *loggingHook) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
@@ -100,7 +100,7 @@ func (h *loggingHook) postExec(c context.Context, ctx interface{}, stmt *Stmt, a
100100
h.mu.Lock()
101101
defer h.mu.Unlock()
102102
fmt.Fprintln(h, "[PostExec]")
103-
return err
103+
return nil
104104
}
105105

106106
func (h *loggingHook) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
@@ -121,7 +121,7 @@ func (h *loggingHook) postQuery(c context.Context, ctx interface{}, stmt *Stmt,
121121
h.mu.Lock()
122122
defer h.mu.Unlock()
123123
fmt.Fprintln(h, "[PostQuery]")
124-
return err
124+
return nil
125125
}
126126

127127
func (h *loggingHook) preBegin(c context.Context, conn *Conn) (interface{}, error) {
@@ -142,7 +142,7 @@ func (h *loggingHook) postBegin(c context.Context, ctx interface{}, conn *Conn,
142142
h.mu.Lock()
143143
defer h.mu.Unlock()
144144
fmt.Fprintln(h, "[PostBegin]")
145-
return err
145+
return nil
146146
}
147147

148148
func (h *loggingHook) preCommit(c context.Context, tx *Tx) (interface{}, error) {
@@ -163,7 +163,7 @@ func (h *loggingHook) postCommit(c context.Context, ctx interface{}, tx *Tx, err
163163
h.mu.Lock()
164164
defer h.mu.Unlock()
165165
fmt.Fprintln(h, "[PostCommit]")
166-
return err
166+
return nil
167167
}
168168

169169
func (h *loggingHook) preRollback(c context.Context, tx *Tx) (interface{}, error) {
@@ -184,7 +184,7 @@ func (h *loggingHook) postRollback(c context.Context, ctx interface{}, tx *Tx, e
184184
h.mu.Lock()
185185
defer h.mu.Unlock()
186186
fmt.Fprintln(h, "[PostRollback]")
187-
return err
187+
return nil
188188
}
189189

190190
func (h *loggingHook) preClose(c context.Context, conn *Conn) (interface{}, error) {
@@ -205,7 +205,7 @@ func (h *loggingHook) postClose(c context.Context, ctx interface{}, conn *Conn,
205205
h.mu.Lock()
206206
defer h.mu.Unlock()
207207
fmt.Fprintln(h, "[PostClose]")
208-
return err
208+
return nil
209209
}
210210

211211
func (h *loggingHook) preResetSession(c context.Context, conn *Conn) (interface{}, error) {
@@ -217,7 +217,7 @@ func (h *loggingHook) resetSession(c context.Context, ctx interface{}, conn *Con
217217
}
218218

219219
func (h *loggingHook) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error {
220-
return err
220+
return nil
221221
}
222222

223223
func (h *loggingHook) preIsValid(conn *Conn) (interface{}, error) {

0 commit comments

Comments
 (0)