Skip to content

Commit bae8d4f

Browse files
authored
chore: refactoring sql read write mode (#4990)
Signed-off-by: kevin <[email protected]>
1 parent 8c6266f commit bae8d4f

File tree

4 files changed

+62
-43
lines changed

4 files changed

+62
-43
lines changed

core/stores/sqlx/rwstrategy.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,25 @@ const (
2727
notSpecifiedMode readWriteMode = ""
2828
)
2929

30-
type readWriteMode string
31-
3230
var readWriteModeKey struct{}
3331

32+
// WithReadPrimary sets the context to read-primary mode.
33+
func WithReadPrimary(ctx context.Context) context.Context {
34+
return context.WithValue(ctx, readWriteModeKey, readPrimaryMode)
35+
}
36+
37+
// WithReadReplica sets the context to read-replica mode.
38+
func WithReadReplica(ctx context.Context) context.Context {
39+
return context.WithValue(ctx, readWriteModeKey, readReplicaMode)
40+
}
41+
42+
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
43+
func WithWrite(ctx context.Context) context.Context {
44+
return context.WithValue(ctx, readWriteModeKey, writeMode)
45+
}
46+
47+
type readWriteMode string
48+
3449
func (m readWriteMode) isValid() bool {
3550
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
3651
}
@@ -45,21 +60,6 @@ func getReadWriteMode(ctx context.Context) readWriteMode {
4560
return notSpecifiedMode
4661
}
4762

48-
func useReplica(ctx context.Context) bool {
49-
return getReadWriteMode(ctx) == readReplicaMode
50-
}
51-
52-
// WithReadPrimaryMode sets the context to read-primary mode.
53-
func WithReadPrimaryMode(ctx context.Context) context.Context {
54-
return context.WithValue(ctx, readWriteModeKey, readPrimaryMode)
55-
}
56-
57-
// WithReadReplicaMode sets the context to read-replica mode.
58-
func WithReadReplicaMode(ctx context.Context) context.Context {
59-
return context.WithValue(ctx, readWriteModeKey, readReplicaMode)
60-
}
61-
62-
// WithWriteMode sets the context to write mode, indicating that the operation is a write operation.
63-
func WithWriteMode(ctx context.Context) context.Context {
64-
return context.WithValue(ctx, readWriteModeKey, writeMode)
63+
func usePrimary(ctx context.Context) bool {
64+
return getReadWriteMode(ctx) != readReplicaMode
6565
}

core/stores/sqlx/rwstrategy_test.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,19 @@ func TestIsValid(t *testing.T) {
5555

5656
func TestWithReadMode(t *testing.T) {
5757
ctx := context.Background()
58-
readPrimaryCtx := WithReadPrimaryMode(ctx)
58+
readPrimaryCtx := WithReadPrimary(ctx)
5959

6060
val := readPrimaryCtx.Value(readWriteModeKey)
6161
assert.Equal(t, readPrimaryMode, val)
6262

63-
readReplicaCtx := WithReadReplicaMode(ctx)
63+
readReplicaCtx := WithReadReplica(ctx)
6464
val = readReplicaCtx.Value(readWriteModeKey)
6565
assert.Equal(t, readReplicaMode, val)
6666
}
6767

6868
func TestWithWriteMode(t *testing.T) {
6969
ctx := context.Background()
70-
writeCtx := WithWriteMode(ctx)
70+
writeCtx := WithWrite(ctx)
7171

7272
val := writeCtx.Value(readWriteModeKey)
7373
assert.Equal(t, writeMode, val)
@@ -105,29 +105,38 @@ func TestGetReadWriteMode(t *testing.T) {
105105
})
106106
}
107107

108-
func TestUuseReplica(t *testing.T) {
108+
func TestUsePrimary(t *testing.T) {
109109
t.Run("context with read-replica mode", func(t *testing.T) {
110110
ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode)
111-
assert.True(t, useReplica(ctx))
111+
assert.False(t, usePrimary(ctx))
112112
})
113113

114114
t.Run("context with read-primary mode", func(t *testing.T) {
115115
ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode)
116-
assert.False(t, useReplica(ctx))
116+
assert.True(t, usePrimary(ctx))
117117
})
118118

119119
t.Run("context with write mode", func(t *testing.T) {
120120
ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode)
121-
assert.False(t, useReplica(ctx))
121+
assert.True(t, usePrimary(ctx))
122122
})
123123

124124
t.Run("context with invalid mode", func(t *testing.T) {
125125
ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("invalid"))
126-
assert.False(t, useReplica(ctx))
126+
assert.True(t, usePrimary(ctx))
127127
})
128128

129129
t.Run("context with no mode set", func(t *testing.T) {
130130
ctx := context.Background()
131-
assert.False(t, useReplica(ctx))
131+
assert.True(t, usePrimary(ctx))
132132
})
133133
}
134+
135+
func TestWithModeTwice(t *testing.T) {
136+
ctx := context.Background()
137+
ctx = WithReadPrimary(ctx)
138+
writeCtx := WithWrite(ctx)
139+
140+
val := writeCtx.Value(readWriteModeKey)
141+
assert.Equal(t, writeMode, val)
142+
}

core/stores/sqlx/sqlconn.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,20 @@ type (
6868
}
6969
)
7070

