diff --git a/cluster.go b/cluster.go index 12417d1f9..67398fb29 100644 --- a/cluster.go +++ b/cluster.go @@ -1279,7 +1279,7 @@ func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool func (c *ClusterClient) processPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { - _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) (retErr error) { cn, err := node.Client.getConn(ctx) if err != nil { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) @@ -1287,20 +1287,26 @@ func (c *ClusterClient) processPipelineNode( return err } - var processErr error + if retErr = cn.WatchCancel(c.context(ctx)); retErr != nil { + return retErr + } + defer func() { - node.Client.releaseConn(ctx, cn, processErr) + if err = cn.WatchFinish(); err != nil { + retErr = err + } + node.Client.releaseConn(ctx, cn, retErr) }() - processErr = c.processPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + retErr = c.processPipelineNodeConn(ctx, node, cn, cmds, failedCmds) - return processErr + return retErr }) } func (c *ClusterClient) processPipelineNodeConn( ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ) error { - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { if shouldRetry(err, true) { @@ -1310,7 +1316,7 @@ func (c *ClusterClient) processPipelineNodeConn( return err } - return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds) }) } @@ -1460,7 +1466,7 @@ func (c *ClusterClient) processTxPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) { cmds = wrapMultiExec(ctx, cmds) - _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) (retErr error) { cn, err := node.Client.getConn(ctx) if err != nil { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) @@ -1468,20 +1474,26 @@ func (c *ClusterClient) processTxPipelineNode( return err } - var processErr error + if retErr = cn.WatchCancel(c.context(ctx)); retErr != nil { + return retErr + } + defer func() { - node.Client.releaseConn(ctx, cn, processErr) + if err = cn.WatchFinish(); err != nil { + retErr = err + } + node.Client.releaseConn(ctx, cn, retErr) }() - processErr = c.processTxPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + retErr = c.processTxPipelineNodeConn(ctx, node, cn, cmds, failedCmds) - return processErr + return retErr }) } func (c *ClusterClient) processTxPipelineNodeConn( ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, ) error { - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { if shouldRetry(err, true) { @@ -1491,7 +1503,7 @@ func (c *ClusterClient) processTxPipelineNodeConn( return err } - return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { statusCmd := cmds[0].(*StatusCmd) // Trim multi and exec. trimmedCmds := cmds[1 : len(cmds)-1] diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 7f45bc0bb..32c899dad 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -10,7 +10,11 @@ import ( "github.com/redis/go-redis/v9/internal/proto" ) -var noDeadline = time.Time{} +var ( + // aLongTimeAgo is a non-zero time, used to immediately unblock the network. + aLongTimeAgo = time.Unix(1, 0) + noDeadline = time.Time{} +) type Conn struct { usedAt int64 // atomic @@ -23,6 +27,13 @@ type Conn struct { Inited bool pooled bool createdAt time.Time + + _closed int32 // atomic + closeChan chan struct{} + watchChan chan context.Context + finishChan chan struct{} + interruptChan chan error + watching bool } func NewConn(netConn net.Conn) *Conn { @@ -34,9 +45,72 @@ func NewConn(netConn net.Conn) *Conn { cn.bw = bufio.NewWriter(netConn) cn.wr = proto.NewWriter(cn.bw) cn.SetUsedAt(time.Now()) + + cn.closeChan = make(chan struct{}) + cn.interruptChan = make(chan error) + cn.finishChan = make(chan struct{}) + cn.watchChan = make(chan context.Context, 1) + + go cn.loopWatcher() + return cn } +func (cn *Conn) loopWatcher() { + var ctx context.Context + for { + select { + case ctx = <-cn.watchChan: + case <-cn.closeChan: + return + } + + select { + case <-ctx.Done(): + _ = cn.netConn.SetDeadline(aLongTimeAgo) + cn.interruptChan <- ctx.Err() + case <-cn.finishChan: + case <-cn.closeChan: + return + } + } +} + +func (cn *Conn) WatchFinish() error { + if !cn.watching { + return nil + } + + var err error + select { + case cn.finishChan <- struct{}{}: + cn.watching = false + case err = <-cn.interruptChan: + cn.watching = false + case <-cn.closeChan: + } + return err +} + +func (cn *Conn) WatchCancel(ctx context.Context) error { + if cn.watching { + panic("repeat watchCancel") + } + if err := ctx.Err(); err != nil { + return err + } + if ctx.Done() == nil { + return nil + } + if cn.closed() { + return nil + } + + cn.watching = true + cn.watchChan <- ctx + return nil +} + func (cn *Conn) UsedAt() time.Time { unix := atomic.LoadInt64(&cn.usedAt) return time.Unix(unix, 0) @@ -94,33 +168,24 @@ func (cn *Conn) WithWriter( return cn.bw.Flush() } +func (cn *Conn) closed() bool { + return atomic.LoadInt32(&cn._closed) == 1 +} + func (cn *Conn) Close() error { - return cn.netConn.Close() + if atomic.CompareAndSwapInt32(&cn._closed, 0, 1) { + close(cn.closeChan) + return cn.netConn.Close() + } + return nil } -func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { +func (cn *Conn) deadline(_ context.Context, timeout time.Duration) time.Time { tm := time.Now() cn.SetUsedAt(tm) if timeout > 0 { - tm = tm.Add(timeout) - } - - if ctx != nil { - deadline, ok := ctx.Deadline() - if ok { - if timeout == 0 { - return deadline - } - if deadline.Before(tm) { - return deadline - } - return tm - } - } - - if timeout > 0 { - return tm + return tm.Add(timeout) } return noDeadline diff --git a/internal/pool/pool.go b/internal/pool/pool.go index bb9b14beb..b04a56fc9 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -354,7 +354,7 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { return } - if !cn.pooled { + if !cn.pooled || cn.watching { p.Remove(ctx, cn, nil) return } @@ -486,6 +486,10 @@ func (p *ConnPool) Close() error { func (p *ConnPool) isHealthyConn(cn *Conn) bool { now := time.Now() + if cn.watching { + return false + } + if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime { return false } diff --git a/redis.go b/redis.go index 6eed8424c..4e424ea5a 100644 --- a/redis.go +++ b/redis.go @@ -336,20 +336,26 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) func (c *baseClient) withConn( ctx context.Context, fn func(context.Context, *pool.Conn) error, -) error { +) (retErr error) { cn, err := c.getConn(ctx) if err != nil { return err } - var fnErr error + if retErr = cn.WatchCancel(c.context(ctx)); retErr != nil { + return retErr + } + defer func() { - c.releaseConn(ctx, cn, fnErr) + if err = cn.WatchFinish(); err != nil { + retErr = err + } + c.releaseConn(ctx, cn, retErr) }() - fnErr = fn(ctx, cn) + retErr = fn(ctx, cn) - return fnErr + return retErr } func (c *baseClient) dial(ctx context.Context, network, addr string) (net.Conn, error) { @@ -380,14 +386,14 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool retryTimeout := uint32(0) if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmd(wr, cmd) }); err != nil { atomic.StoreUint32(&retryTimeout, 1) return err } - if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), cmd.readReply); err != nil { + if err := cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply); err != nil { if cmd.readTimeout() == nil { atomic.StoreUint32(&retryTimeout, 1) } else { @@ -486,14 +492,14 @@ func (c *baseClient) generalProcessPipeline( func (c *baseClient) pipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { setCmdsErr(cmds, err) return true, err } - if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + if err := cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { return pipelineReadCmds(rd, cmds) }); err != nil { return true, err @@ -518,14 +524,14 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { func (c *baseClient) txPipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { setCmdsErr(cmds, err) return true, err } - if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + if err := cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { statusCmd := cmds[0].(*StatusCmd) // Trim multi and exec. trimmedCmds := cmds[1 : len(cmds)-1] diff --git a/redis_test.go b/redis_test.go index 6d3842070..b3bf01caa 100644 --- a/redis_test.go +++ b/redis_test.go @@ -531,3 +531,34 @@ var _ = Describe("Hook", func() { })) }) }) + +var _ = Describe("Watch Context.Done", func() { + var client *redis.Client + + BeforeEach(func() { + opt := redisOptions() + opt.ReadTimeout = 10 * time.Second + opt.WriteTimeout = 10 * time.Second + opt.PoolTimeout = 10 * time.Second + opt.ContextTimeoutEnabled = true + + client = redis.NewClient(opt) + Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should cancel", func() { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + + err := client.BLPop(ctx, 10*time.Second, "key1").Err() + Expect(errors.Is(err, context.Canceled)).To(BeTrue()) + }) +})