Skip to content

Commit 98b3648

Browse files
committed
* Supported the *sql.Conn as input type ydb.Unwrap helper for go's 1.18
1 parent f9af1b9 commit 98b3648

File tree

8 files changed

+123
-31
lines changed

8 files changed

+123
-31
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Supported the `*sql.Conn` as input type `ydb.Unwrap` helper for go's 1.18
2+
13
## v3.36.2
24
* Changed output of `sugar.GenerateDeclareSection` (added error as second result)
35
* Specified `sugar.GenerateDeclareSection` for `go1.18` (supports input types `*table.QueryParameters` `[]table.ParameterOption` or `[]sql.NamedArg`)

internal/xsql/conn.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -325,21 +325,3 @@ func (c *conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (_ drive
325325
}
326326
return c.currentTx, nil
327327
}
328-
329-
func Unwrap(db *sql.DB) (connector *Connector, err error) {
330-
// hop with create session (connector.Connect()) helps to get ydb.Connection
331-
c, err := db.Conn(context.Background())
332-
if err != nil {
333-
return nil, xerrors.WithStackTrace(err)
334-
}
335-
if err = c.Raw(func(driverConn interface{}) error {
336-
if cc, ok := driverConn.(*conn); ok {
337-
connector = cc.connector
338-
return nil
339-
}
340-
return xerrors.WithStackTrace(badconn.Map(fmt.Errorf("%+v is not a *conn", driverConn)))
341-
}); err != nil {
342-
return nil, xerrors.WithStackTrace(err)
343-
}
344-
return connector, nil
345-
}

internal/xsql/connector.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,13 @@ func (c *Connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
129129
}
130130

131131
func (c *Connector) Driver() driver.Driver {
132-
return c.driver
132+
return &driverWrapper{c: c}
133+
}
134+
135+
type driverWrapper struct {
136+
c *Connector
137+
}
138+
139+
func (d *driverWrapper) Open(name string) (driver.Conn, error) {
140+
return d.c.driver.Open(name)
133141
}

internal/xsql/unwrap.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//go:build !go1.18
2+
// +build !go1.18
3+
4+
package xsql
5+
6+
import (
7+
"context"
8+
"database/sql"
9+
"fmt"
10+
11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
12+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
13+
)
14+
15+
func Unwrap(db *sql.DB) (connector *Connector, err error) {
16+
// hop with create session (connector.Connect()) helps to get ydb.Connection
17+
c, err := db.Conn(context.Background())
18+
if err != nil {
19+
return nil, xerrors.WithStackTrace(err)
20+
}
21+
if err = c.Raw(func(driverConn interface{}) error {
22+
if cc, ok := driverConn.(*conn); ok {
23+
connector = cc.connector
24+
return nil
25+
}
26+
return xerrors.WithStackTrace(badconn.Map(fmt.Errorf("%+v is not a *conn", driverConn)))
27+
}); err != nil {
28+
return nil, xerrors.WithStackTrace(err)
29+
}
30+
return connector, nil
31+
}

internal/xsql/unwrap_go1.18.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//go:build go1.18
2+
// +build go1.18
3+
4+
package xsql
5+
6+
import (
7+
"database/sql"
8+
"fmt"
9+
10+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
11+
)
12+
13+
func Unwrap[T *sql.DB | *sql.Conn](v T) (connector *Connector, err error) {
14+
switch vv := any(v).(type) {
15+
case *sql.DB:
16+
d := vv.Driver()
17+
if dw, ok := d.(*driverWrapper); ok {
18+
return dw.c, nil
19+
}
20+
return nil, xerrors.WithStackTrace(fmt.Errorf("%T is not a *driverWrapper", d))
21+
case *sql.Conn:
22+
if err = vv.Raw(func(driverConn interface{}) error {
23+
if cc, ok := driverConn.(*conn); ok {
24+
connector = cc.connector
25+
return nil
26+
}
27+
return xerrors.WithStackTrace(fmt.Errorf("%T is not a *conn", driverConn))
28+
}); err != nil {
29+
return nil, xerrors.WithStackTrace(err)
30+
}
31+
return connector, nil
32+
default:
33+
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown type %T for Unwrap", vv))
34+
}
35+
}

sql.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7-
"fmt"
87

98
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
109
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql"
@@ -128,14 +127,3 @@ func Connector(db Connection, opts ...ConnectorOption) (*xsql.Connector, error)
128127
}
129128
return c, nil
130129
}
131-
132-
func Unwrap(db *sql.DB) (Connection, error) {
133-
c, err := xsql.Unwrap(db)
134-
if err != nil {
135-
return nil, xerrors.WithStackTrace(err)
136-
}
137-
if cc, ok := c.Connection().(Connection); ok {
138-
return cc, nil
139-
}
140-
return nil, xerrors.WithStackTrace(fmt.Errorf("%+v is not a ydb.Nonnection", c.Connection()))
141-
}

sql_unwrap.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//go:build !go1.18
2+
// +build !go1.18
3+
4+
package ydb
5+
6+
import (
7+
"database/sql"
8+
"fmt"
9+
10+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql"
12+
)
13+
14+
func Unwrap(db *sql.DB) (Connection, error) {
15+
c, err := xsql.Unwrap(db)
16+
if err != nil {
17+
return nil, xerrors.WithStackTrace(err)
18+
}
19+
if cc, ok := c.Connection().(Connection); ok {
20+
return cc, nil
21+
}
22+
return nil, xerrors.WithStackTrace(fmt.Errorf("%+v is not a ydb.Connection", c.Connection()))
23+
}

sql_unwrap_go1.18.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//go:build go1.18
2+
// +build go1.18
3+
4+
package ydb
5+
6+
import (
7+
"database/sql"
8+
"fmt"
9+
10+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql"
12+
)
13+
14+
func Unwrap[T *sql.DB | *sql.Conn](v T) (Connection, error) {
15+
c, err := xsql.Unwrap(v)
16+
if err != nil {
17+
return nil, xerrors.WithStackTrace(err)
18+
}
19+
if cc, ok := c.Connection().(Connection); ok {
20+
return cc, nil
21+
}
22+
return nil, xerrors.WithStackTrace(fmt.Errorf("%+v is not a ydb.Connection", c.Connection()))
23+
}

0 commit comments

Comments
 (0)