Skip to content

Commit 235864e

Browse files
authored
Merge pull request #1188 from ydb-platform/cancels-guard
added xcontext.CancelsGuard
2 parents 4ebada8 + a243f13 commit 235864e

File tree

3 files changed

+89
-10
lines changed

3 files changed

+89
-10
lines changed

internal/conn/conn.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type conn struct {
5656
closed bool
5757
state atomic.Uint32
5858
lastUsage *lastUsage
59+
childStreams *xcontext.CancelsGuard
5960
onClose []func(*conn)
6061
onTransportErrors []func(ctx context.Context, cc Conn, cause error)
6162
}
@@ -392,23 +393,21 @@ func (c *conn) NewStream(
392393
desc *grpc.StreamDesc,
393394
method string,
394395
opts ...grpc.CallOption,
395-
) (_ grpc.ClientStream, err error) {
396+
) (_ grpc.ClientStream, finalErr error) {
396397
var (
397398
onDone = trace.DriverOnConnNewStream(
398399
c.config.Trace(), &ctx,
399400
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).NewStream"),
400401
c.endpoint.Copy(), trace.Method(method),
401402
)
402403
useWrapping = UseWrapping(ctx)
403-
cc *grpc.ClientConn
404-
s grpc.ClientStream
405404
)
406405

407406
defer func() {
408-
onDone(err, c.GetState())
407+
onDone(finalErr, c.GetState())
409408
}()
410409

411-
cc, err = c.realConn(ctx)
410+
cc, err := c.realConn(ctx)
412411
if err != nil {
413412
return nil, c.wrapError(err)
414413
}
@@ -423,7 +422,19 @@ func (c *conn) NewStream(
423422

424423
ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID))
425424

426-
s, err = cc.NewStream(ctx, desc, method, opts...)
425+
ctx, cancel := xcontext.WithCancel(ctx)
426+
defer func() {
427+
if finalErr != nil {
428+
cancel()
429+
} else {
430+
c.childStreams.Remember(&cancel)
431+
}
432+
}()
433+
434+
s, err := cc.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(func(err error) {
435+
cancel()
436+
c.childStreams.Forget(&cancel)
437+
}))...)
427438
if err != nil {
428439
if xerrors.IsContextError(err) {
429440
return nil, xerrors.WithStackTrace(err)
@@ -490,10 +501,16 @@ func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, ca
490501

491502
func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn {
492503
c := &conn{
493-
endpoint: e,
494-
config: config,
495-
done: make(chan struct{}),
496-
lastUsage: newLastUsage(nil),
504+
endpoint: e,
505+
config: config,
506+
done: make(chan struct{}),
507+
lastUsage: newLastUsage(nil),
508+
childStreams: xcontext.NewCancelsGuard(),
509+
onClose: []func(*conn){
510+
func(c *conn) {
511+
c.childStreams.Cancel()
512+
},
513+
},
497514
}
498515
c.state.Store(uint32(Created))
499516
for _, opt := range opts {

internal/xcontext/cancels_quard.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package xcontext
2+
3+
import (
4+
"context"
5+
"sync"
6+
)
7+
8+
type CancelsGuard struct {
9+
mu sync.Mutex
10+
cancels map[*context.CancelFunc]struct{}
11+
}
12+
13+
func NewCancelsGuard() *CancelsGuard {
14+
return &CancelsGuard{
15+
cancels: make(map[*context.CancelFunc]struct{}),
16+
}
17+
}
18+
19+
func (g *CancelsGuard) Remember(cancel *context.CancelFunc) {
20+
g.mu.Lock()
21+
defer g.mu.Unlock()
22+
g.cancels[cancel] = struct{}{}
23+
}
24+
25+
func (g *CancelsGuard) Forget(cancel *context.CancelFunc) {
26+
g.mu.Lock()
27+
defer g.mu.Unlock()
28+
delete(g.cancels, cancel)
29+
}
30+
31+
func (g *CancelsGuard) Cancel() {
32+
g.mu.Lock()
33+
defer g.mu.Unlock()
34+
for cancel := range g.cancels {
35+
(*cancel)()
36+
}
37+
g.cancels = make(map[*context.CancelFunc]struct{})
38+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package xcontext
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"golang.org/x/net/context"
8+
)
9+
10+
func TestCancelsGuard(t *testing.T) {
11+
g := NewCancelsGuard()
12+
ctx, cancel1 := context.WithCancel(context.Background())
13+
g.Remember(&cancel1)
14+
require.Len(t, g.cancels, 1)
15+
g.Forget(&cancel1)
16+
require.Empty(t, g.cancels, 0)
17+
cancel2 := context.CancelFunc(func() {
18+
cancel1()
19+
})
20+
g.Remember(&cancel2)
21+
require.Len(t, g.cancels, 1)
22+
g.Cancel()
23+
require.Error(t, ctx.Err())
24+
}

0 commit comments

Comments
 (0)