Skip to content

Commit 11b2900

Browse files
committed
fixed mapping to database/sql.ErrBadConn
1 parent 6fd1444 commit 11b2900

File tree

8 files changed

+50
-57
lines changed

8 files changed

+50
-57
lines changed

internal/xsql/conn.go

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -101,22 +101,8 @@ func newConn(c *Connector, s table.ClosableSession, opts ...connOption) *conn {
101101
return cc
102102
}
103103

104-
func (c *conn) checkClosed(err error) error {
105-
if err = badconn.Map(err); xerrors.Is(err, driver.ErrBadConn) {
106-
atomic.StoreUint32(&c.closed, 1)
107-
}
108-
return err
109-
}
110-
111104
func (c *conn) isReady() bool {
112-
if atomic.LoadUint32(&c.closed) == 1 {
113-
return true
114-
}
115-
if c.session.Status() != table.SessionReady {
116-
atomic.StoreUint32(&c.closed, 1)
117-
return true
118-
}
119-
return false
105+
return c.session.Status() == table.SessionReady
120106
}
121107

122108
func (conn) CheckNamedValue(v *driver.NamedValue) (err error) {
@@ -129,7 +115,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (_ driver.Stmt,
129115
onDone(err)
130116
}()
131117
if !c.isReady() {
132-
return nil, errNotReadyConn
118+
return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
133119
}
134120
return &stmt{
135121
conn: c,
@@ -162,38 +148,38 @@ func (c *conn) execContext(ctx context.Context, query string, args []driver.Name
162148
dataQueryOptions(ctx)...,
163149
)
164150
if err != nil {
165-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
151+
return nil, badconn.Map(xerrors.WithStackTrace(err))
166152
}
167153
defer func() {
168154
_ = res.Close()
169155
}()
170156
if err = res.NextResultSetErr(ctx); !xerrors.Is(err, nil, io.EOF) {
171-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
157+
return nil, badconn.Map(xerrors.WithStackTrace(err))
172158
}
173159
if err = res.Err(); err != nil {
174-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
160+
return nil, badconn.Map(xerrors.WithStackTrace(err))
175161
}
176162
return driver.ResultNoRows, nil
177163
case SchemeQueryMode:
178164
err = c.session.ExecuteSchemeQuery(ctx, query)
179165
if err != nil {
180-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
166+
return nil, badconn.Map(xerrors.WithStackTrace(err))
181167
}
182168
return driver.ResultNoRows, nil
183169
case ScriptingQueryMode:
184170
var res result.StreamResult
185171
res, err = c.connector.connection.Scripting().StreamExecute(ctx, query, toQueryParams(args))
186172
if err != nil {
187-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
173+
return nil, badconn.Map(xerrors.WithStackTrace(err))
188174
}
189175
defer func() {
190176
_ = res.Close()
191177
}()
192178
if err = res.NextResultSetErr(ctx); !xerrors.Is(err, nil, io.EOF) {
193-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
179+
return nil, badconn.Map(xerrors.WithStackTrace(err))
194180
}
195181
if err = res.Err(); err != nil {
196-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
182+
return nil, badconn.Map(xerrors.WithStackTrace(err))
197183
}
198184
return driver.ResultNoRows, nil
199185
default:
@@ -203,7 +189,7 @@ func (c *conn) execContext(ctx context.Context, query string, args []driver.Name
203189

204190
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (_ driver.Result, err error) {
205191
if !c.isReady() {
206-
return nil, errNotReadyConn
192+
return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
207193
}
208194
if c.currentTx != nil {
209195
return c.currentTx.ExecContext(ctx, query, args)
@@ -213,7 +199,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
213199

214200
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (_ driver.Rows, err error) {
215201
if !c.isReady() {
216-
return nil, errNotReadyConn
202+
return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
217203
}
218204
if c.currentTx != nil {
219205
return c.currentTx.QueryContext(ctx, query, args)
@@ -245,10 +231,10 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
245231
dataQueryOptions(ctx)...,
246232
)
247233
if err != nil {
248-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
234+
return nil, badconn.Map(xerrors.WithStackTrace(err))
249235
}
250236
if err = res.Err(); err != nil {
251-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
237+
return nil, badconn.Map(xerrors.WithStackTrace(err))
252238
}
253239
return &rows{
254240
conn: c,
@@ -262,10 +248,10 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
262248
scanQueryOptions(ctx)...,
263249
)
264250
if err != nil {
265-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
251+
return nil, badconn.Map(xerrors.WithStackTrace(err))
266252
}
267253
if err = res.Err(); err != nil {
268-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
254+
return nil, badconn.Map(xerrors.WithStackTrace(err))
269255
}
270256
return &rows{
271257
conn: c,
@@ -275,7 +261,7 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
275261
var exp table.DataQueryExplanation
276262
exp, err = c.session.Explain(ctx, query)
277263
if err != nil {
278-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
264+
return nil, badconn.Map(xerrors.WithStackTrace(err))
279265
}
280266
return &single{
281267
values: []sql.NamedArg{
@@ -287,10 +273,10 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
287273
var res result.StreamResult
288274
res, err = c.connector.connection.Scripting().StreamExecute(ctx, query, toQueryParams(args))
289275
if err != nil {
290-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
276+
return nil, badconn.Map(xerrors.WithStackTrace(err))
291277
}
292278
if err = res.Err(); err != nil {
293-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
279+
return nil, badconn.Map(xerrors.WithStackTrace(err))
294280
}
295281
return &rows{
296282
conn: c,
@@ -307,10 +293,10 @@ func (c *conn) Ping(ctx context.Context) (err error) {
307293
onDone(err)
308294
}()
309295
if !c.isReady() {
310-
return errNotReadyConn
296+
return badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
311297
}
312298
if err = c.session.KeepAlive(ctx); err != nil {
313-
return c.checkClosed(xerrors.WithStackTrace(err))
299+
return badconn.Map(xerrors.WithStackTrace(err))
314300
}
315301
return nil
316302
}
@@ -327,7 +313,7 @@ func (c *conn) Close() (err error) {
327313
}
328314
return nil
329315
}
330-
return errClosedConn
316+
return badconn.Map(xerrors.WithStackTrace(errClosedConn))
331317
}
332318

333319
func (c *conn) Prepare(string) (driver.Stmt, error) {
@@ -345,7 +331,7 @@ func (c *conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (_ drive
345331
onDone(transaction, err)
346332
}()
347333
if !c.isReady() {
348-
return nil, errNotReadyConn
334+
return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
349335
}
350336
if c.currentTx != nil {
351337
return nil, xerrors.WithStackTrace(
@@ -359,7 +345,7 @@ func (c *conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (_ drive
359345
}
360346
transaction, err = c.session.BeginTransaction(ctx, table.TxSettings(txc))
361347
if err != nil {
362-
return nil, c.checkClosed(xerrors.WithStackTrace(err))
348+
return nil, badconn.Map(xerrors.WithStackTrace(err))
363349
}
364350
c.currentTx = &tx{
365351
conn: c,

internal/xsql/errors.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@ import (
55
"errors"
66

77
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
8-
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
98
)
109

1110
var (
1211
ErrUnsupported = driver.ErrSkip
1312
errDeprecated = driver.ErrSkip
14-
errClosedConn = badconn.Map(xerrors.Retryable(errors.New("conn closed early"), xerrors.WithDeleteSession()))
15-
errNotReadyConn = badconn.Map(xerrors.Retryable(errors.New("conn not ready"), xerrors.WithDeleteSession()))
13+
errClosedConn = xerrors.Retryable(errors.New("conn closed early"), xerrors.WithDeleteSession())
14+
errNotReadyConn = xerrors.Retryable(errors.New("conn not ready"), xerrors.WithDeleteSession())
1615
)

internal/xsql/rows.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync"
99

1010
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
1112
"github.com/ydb-platform/ydb-go-sdk/v3/table/options"
1213
"github.com/ydb-platform/ydb-go-sdk/v3/table/result"
1314
"github.com/ydb-platform/ydb-go-sdk/v3/table/result/indexed"
@@ -61,10 +62,10 @@ func (r *rows) Next(dst []driver.Value) (err error) {
6162
err = r.result.NextResultSetErr(context.Background())
6263
})
6364
if err != nil {
64-
return r.conn.checkClosed(xerrors.WithStackTrace(err))
65+
return badconn.Map(xerrors.WithStackTrace(err))
6566
}
6667
if err = r.result.Err(); err != nil {
67-
return r.conn.checkClosed(xerrors.WithStackTrace(err))
68+
return badconn.Map(xerrors.WithStackTrace(err))
6869
}
6970
if !r.result.NextRow() {
7071
return io.EOF
@@ -74,13 +75,13 @@ func (r *rows) Next(dst []driver.Value) (err error) {
7475
values[i] = &valuer{}
7576
}
7677
if err = r.result.Scan(values...); err != nil {
77-
return r.conn.checkClosed(xerrors.WithStackTrace(err))
78+
return badconn.Map(xerrors.WithStackTrace(err))
7879
}
7980
for i := range values {
8081
dst[i] = values[i].(*valuer).Value()
8182
}
8283
if err = r.result.Err(); err != nil {
83-
return r.conn.checkClosed(xerrors.WithStackTrace(err))
84+
return badconn.Map(xerrors.WithStackTrace(err))
8485
}
8586
return nil
8687
}

internal/xsql/stmt.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"database/sql/driver"
66
"fmt"
77

8+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
9+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
810
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
911
)
1012

@@ -27,7 +29,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (_ dr
2729
onDone(err)
2830
}()
2931
if !s.conn.isReady() {
30-
return nil, errNotReadyConn
32+
return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
3133
}
3234
switch m := queryModeFromContext(ctx, s.conn.defaultQueryMode); m {
3335
case DataQueryMode:
@@ -43,7 +45,7 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (_ dri
4345
onDone(err)
4446
}()
4547
if !s.conn.isReady() {
46-
return nil, errNotReadyConn
48+
return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
4749
}
4850
switch m := queryModeFromContext(ctx, s.conn.defaultQueryMode); m {
4951
case DataQueryMode:

internal/xsql/tx.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
88
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
9+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
910
"github.com/ydb-platform/ydb-go-sdk/v3/table"
1011
"github.com/ydb-platform/ydb-go-sdk/v3/table/result"
1112
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
@@ -37,11 +38,11 @@ func (tx *tx) Commit() (err error) {
3738
tx.conn.currentTx = nil
3839
}()
3940
if !tx.conn.isReady() {
40-
return errNotReadyConn
41+
return badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
4142
}
4243
_, err = tx.tx.CommitTx(tx.ctx)
4344
if err != nil {
44-
return tx.conn.checkClosed(xerrors.WithStackTrace(err))
45+
return badconn.Map(xerrors.WithStackTrace(err))
4546
}
4647
return nil
4748
}
@@ -55,11 +56,11 @@ func (tx *tx) Rollback() (err error) {
5556
tx.conn.currentTx = nil
5657
}()
5758
if !tx.conn.isReady() {
58-
return errNotReadyConn
59+
return badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
5960
}
6061
err = tx.tx.Rollback(tx.ctx)
6162
if err != nil {
62-
return tx.conn.checkClosed(xerrors.WithStackTrace(err))
63+
return badconn.Map(xerrors.WithStackTrace(err))
6364
}
6465
return err
6566
}
@@ -76,10 +77,10 @@ func (tx *tx) QueryContext(ctx context.Context, query string, args []driver.Name
7677
dataQueryOptions(ctx)...,
7778
)
7879
if err != nil {
79-
return nil, tx.conn.checkClosed(xerrors.WithStackTrace(err))
80+
return nil, badconn.Map(xerrors.WithStackTrace(err))
8081
}
8182
if err = res.Err(); err != nil {
82-
return nil, tx.conn.checkClosed(xerrors.WithStackTrace(err))
83+
return nil, badconn.Map(xerrors.WithStackTrace(err))
8384
}
8485
return &rows{
8586
conn: tx.conn,
@@ -98,7 +99,7 @@ func (tx *tx) ExecContext(ctx context.Context, query string, args []driver.Named
9899
dataQueryOptions(ctx)...,
99100
)
100101
if err != nil {
101-
return nil, tx.conn.checkClosed(xerrors.WithStackTrace(err))
102+
return nil, badconn.Map(xerrors.WithStackTrace(err))
102103
}
103104
return driver.ResultNoRows, nil
104105
}

internal/xsql/unwrap.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func Unwrap(db *sql.DB) (connector *Connector, err error) {
1616
// hop with create session (connector.Connect()) helps to get ydb.Connection
1717
c, err := db.Conn(context.Background())
1818
if err != nil {
19-
return nil, xerrors.WithStackTrace(err)
19+
return nil, badconn.Map(xerrors.WithStackTrace(err))
2020
}
2121
if err = c.Raw(func(driverConn interface{}) error {
2222
if cc, ok := driverConn.(*conn); ok {
@@ -25,7 +25,7 @@ func Unwrap(db *sql.DB) (connector *Connector, err error) {
2525
}
2626
return xerrors.WithStackTrace(badconn.Map(fmt.Errorf("%+v is not a *conn", driverConn)))
2727
}); err != nil {
28-
return nil, xerrors.WithStackTrace(err)
28+
return nil, badconn.Map(xerrors.WithStackTrace(err))
2929
}
3030
return connector, nil
3131
}

internal/xsql/unwrap_go1.18.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99

1010
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
1112
)
1213

1314
func Unwrap[T *sql.DB | *sql.Conn](v T) (connector *Connector, err error) {
@@ -26,7 +27,7 @@ func Unwrap[T *sql.DB | *sql.Conn](v T) (connector *Connector, err error) {
2627
}
2728
return xerrors.WithStackTrace(fmt.Errorf("%T is not a *conn", driverConn))
2829
}); err != nil {
29-
return nil, xerrors.WithStackTrace(err)
30+
return nil, badconn.Map(xerrors.WithStackTrace(err))
3031
}
3132
return connector, nil
3233
default:

table/table_e2e_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,9 @@ func TestNullType(t *testing.T) {
20402040
}
20412041

20422042
func TestTypeToString(t *testing.T) {
2043+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
2044+
defer cancel()
2045+
20432046
db, err := sql.Open("ydb", os.Getenv("YDB_CONNECTION_STRING"))
20442047
if err != nil {
20452048
t.Fatal(err)
@@ -2112,7 +2115,7 @@ func TestTypeToString(t *testing.T) {
21122115
} {
21132116
t.Run(tt.Yql(), func(t *testing.T) {
21142117
var got string
2115-
err := retry.Do(context.Background(), db, func(ctx context.Context, cc *sql.Conn) error {
2118+
err := retry.Do(ctx, db, func(ctx context.Context, cc *sql.Conn) error {
21162119
row := cc.QueryRowContext(ctx,
21172120
fmt.Sprintf("SELECT FormatType(ParseType(\"%s\"))", tt.Yql()),
21182121
)

0 commit comments

Comments
 (0)