Skip to content

Commit 69bab38

Browse files
Allow for tracking state in Limiter
Extend Allow and ReportResult functions to handle a Context. Allow can override the passed in Context. The returned Context is then further passed down to ReportResult. Using this Context it is then possible to store values/track state between Allow and ReportResult calls. Without this tracking HalfOpen/Generation state is hard to implement efficiently for Circuit Breakers.
1 parent 7d56a2c commit 69bab38

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

options.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ type Limiter interface {
2121
// Allow returns nil if operation is allowed or an error otherwise.
2222
// If operation is allowed client must ReportResult of the operation
2323
// whether it is a success or a failure.
24-
Allow() error
24+
// The returned context will be passed to ReportResult.
25+
Allow(ctx context.Context) (context.Context, error)
2526
// ReportResult reports the result of the previously allowed operation.
2627
// nil indicates a success, non-nil error usually indicates a failure.
27-
ReportResult(result error)
28+
// Context can be used to access state tracked by previous Allow call.
29+
ReportResult(ctx context.Context, result error)
2830
}
2931

3032
// Options keeps the settings to set up redis connection.

osscluster.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,7 +1319,7 @@ func (c *ClusterClient) processPipelineNode(
13191319
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
13201320
) {
13211321
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
1322-
cn, err := node.Client.getConn(ctx)
1322+
ctx, cn, err := node.Client.getConn(ctx)
13231323
if err != nil {
13241324
node.MarkAsFailing()
13251325
_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
@@ -1504,7 +1504,7 @@ func (c *ClusterClient) processTxPipelineNode(
15041504
) {
15051505
cmds = wrapMultiExec(ctx, cmds)
15061506
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
1507-
cn, err := node.Client.getConn(ctx)
1507+
ctx, cn, err := node.Client.getConn(ctx)
15081508
if err != nil {
15091509
_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
15101510
setCmdsErr(cmds, err)

redis.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,23 +237,24 @@ func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
237237
return cn, nil
238238
}
239239

240-
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
240+
func (c *baseClient) getConn(ctx context.Context) (context.Context, *pool.Conn, error) {
241+
var err error
241242
if c.opt.Limiter != nil {
242-
err := c.opt.Limiter.Allow()
243+
ctx, err = c.opt.Limiter.Allow(ctx)
243244
if err != nil {
244-
return nil, err
245+
return ctx, nil, err
245246
}
246247
}
247248

248249
cn, err := c._getConn(ctx)
249250
if err != nil {
250251
if c.opt.Limiter != nil {
251-
c.opt.Limiter.ReportResult(err)
252+
c.opt.Limiter.ReportResult(ctx, err)
252253
}
253-
return nil, err
254+
return ctx, nil, err
254255
}
255256

256-
return cn, nil
257+
return ctx, cn, nil
257258
}
258259

259260
func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
@@ -365,7 +366,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
365366

366367
func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
367368
if c.opt.Limiter != nil {
368-
c.opt.Limiter.ReportResult(err)
369+
c.opt.Limiter.ReportResult(ctx, err)
369370
}
370371

371372
if isBadConn(err, false, c.opt.Addr) {
@@ -378,7 +379,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
378379
func (c *baseClient) withConn(
379380
ctx context.Context, fn func(context.Context, *pool.Conn) error,
380381
) error {
381-
cn, err := c.getConn(ctx)
382+
ctx, cn, err := c.getConn(ctx)
382383
if err != nil {
383384
return err
384385
}

0 commit comments

Comments
 (0)