44 "context"
55 "database/sql"
66 "errors"
7+ "fmt"
8+ "math/rand"
9+ "sync/atomic"
710
811 "github.com/zeromicro/go-zero/core/breaker"
912 "github.com/zeromicro/go-zero/core/errorx"
@@ -52,9 +55,10 @@ type (
5255 beginTx beginnable
5356 brk breaker.Breaker
5457 accept breaker.Acceptable
58+ index uint32
5559 }
5660
57- connProvider func () (* sql.DB , error )
61+ connProvider func (ctx context. Context ) (* sql.DB , error )
5862
5963 sessionConn interface {
6064 Exec (query string , args ... any ) (sql.Result , error )
@@ -64,10 +68,31 @@ type (
6468 }
6569)
6670
71+ // NewConn returns a SqlConn with the given SqlConf.
72+ func NewConn (c SqlConf , opts ... SqlOption ) SqlConn {
73+ if err := c .Validate (); err != nil {
74+ logx .Must (err )
75+ }
76+
77+ conn := & commonSqlConn {
78+ onError : func (ctx context.Context , err error ) {
79+ logInstanceError (ctx , c .DataSource , err )
80+ },
81+ beginTx : begin ,
82+ brk : breaker .NewBreaker (),
83+ }
84+ for _ , opt := range opts {
85+ opt (conn )
86+ }
87+ conn .connProv = getConnProvider (conn , c .DriverName , c .DataSource , c .Policy , c .Replicas )
88+
89+ return conn
90+ }
91+
6792// NewSqlConn returns a SqlConn with given driver name and datasource.
6893func NewSqlConn (driverName , datasource string , opts ... SqlOption ) SqlConn {
6994 conn := & commonSqlConn {
70- connProv : func () (* sql.DB , error ) {
95+ connProv : func (context. Context ) (* sql.DB , error ) {
7196 return getSqlConn (driverName , datasource )
7297 },
7398 onError : func (ctx context.Context , err error ) {
@@ -87,7 +112,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
87112// Use it with caution; it's provided for other ORM to interact with.
88113func NewSqlConnFromDB (db * sql.DB , opts ... SqlOption ) SqlConn {
89114 conn := & commonSqlConn {
90- connProv : func () (* sql.DB , error ) {
115+ connProv : func (ctx context. Context ) (* sql.DB , error ) {
91116 return db , nil
92117 },
93118 onError : func (ctx context.Context , err error ) {
@@ -123,7 +148,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
123148
124149 err = db .brk .DoWithAcceptableCtx (ctx , func () error {
125150 var conn * sql.DB
126- conn , err = db .connProv ()
151+ conn , err = db .connProv (ctx )
127152 if err != nil {
128153 db .onError (ctx , err )
129154 return err
@@ -151,7 +176,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
151176
152177 err = db .brk .DoWithAcceptableCtx (ctx , func () error {
153178 var conn * sql.DB
154- conn , err = db .connProv ()
179+ conn , err = db .connProv (ctx )
155180 if err != nil {
156181 db .onError (ctx , err )
157182 return err
@@ -242,7 +267,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
242267}
243268
244269func (db * commonSqlConn ) RawDB () (* sql.DB , error ) {
245- return db .connProv ()
270+ return db .connProv (context . Background () )
246271}
247272
248273func (db * commonSqlConn ) Transact (fn func (Session ) error ) error {
@@ -288,7 +313,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
288313 q string , args ... any ) (err error ) {
289314 var scanFailed bool
290315 err = db .brk .DoWithAcceptableCtx (ctx , func () error {
291- conn , err := db .connProv ()
316+ conn , err := db .connProv (ctx )
292317 if err != nil {
293318 db .onError (ctx , err )
294319 return err
@@ -311,6 +336,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
311336 return
312337}
313338
339+ func getConnProvider (sc * commonSqlConn , driverName , datasource , policy string , replicas []string ) connProvider {
340+ return func (ctx context.Context ) (* sql.DB , error ) {
341+ replicaCount := len (replicas )
342+
343+ if replicaCount == 0 || ! useReplica (ctx ) {
344+ return getSqlConn (driverName , datasource )
345+ }
346+
347+ var dsn string
348+
349+ if replicaCount == 1 {
350+ dsn = replicas [0 ]
351+ } else {
352+ if len (policy ) == 0 {
353+ policy = policyRoundRobin
354+ }
355+
356+ switch policy {
357+ case policyRandom :
358+ dsn = replicas [rand .Intn (replicaCount )]
359+ case policyRoundRobin :
360+ index := atomic .AddUint32 (& sc .index , 1 ) - 1
361+ dsn = replicas [index % uint32 (replicaCount )]
362+ default :
363+ return nil , fmt .Errorf ("unknown policy: %s" , policy )
364+ }
365+ }
366+
367+ return getSqlConn (driverName , dsn )
368+ }
369+ }
370+
314371// WithAcceptable returns a SqlOption that setting the acceptable function.
315372// acceptable is the func to check if the error can be accepted.
316373func WithAcceptable (acceptable func (err error ) bool ) SqlOption {
0 commit comments