Skip to content

Commit bb81382

Browse files
added configurable policy resolvers
1 parent d241b05 commit bb81382

File tree

5 files changed

+208
-23
lines changed

5 files changed

+208
-23
lines changed

command_policy_resolver.go

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package redis
2+
3+
import (
4+
"context"
5+
"strings"
6+
7+
"github.com/redis/go-redis/v9/internal/routing"
8+
)
9+
10+
type (
11+
module = string
12+
commandName = string
13+
)
14+
15+
var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{
16+
"ft": {
17+
"create": {
18+
Request: routing.ReqDefault,
19+
Response: routing.RespDefaultKeyless,
20+
},
21+
"search": {
22+
Request: routing.ReqDefault,
23+
Response: routing.RespDefaultKeyless,
24+
},
25+
"aggregate": {
26+
Request: routing.ReqDefault,
27+
Response: routing.RespDefaultKeyless,
28+
},
29+
"dictadd": {
30+
Request: routing.ReqDefault,
31+
Response: routing.RespDefaultKeyless,
32+
},
33+
"dictdump": {
34+
Request: routing.ReqDefault,
35+
Response: routing.RespDefaultKeyless,
36+
},
37+
"dictdel": {
38+
Request: routing.ReqDefault,
39+
Response: routing.RespDefaultKeyless,
40+
},
41+
"suglen": {
42+
Request: routing.ReqDefault,
43+
Response: routing.RespDefaultHashSlot,
44+
},
45+
"cursor": {
46+
Request: routing.ReqSpecial,
47+
Response: routing.RespDefaultKeyless,
48+
},
49+
"sugadd": {
50+
Request: routing.ReqDefault,
51+
Response: routing.RespDefaultHashSlot,
52+
},
53+
"sugget": {
54+
Request: routing.ReqDefault,
55+
Response: routing.RespDefaultHashSlot,
56+
},
57+
"sugdel": {
58+
Request: routing.ReqDefault,
59+
Response: routing.RespDefaultHashSlot,
60+
},
61+
"spellcheck": {
62+
Request: routing.ReqDefault,
63+
Response: routing.RespDefaultKeyless,
64+
},
65+
"explain": {
66+
Request: routing.ReqDefault,
67+
Response: routing.RespDefaultKeyless,
68+
},
69+
"explaincli": {
70+
Request: routing.ReqDefault,
71+
Response: routing.RespDefaultKeyless,
72+
},
73+
"aliasadd": {
74+
Request: routing.ReqDefault,
75+
Response: routing.RespDefaultKeyless,
76+
},
77+
"aliasupdate": {
78+
Request: routing.ReqDefault,
79+
Response: routing.RespDefaultKeyless,
80+
},
81+
"aliasdel": {
82+
Request: routing.ReqDefault,
83+
Response: routing.RespDefaultKeyless,
84+
},
85+
"info": {
86+
Request: routing.ReqDefault,
87+
Response: routing.RespDefaultKeyless,
88+
},
89+
"tagvals": {
90+
Request: routing.ReqDefault,
91+
Response: routing.RespDefaultKeyless,
92+
},
93+
"syndump": {
94+
Request: routing.ReqDefault,
95+
Response: routing.RespDefaultKeyless,
96+
},
97+
"synupdate": {
98+
Request: routing.ReqDefault,
99+
Response: routing.RespDefaultKeyless,
100+
},
101+
"profile": {
102+
Request: routing.ReqDefault,
103+
Response: routing.RespDefaultKeyless,
104+
},
105+
"alter": {
106+
Request: routing.ReqDefault,
107+
Response: routing.RespDefaultKeyless,
108+
},
109+
"dropindex": {
110+
Request: routing.ReqDefault,
111+
Response: routing.RespDefaultKeyless,
112+
},
113+
"drop": {
114+
Request: routing.ReqDefault,
115+
Response: routing.RespDefaultKeyless,
116+
},
117+
},
118+
}
119+
120+
type CommandInfoResolver interface {
121+
getCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy
122+
setFallbackResolver(fallback CommandInfoResolver)
123+
}
124+
125+
type resolver struct {
126+
resolve func(ctx context.Context, cmdName string) *routing.CommandPolicy
127+
fallBackResolver CommandInfoResolver
128+
}
129+
130+
func NewDefaultCommandPolicyResolver() *resolver {
131+
return &resolver{
132+
resolve: func(ctx context.Context, cmdName string) *routing.CommandPolicy {
133+
module := "core"
134+
command := cmdName
135+
cmdParts := strings.Split(cmdName, ".")
136+
if len(cmdParts) == 2 {
137+
module = cmdParts[0]
138+
command = cmdParts[1]
139+
}
140+
141+
if policy, ok := defaultPolicies[module][command]; ok {
142+
return policy
143+
}
144+
145+
return nil
146+
},
147+
}
148+
}
149+
150+
func (r *resolver) getCommandPolicy(ctx context.Context, cmdName string) *routing.CommandPolicy {
151+
policy := r.resolve(ctx, cmdName)
152+
if policy != nil {
153+
return policy
154+
}
155+
156+
if r.fallBackResolver != nil {
157+
return r.fallBackResolver.getCommandPolicy(ctx, cmdName)
158+
}
159+
160+
return nil
161+
}
162+
163+
func (r *resolver) setFallbackResolver(fallbackResolver CommandInfoResolver) {
164+
r.fallBackResolver = fallbackResolver
165+
}

