Skip to content

Commit 8c6266f

Browse files
zhoushuguanglight.zhou
andauthored
sql read write support (#4976)
Co-authored-by: light.zhou <[email protected]>
1 parent 95d5b81 commit 8c6266f

File tree

8 files changed

+553
-10
lines changed

8 files changed

+553
-10
lines changed

core/stores/sqlx/config.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package sqlx
2+
3+
import "errors"
4+
5+
var (
6+
errEmptyDatasource = errors.New("empty datasource")
7+
errEmptyDriverName = errors.New("empty driver name")
8+
)
9+
10+
// SqlConf defines the configuration for sqlx.
11+
type SqlConf struct {
12+
DataSource string
13+
DriverName string `json:",default=mysql"`
14+
Replicas []string `json:",optional"`
15+
Policy string `json:",default=round-robin,options=round-robin|random"`
16+
}
17+
18+
// Validate validates the SqlxConf.
19+
func (sc SqlConf) Validate() error {
20+
if len(sc.DataSource) == 0 {
21+
return errEmptyDatasource
22+
}
23+
24+
if len(sc.DriverName) == 0 {
25+
return errEmptyDriverName
26+
}
27+
28+
return nil
29+
}

core/stores/sqlx/config_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package sqlx
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/zeromicro/go-zero/core/conf"
8+
)
9+
10+
func TestValidate(t *testing.T) {
11+
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
12+
`)
13+
14+
var sc SqlConf
15+
err := conf.LoadFromYamlBytes(text, &sc)
16+
assert.Nil(t, err)
17+
assert.Equal(t, "mysql", sc.DriverName)
18+
assert.Equal(t, policyRoundRobin, sc.Policy)
19+
assert.Nil(t, sc.Validate())
20+
21+
sc = SqlConf{}
22+
assert.Equal(t, errEmptyDatasource, sc.Validate())
23+
24+
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
25+
assert.Equal(t, errEmptyDriverName, sc.Validate())
26+
27+
sc.DriverName = "mysql"
28+
assert.Nil(t, sc.Validate())
29+
}

core/stores/sqlx/rwstrategy.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package sqlx
2+
3+
import "context"
4+
5+
const (
6+
// policyRoundRobin round-robin policy for selecting replicas.
7+
policyRoundRobin = "round-robin"
8+
// policyRandom random policy for selecting replicas.
9+
policyRandom = "random"
10+
11+
// readPrimaryMode indicates that the operation is a read,
12+
// but should be performed on the primary database instance.
13+
//
14+
// This mode is used in scenarios where data freshness and consistency are critical,
15+
// such as immediately after writes or where replication lag may cause stale reads.
16+
readPrimaryMode readWriteMode = "read-primary"
17+
18+
// readReplicaMode indicates that the operation is a read from replicas.
19+
// This is suitable for scenarios where eventual consistency is acceptable,
20+
// and the goal is to offload traffic from the primary and improve read scalability.
21+
readReplicaMode readWriteMode = "read-replica"
22+
23+
// writeMode indicates that the operation is a write operation (to primary).
24+
writeMode readWriteMode = "write"
25+
26+
// notSpecifiedMode indicates that the read/write mode is not specified.
27+
notSpecifiedMode readWriteMode = ""
28+
)
29+
30+
type readWriteMode string
31+
32+
var readWriteModeKey struct{}
33+
34+
func (m readWriteMode) isValid() bool {
35+
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
36+
}
37+
38+
func getReadWriteMode(ctx context.Context) readWriteMode {
39+
if mode := ctx.Value(readWriteModeKey); mode != nil {
40+
if v, ok := mode.(readWriteMode); ok && v.isValid() {
41+
return v
42+
}
43+
}
44+
45+
return notSpecifiedMode
46+
}
47+
48+
func useReplica(ctx context.Context) bool {
49+
return getReadWriteMode(ctx) == readReplicaMode
50+
}
51+
52+
// WithReadPrimaryMode sets the context to read-primary mode.
53+
func WithReadPrimaryMode(ctx context.Context) context.Context {
54+
return context.WithValue(ctx, readWriteModeKey, readPrimaryMode)
55+
}
56+
57+
// WithReadReplicaMode sets the context to read-replica mode.
58+
func WithReadReplicaMode(ctx context.Context) context.Context {
59+
return context.WithValue(ctx, readWriteModeKey, readReplicaMode)
60+
}
61+
62+
// WithWriteMode sets the context to write mode, indicating that the operation is a write operation.
63+
func WithWriteMode(ctx context.Context) context.Context {
64+
return context.WithValue(ctx, readWriteModeKey, writeMode)
65+
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package sqlx
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestIsValid(t *testing.T) {
11+
testCases := []struct {
12+
name string
13+
mode readWriteMode
14+
expected bool
15+
}{
16+
{
17+
name: "valid read-primary mode",
18+
mode: readPrimaryMode,
19+
expected: true,
20+
},
21+
{
22+
name: "valid read-replica mode",
23+
mode: readReplicaMode,
24+
expected: true,
25+
},
26+
{
27+
name: "valid write mode",
28+
mode: writeMode,
29+
expected: true,
30+
},
31+
{
32+
name: "not specified mode (empty)",
33+
mode: notSpecifiedMode,
34+
expected: false,
35+
},
36+
{
37+
name: "invalid custom string",
38+
mode: readWriteMode("delete"),
39+
expected: false,
40+
},
41+
{
42+
name: "case sensitive check",
43+
mode: readWriteMode("READ"),
44+
expected: false,
45+
},
46+
}
47+
48+
for _, tc := range testCases {
49+
t.Run(tc.name, func(t *testing.T) {
50+
actual := tc.mode.isValid()
51+
assert.Equal(t, tc.expected, actual)
52+
})
53+
}
54+
}
55+
56+
func TestWithReadMode(t *testing.T) {
57+
ctx := context.Background()
58+
readPrimaryCtx := WithReadPrimaryMode(ctx)
59+
60+
val := readPrimaryCtx.Value(readWriteModeKey)
61+
assert.Equal(t, readPrimaryMode, val)
62+
63+
readReplicaCtx := WithReadReplicaMode(ctx)
64+
val = readReplicaCtx.Value(readWriteModeKey)
65+
assert.Equal(t, readReplicaMode, val)
66+
}
67+
68+
func TestWithWriteMode(t *testing.T) {
69+
ctx := context.Background()
70+
writeCtx := WithWriteMode(ctx)
71+
72+
val := writeCtx.Value(readWriteModeKey)
73+
assert.Equal(t, writeMode, val)
74+
}
75+
76+
func TestGetReadWriteMode(t *testing.T) {
77+
t.Run("valid read-primary mode", func(t *testing.T) {
78+
ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode)
79+
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
80+
})
81+
82+
t.Run("valid read-replica mode", func(t *testing.T) {
83+
ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode)
84+
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
85+
})
86+
87+
t.Run("valid write mode", func(t *testing.T) {
88+
ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode)
89+
assert.Equal(t, writeMode, getReadWriteMode(ctx))
90+
})
91+
92+
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
93+
ctx := context.WithValue(context.Background(), readWriteModeKey, "not-a-mode")
94+
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
95+
})
96+
97+
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
98+
ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("delete"))
99+
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
100+
})
101+
102+
t.Run("no mode set", func(t *testing.T) {
103+
ctx := context.Background()
104+
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
105+
})
106+
}
107+
108+
func TestUuseReplica(t *testing.T) {
109+
t.Run("context with read-replica mode", func(t *testing.T) {
110+
ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode)
111+
assert.True(t, useReplica(ctx))
112+
})
113+
114+
t.Run("context with read-primary mode", func(t *testing.T) {
115+
ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode)
116+
assert.False(t, useReplica(ctx))
117+
})
118+
119+
t.Run("context with write mode", func(t *testing.T) {
120+
ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode)
121+
assert.False(t, useReplica(ctx))
122+
})
123+
124+
t.Run("context with invalid mode", func(t *testing.T) {
125+
ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("invalid"))
126+
assert.False(t, useReplica(ctx))
127+
})
128+
129+
t.Run("context with no mode set", func(t *testing.T) {
130+
ctx := context.Background()
131+
assert.False(t, useReplica(ctx))
132+
})
133+
}

core/stores/sqlx/sqlconn.go

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import (
44
"context"
55
"database/sql"
66
"errors"
7+
"fmt"
8+
"math/rand"
9+
"sync/atomic"
710

811
"github.com/zeromicro/go-zero/core/breaker"
912
"github.com/zeromicro/go-zero/core/errorx"
@@ -52,9 +55,10 @@ type (
5255
beginTx beginnable
5356
brk breaker.Breaker
5457
accept breaker.Acceptable
58+
index uint32
5559
}
5660

57-
connProvider func() (*sql.DB, error)
61+
connProvider func(ctx context.Context) (*sql.DB, error)
5862

5963
sessionConn interface {
6064
Exec(query string, args ...any) (sql.Result, error)
@@ -64,10 +68,31 @@ type (
6468
}
6569
)
6670

71+
// NewConn returns a SqlConn with the given SqlConf.
72+
func NewConn(c SqlConf, opts ...SqlOption) SqlConn {
73+
if err := c.Validate(); err != nil {
74+
logx.Must(err)
75+
}
76+
77+
conn := &commonSqlConn{
78+
onError: func(ctx context.Context, err error) {
79+
logInstanceError(ctx, c.DataSource, err)
80+
},
81+
beginTx: begin,
82+
brk: breaker.NewBreaker(),
83+
}
84+
for _, opt := range opts {
85+
opt(conn)
86+
}
87+
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
88+
89+
return conn
90+
}
91+
6792
// NewSqlConn returns a SqlConn with given driver name and datasource.
6893
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
6994
conn := &commonSqlConn{
70-
connProv: func() (*sql.DB, error) {
95+
connProv: func(context.Context) (*sql.DB, error) {
7196
return getSqlConn(driverName, datasource)
7297
},
7398
onError: func(ctx context.Context, err error) {
@@ -87,7 +112,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
87112
// Use it with caution; it's provided for other ORM to interact with.
88113
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
89114
conn := &commonSqlConn{
90-
connProv: func() (*sql.DB, error) {
115+
connProv: func(ctx context.Context) (*sql.DB, error) {
91116
return db, nil
92117
},
93118
onError: func(ctx context.Context, err error) {
@@ -123,7 +148,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
123148

124149
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
125150
var conn *sql.DB
126-
conn, err = db.connProv()
151+
conn, err = db.connProv(ctx)
127152
if err != nil {
128153
db.onError(ctx, err)
129154
return err
@@ -151,7 +176,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
151176

152177
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
153178
var conn *sql.DB
154-
conn, err = db.connProv()
179+
conn, err = db.connProv(ctx)
155180
if err != nil {
156181
db.onError(ctx, err)
157182
return err
@@ -242,7 +267,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
242267
}
243268

244269
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
245-
return db.connProv()
270+
return db.connProv(context.Background())
246271
}
247272

248273
func (db *commonSqlConn) Transact(fn func(Session) error) error {
@@ -288,7 +313,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
288313
q string, args ...any) (err error) {
289314
var scanFailed bool
290315
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
291-
conn, err := db.connProv()
316+
conn, err := db.connProv(ctx)
292317
if err != nil {
293318
db.onError(ctx, err)
294319
return err
@@ -311,6 +336,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
311336
return
312337
}
313338

339+
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
340+
return func(ctx context.Context) (*sql.DB, error) {
341+
replicaCount := len(replicas)
342+
343+
if replicaCount == 0 || !useReplica(ctx) {
344+
return getSqlConn(driverName, datasource)
345+
}
346+
347+
var dsn string
348+
349+
if replicaCount == 1 {
350+
dsn = replicas[0]
351+
} else {
352+
if len(policy) == 0 {
353+
policy = policyRoundRobin
354+
}
355+
356+
switch policy {
357+
case policyRandom:
358+
dsn = replicas[rand.Intn(replicaCount)]
359+
case policyRoundRobin:
360+
index := atomic.AddUint32(&sc.index, 1) - 1
361+
dsn = replicas[index%uint32(replicaCount)]
362+
default:
363+
return nil, fmt.Errorf("unknown policy: %s", policy)
364+
}
365+
}
366+
367+
return getSqlConn(driverName, dsn)
368+
}
369+
}
370+
314371
// WithAcceptable returns a SqlOption that setting the acceptable function.
315372
// acceptable is the func to check if the error can be accepted.
316373
func WithAcceptable(acceptable func(err error) bool) SqlOption {

0 commit comments

Comments
 (0)