Skip to content

Commit 593812b

Browse files
added nil safety assertions
1 parent bdb6036 commit 593812b

File tree

3 files changed

+32
-20
lines changed

3 files changed

+32
-20
lines changed

command_policy_resolver.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,20 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{
117117
},
118118
}
119119

120-
type CommandInfoResolver struct {
121-
resolve func(ctx context.Context, cmd Cmder) *routing.CommandPolicy
122-
fallBackResolver *CommandInfoResolver
120+
type CommandInfoResolveFunc func(ctx context.Context, cmd Cmder) *routing.CommandPolicy
121+
122+
type commandInfoResolver struct {
123+
resolveFunc CommandInfoResolveFunc
124+
fallBackResolver *commandInfoResolver
123125
}
124126

125-
func NewCommandInfoResolver(resolver func(ctx context.Context, cmd Cmder) *routing.CommandPolicy) *CommandInfoResolver {
126-
return &CommandInfoResolver{
127-
resolve: resolver,
127+
func NewCommandInfoResolver(resolveFunc CommandInfoResolveFunc) *commandInfoResolver {
128+
return &commandInfoResolver{
129+
resolveFunc: resolveFunc,
128130
}
129131
}
130132

131-
func NewDefaultCommandPolicyResolver() *CommandInfoResolver {
133+
func NewDefaultCommandPolicyResolver() *commandInfoResolver {
132134
return NewCommandInfoResolver(func(ctx context.Context, cmd Cmder) *routing.CommandPolicy {
133135
module := "core"
134136
command := cmd.Name()
@@ -146,12 +148,12 @@ func NewDefaultCommandPolicyResolver() *CommandInfoResolver {
146148
})
147149
}
148150

149-
func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy {
150-
if r.resolve == nil {
151+
func (r *commandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy {
152+
if r.resolveFunc == nil {
151153
return nil
152154
}
153155

154-
policy := r.resolve(ctx, cmd)
156+
policy := r.resolveFunc(ctx, cmd)
155157
if policy != nil {
156158
return policy
157159
}
@@ -163,6 +165,6 @@ func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *
163165
return nil
164166
}
165167

166-
func (r *CommandInfoResolver) SetFallbackResolver(fallbackResolver *CommandInfoResolver) {
168+
func (r *commandInfoResolver) SetFallbackResolver(fallbackResolver *commandInfoResolver) {
167169
r.fallBackResolver = fallbackResolver
168170
}

osscluster.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ type ClusterClient struct {
10171017
nodes *clusterNodes
10181018
state *clusterStateHolder
10191019
cmdsInfoCache *cmdsInfoCache
1020-
cmdInfoResolver *CommandInfoResolver
1020+
cmdInfoResolver *commandInfoResolver
10211021
cmdable
10221022
hooksMixin
10231023
}
@@ -1425,7 +1425,10 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd
14251425

14261426
if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) {
14271427
for _, cmd := range cmds {
1428-
policy := c.extractCommandInfo(ctx, cmd)
1428+
var policy *routing.CommandPolicy
1429+
if c.cmdInfoResolver != nil {
1430+
policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd)
1431+
}
14291432
if policy != nil && !policy.CanBeUsedInPipeline() {
14301433
return fmt.Errorf(
14311434
"redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(),
@@ -1442,7 +1445,10 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd
14421445
}
14431446

14441447
for _, cmd := range cmds {
1445-
policy := c.extractCommandInfo(ctx, cmd)
1448+
var policy *routing.CommandPolicy
1449+
if c.cmdInfoResolver != nil {
1450+
policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd)
1451+
}
14461452
if policy != nil && !policy.CanBeUsedInPipeline() {
14471453
return fmt.Errorf(
14481454
"redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(),
@@ -2199,11 +2205,11 @@ func (c *ClusterClient) context(ctx context.Context) context.Context {
21992205
return context.Background()
22002206
}
22012207

2202-
func (c *ClusterClient) GetResolver() *CommandInfoResolver {
2208+
func (c *ClusterClient) GetResolver() *commandInfoResolver {
22032209
return c.cmdInfoResolver
22042210
}
22052211

2206-
func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *CommandInfoResolver) {
2212+
func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *commandInfoResolver) {
22072213
c.cmdInfoResolver = cmdInfoResolver
22082214
}
22092215

@@ -2218,9 +2224,9 @@ func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmd Cmder) *rout
22182224

22192225
// NewDynamicResolver returns a CommandInfoResolver
22202226
// that uses the underlying cmdInfo cache to resolve the policies
2221-
func (c *ClusterClient) NewDynamicResolver() *CommandInfoResolver {
2222-
return &CommandInfoResolver{
2223-
resolve: c.extractCommandInfo,
2227+
func (c *ClusterClient) NewDynamicResolver() *commandInfoResolver {
2228+
return &commandInfoResolver{
2229+
resolveFunc: c.extractCommandInfo,
22242230
}
22252231
}
22262232

osscluster_router.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ type slotResult struct {
2121

2222
// routeAndRun routes a command to the appropriate cluster nodes and executes it
2323
func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error {
24-
policy := c.cmdInfoResolver.GetCommandPolicy(ctx, cmd)
24+
var policy *routing.CommandPolicy
25+
if c.cmdInfoResolver != nil {
26+
policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd)
27+
}
28+
2529
if policy == nil {
2630
return c.executeDefault(ctx, cmd, node)
2731
}

0 commit comments

Comments
 (0)