internal/routing/aggregator.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggreg
5050
case RespOneSucceeded:
5151
return &OneSucceededAggregator{}
5252
case RespAggSum:
53-
return &AggSumAggregator{}
53+
return &AggSumAggregator{
54+
// res:
55+
}
5456
case RespAggMin:
5557
return &AggMinAggregator{
5658
res: util.NewAtomicMin(),
@@ -212,7 +214,7 @@ func (a *OneSucceededAggregator) Result() (interface{}, error) {
212214
// AggSumAggregator sums numeric replies from all shards.
213215
type AggSumAggregator struct {
214216
err atomic.Value
215-
res *int64
217+
res int64
216218
}
217219

218220
func (a *AggSumAggregator) Add(result interface{}, err error) error {
@@ -226,7 +228,7 @@ func (a *AggSumAggregator) Add(result interface{}, err error) error {
226228
a.err.CompareAndSwap(nil, err)
227229
return err
228230
}
229-
atomic.AddInt64(a.res, val)
231+
atomic.AddInt64(&a.res, val)
230232
}
231233

232234
return nil
@@ -240,7 +242,7 @@ func (a *AggSumAggregator) BatchAdd(results map[string]AggregatorResErr) error {
240242
return a.Add(res.Result, res.Err)
241243
}
242244

243-
intRes, err := toInt64(res)
245+
intRes, err := toInt64(res.Result)
244246
if err != nil {
245247
return a.Add(nil, err)
246248
}
@@ -263,7 +265,7 @@ func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error {
263265
return a.Add(res.Result, res.Err)
264266
}
265267

266-
intRes, err := toInt64(res)
268+
intRes, err := toInt64(res.Result)
267269
if err != nil {
268270
return a.Add(nil, err)
269271
}
@@ -275,7 +277,7 @@ func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error {
275277
}
276278

277279
func (a *AggSumAggregator) Result() (interface{}, error) {
278-
res, err := atomic.LoadInt64(a.res), a.err.Load()
280+
res, err := atomic.LoadInt64(&a.res), a.err.Load()
279281
if err != nil {
280282
return nil, err.(error)
281283
}

osscluster.go

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -921,10 +921,11 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er
921921
// or more underlying connections. It's safe for concurrent use by
922922
// multiple goroutines.
923923
type ClusterClient struct {
924-
opt *ClusterOptions
925-
nodes *clusterNodes
926-
state *clusterStateHolder
927-
cmdsInfoCache *cmdsInfoCache
924+
opt *ClusterOptions
925+
nodes *clusterNodes
926+
state *clusterStateHolder
927+
cmdsInfoCache *cmdsInfoCache
928+
cmdInfoResolver CommandInfoResolver
928929
cmdable
929930
hooksMixin
930931
}
@@ -942,6 +943,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
942943
c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
943944

944945
c.state = newClusterStateHolder(c.loadState)
946+
947+
c.cmdInfoResolver = c.NewDynamicResolver()
948+
945949
c.cmdable = c.Process
946950
c.initHooks(hooks{
947951
dial: nil,
@@ -1334,7 +1338,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd
13341338

13351339
if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) {
13361340
for _, cmd := range cmds {
1337-
policy := c.getCommandPolicy(ctx, cmd)
1341+
policy := c.extractCommandInfo(ctx, cmd.Name())
13381342
if policy != nil && !policy.CanBeUsedInPipeline() {
13391343
return fmt.Errorf(
13401344
"redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(),
@@ -1351,7 +1355,7 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd
13511355
}
13521356

13531357
for _, cmd := range cmds {
1354-
policy := c.getCommandPolicy(ctx, cmd)
1358+
policy := c.extractCommandInfo(ctx, cmd.Name())
13551359
if policy != nil && !policy.CanBeUsedInPipeline() {
13561360
return fmt.Errorf(
13571361
"redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(),
@@ -1976,6 +1980,29 @@ func (c *ClusterClient) context(ctx context.Context) context.Context {
19761980
return context.Background()
19771981
}
19781982

1983+
func (c *ClusterClient) GetResolver() CommandInfoResolver {
1984+
return c.cmdInfoResolver
1985+
}
1986+
1987+
func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver CommandInfoResolver) {
1988+
c.cmdInfoResolver = cmdInfoResolver
1989+
}
1990+
1991+
// extractCommandInfo retrieves the routing policy for a command
1992+
func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmdName string) *routing.CommandPolicy {
1993+
if cmdInfo := c.cmdInfo(ctx, cmdName); cmdInfo != nil && cmdInfo.CommandPolicy != nil {
1994+
return cmdInfo.CommandPolicy
1995+
}
1996+
1997+
return nil
1998+
}
1999+
2000+
func (c *ClusterClient) NewDynamicResolver() CommandInfoResolver {
2001+
return &resolver{
2002+
resolve: c.extractCommandInfo,
2003+
}
2004+
}
2005+
19792006
func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode {
19802007
for _, n := range nodes {
19812008
if n == node {

osscluster_router.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type slotResult struct {
2121

2222
// routeAndRun routes a command to the appropriate cluster nodes and executes it
2323
func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error {
24-
policy := c.getCommandPolicy(ctx, cmd)
24+
policy := c.cmdInfoResolver.getCommandPolicy(ctx, cmd.Name())
2525
if policy == nil {
2626
return c.executeDefault(ctx, cmd, node)
2727
}
@@ -39,14 +39,6 @@ func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *cluste
3939
}
4040
}
4141

42-
// getCommandPolicy retrieves the routing policy for a command
43-
func (c *ClusterClient) getCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy {
44-
if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.CommandPolicy != nil {
45-
return cmdInfo.CommandPolicy
46-
}
47-
return nil
48-
}
49-
5042
// executeDefault handles standard command routing based on keys
5143
func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, node *clusterNode) error {
5244
if c.hasKeys(cmd) {

osscluster_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (s *clusterScenario) newClusterClient(
6363
ctx context.Context, opt *redis.ClusterOptions,
6464
) *redis.ClusterClient {
6565
client := s.newClusterClientUnstable(opt)
66-
66+
client.SetCommandInfoResolver(client.NewDynamicResolver())
6767
err := eventually(func() error {
6868
if opt.ClusterSlots != nil {
6969
return nil
@@ -1325,7 +1325,6 @@ var _ = Describe("ClusterClient", func() {
13251325
return slots, nil
13261326
}
13271327
client = cluster.newClusterClient(ctx, opt)
1328-
13291328
err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error {
13301329
return master.FlushDB(ctx).Err()
13311330
})

0 commit comments

Comments
 (0)