71+
// MustNewConn returns a SqlConn with the given SqlConf.
72+
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
73+
conn, err := NewConn(c, opts...)
74+
if err != nil {
75+
logx.Must(err)
76+
}
77+
78+
return conn
79+
}
80+
7181
// NewConn returns a SqlConn with the given SqlConf.
72-
func NewConn(c SqlConf, opts ...SqlOption) SqlConn {
82+
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
7383
if err := c.Validate(); err != nil {
74-
logx.Must(err)
84+
return nil, err
7585
}
7686

7787
conn := &commonSqlConn{
@@ -86,7 +96,7 @@ func NewConn(c SqlConf, opts ...SqlOption) SqlConn {
8696
}
8797
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
8898

89-
return conn
99+
return conn, nil
90100
}
91101

92102
// NewSqlConn returns a SqlConn with given driver name and datasource.
@@ -340,7 +350,7 @@ func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, r
340350
return func(ctx context.Context) (*sql.DB, error) {
341351
replicaCount := len(replicas)
342352

343-
if replicaCount == 0 || !useReplica(ctx) {
353+
if replicaCount == 0 || usePrimary(ctx) {
344354
return getSqlConn(driverName, datasource)
345355
}
346356

core/stores/sqlx/sqlconn_test.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func TestConfigSqlConn(t *testing.T) {
149149
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
150150

151151
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
152-
conn := NewConn(conf, withMysqlAcceptable())
152+
conn := MustNewConn(conf, withMysqlAcceptable())
153153

154154
_, err = conn.Exec("any", "value")
155155
assert.NotNil(t, err)
@@ -177,7 +177,7 @@ func TestConfigSqlConnStatement(t *testing.T) {
177177
mock.ExpectQuery("any").WillReturnRows(row)
178178

179179
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
180-
conn := NewConn(conf, withMysqlAcceptable())
180+
conn := MustNewConn(conf, withMysqlAcceptable())
181181
stmt, err := conn.Prepare("any")
182182
assert.NoError(t, err)
183183

@@ -220,7 +220,7 @@ func TestConfigSqlConnQuery(t *testing.T) {
220220
t.Run("QueryRow", func(t *testing.T) {
221221
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
222222
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
223-
conn := NewConn(conf)
223+
conn := MustNewConn(conf)
224224
var val string
225225
assert.NoError(t, conn.QueryRow(&val, "any"))
226226
assert.Equal(t, "bar", val)
@@ -229,7 +229,7 @@ func TestConfigSqlConnQuery(t *testing.T) {
229229
t.Run("QueryRowPartial", func(t *testing.T) {
230230
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
231231
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
232-
conn := NewConn(conf)
232+
conn := MustNewConn(conf)
233233
var val string
234234
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
235235
assert.Equal(t, "bar", val)
@@ -238,7 +238,7 @@ func TestConfigSqlConnQuery(t *testing.T) {
238238
t.Run("QueryRows", func(t *testing.T) {
239239
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
240240
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
241-
conn := NewConn(conf)
241+
conn := MustNewConn(conf)
242242
var vals []string
243243
assert.NoError(t, conn.QueryRows(&vals, "any"))
244244
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
@@ -247,7 +247,7 @@ func TestConfigSqlConnQuery(t *testing.T) {
247247
t.Run("QueryRowsPartial", func(t *testing.T) {
248248
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
249249
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
250-
conn := NewConn(conf)
250+
conn := MustNewConn(conf)
251251
var vals []string
252252
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
253253
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
@@ -261,7 +261,7 @@ func TestConfigSqlConnErr(t *testing.T) {
261261
defer logx.ExitOnFatal.Set(original)
262262

263263
assert.Panics(t, func() {
264-
NewConn(SqlConf{})
264+
MustNewConn(SqlConf{})
265265
})
266266
})
267267
t.Run("on error", func(t *testing.T) {
@@ -272,7 +272,7 @@ func TestConfigSqlConnErr(t *testing.T) {
272272
connManager.Inject(mockedDatasource, db)
273273

274274
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
275-
conn := NewConn(conf)
275+
conn := MustNewConn(conf)
276276
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
277277
return nil, errors.New("error")
278278
}
@@ -479,12 +479,12 @@ func TestProvider(t *testing.T) {
479479
assert.Nil(t, err)
480480
assert.Equal(t, primaryDB, db)
481481

482-
ctx = WithWriteMode(ctx)
482+
ctx = WithWrite(ctx)
483483
db, err = sc.connProv(ctx)
484484
assert.Nil(t, err)
485485
assert.Equal(t, primaryDB, db)
486486

487-
ctx = WithReadPrimaryMode(ctx)
487+
ctx = WithReadPrimary(ctx)
488488
db, err = sc.connProv(ctx)
489489
assert.Nil(t, err)
490490
assert.Equal(t, primaryDB, db)
@@ -496,7 +496,7 @@ func TestProvider(t *testing.T) {
496496
assert.Nil(t, err)
497497
assert.Equal(t, primaryDB, db)
498498

499-
ctx = WithReadReplicaMode(ctx)
499+
ctx = WithReadReplica(ctx)
500500
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
501501
db, err = sc.connProv(ctx)
502502
assert.Nil(t, err)

0 commit comments

Comments
 (0)