Skip to content

Commit 4bdce2a

Browse files
authored
Merge pull request #280 from ydb-platform/xcontext
context cancel with reason
2 parents 8c7ac54 + 43a12a9 commit 4bdce2a

File tree

6 files changed

+167
-10
lines changed

6 files changed

+167
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Added context cancel with specific error
2+
13
## v3.26.10
24
* Fixed syntax mistake in `trace.TablePooStateChangeInfo` to `trace.TablePoolStateChangeInfo`
35

internal/conn/grpc_client_stream.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package conn
22

33
import (
44
"context"
5+
"errors"
56
"time"
67

78
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
89
"google.golang.org/grpc"
910

11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
12+
1013
"github.com/ydb-platform/ydb-go-sdk/v3/internal/wrap"
1114
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
1215
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
@@ -40,7 +43,7 @@ func (s *grpcClientStream) CloseSend() (err error) {
4043

4144
func (s *grpcClientStream) SendMsg(m interface{}) (err error) {
4245
cancel := createPinger(s.c)
43-
defer cancel()
46+
defer cancel(xerrors.WithStackTrace(errors.New("send msg finished")))
4447

4548
err = s.ClientStream.SendMsg(m)
4649

@@ -61,7 +64,7 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) {
6164

6265
func (s *grpcClientStream) RecvMsg(m interface{}) (err error) {
6366
cancel := createPinger(s.c)
64-
defer cancel()
67+
defer cancel(xerrors.WithStackTrace(errors.New("receive msg finished")))
6568

6669
defer func() {
6770
onDone := s.recv(xerrors.HideEOF(err))
@@ -102,9 +105,9 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) {
102105
return nil
103106
}
104107

105-
func createPinger(c *conn) context.CancelFunc {
108+
func createPinger(c *conn) xcontext.CancelErrFunc {
106109
c.touchLastUsage()
107-
ctx, cancel := context.WithCancel(context.Background())
110+
ctx, cancel := xcontext.WithErrCancel(context.Background())
108111
go func() {
109112
ticker := time.NewTicker(time.Second)
110113
ctxDone := ctx.Done()

internal/repeater/repeater.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@ package repeater
22

33
import (
44
"context"
5+
"fmt"
56
"time"
67

78
"github.com/jonboulle/clockwork"
89

10+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
11+
12+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
13+
914
"github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff"
1015
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
1116
)
@@ -27,7 +32,7 @@ type repeater struct {
2732
// Task is a function that must be executed periodically.
2833
task func(context.Context) error
2934

30-
cancel context.CancelFunc
35+
cancel xcontext.CancelErrFunc
3136
stopped chan struct{}
3237

3338
force chan struct{}
@@ -74,7 +79,7 @@ func New(
7479
task func(ctx context.Context) (err error),
7580
opts ...option,
7681
) *repeater {
77-
ctx, cancel := context.WithCancel(context.Background())
82+
ctx, cancel := xcontext.WithErrCancel(context.Background())
7883

7984
r := &repeater{
8085
interval: interval,
@@ -95,7 +100,7 @@ func New(
95100
}
96101

97102
func (r *repeater) stop(onCancel func()) {
98-
r.cancel()
103+
r.cancel(xerrors.WithStackTrace(fmt.Errorf("ydb: repeater stopped")))
99104
if onCancel != nil {
100105
onCancel()
101106
}

internal/table/session.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package table
22

33
import (
44
"context"
5+
"fmt"
56
"net/url"
67
"strconv"
78
"sync"
@@ -11,6 +12,8 @@ import (
1112
"google.golang.org/grpc/metadata"
1213
"google.golang.org/protobuf/proto"
1314

15+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
16+
1417
"github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer"
1518

1619
"github.com/ydb-platform/ydb-go-genproto/Ydb_Table_V1"
@@ -915,15 +918,15 @@ func (s *session) StreamReadTable(
915918
opt((*options.ReadTableDesc)(&request))
916919
}
917920

918-
ctx, cancel := context.WithCancel(ctx)
921+
ctx, cancel := xcontext.WithErrCancel(ctx)
919922

920923
stream, err = s.tableService.StreamReadTable(
921924
balancer.WithEndpoint(ctx, s),
922925
&request,
923926
)
924927

925928
if err != nil {
926-
cancel()
929+
cancel(xerrors.WithStackTrace(fmt.Errorf("ydb: stream read error: %w", err)))
927930
return nil, xerrors.WithStackTrace(err)
928931
}
929932

@@ -950,7 +953,11 @@ func (s *session) StreamReadTable(
950953
}
951954
},
952955
func(err error) error {
953-
cancel()
956+
if err == nil {
957+
cancel(nil)
958+
} else {
959+
cancel(xerrors.WithStackTrace(fmt.Errorf("ydb: stream closed with: %w", err)))
960+
}
954961
onIntermediate(xerrors.HideEOF(err))(xerrors.HideEOF(err))
955962
return err
956963
},
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package xcontext
2+
3+
import (
4+
"context"
5+
"errors"
6+
"sync"
7+
"time"
8+
9+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
10+
)
11+
12+
var errCancelWithNilError = cancelError{err: errors.New("cancel context with nil error")}
13+
14+
// CancelErrFunc use for cancel with wrap with specific error
15+
// if err == nil CancelErrFunc will panic for prevent
16+
// call cancel, then ctx.Err() == nil
17+
type CancelErrFunc func(err error)
18+
19+
func WithErrCancel(ctx context.Context) (resCtx context.Context, cancel CancelErrFunc) {
20+
res := &ctxError{}
21+
res.ctx, res.ctxCancel = context.WithCancel(ctx)
22+
return res, res.cancel
23+
}
24+
25+
type ctxError struct {
26+
ctx context.Context
27+
ctxCancel context.CancelFunc
28+
29+
m sync.Mutex
30+
err error
31+
}
32+
33+
func (c *ctxError) Deadline() (deadline time.Time, ok bool) {
34+
return c.ctx.Deadline()
35+
}
36+
37+
func (c *ctxError) Done() <-chan struct{} {
38+
return c.ctx.Done()
39+
}
40+
41+
func (c *ctxError) Err() error {
42+
c.m.Lock()
43+
defer c.m.Unlock()
44+
45+
return c.errUnderLock()
46+
}
47+
48+
func (c *ctxError) errUnderLock() error {
49+
if c.err == nil {
50+
c.err = c.ctx.Err()
51+
}
52+
53+
return c.err
54+
}
55+
56+
func (c *ctxError) Value(key interface{}) interface{} {
57+
return c.ctx.Value(key)
58+
}
59+
60+
func (c *ctxError) cancel(err error) {
61+
c.m.Lock()
62+
defer c.m.Unlock()
63+
64+
if err == nil {
65+
err = xerrors.WithStackTrace(errCancelWithNilError)
66+
}
67+
68+
if c.errUnderLock() == nil {
69+
err = cancelError{err: err}
70+
c.err = err
71+
}
72+
73+
c.ctxCancel()
74+
}
75+
76+
type cancelError struct {
77+
err error
78+
}
79+
80+
func (e cancelError) Error() string {
81+
return e.err.Error()
82+
}
83+
84+
func (e cancelError) Is(target error) bool {
85+
return errors.Is(e.err, target) || errors.Is(context.Canceled, target)
86+
}
87+
88+
func (e cancelError) As(target interface{}) bool {
89+
return errors.As(e.err, target) || errors.As(context.Canceled, target)
90+
}
91+
92+
func (e cancelError) Unwrap() error {
93+
return e.err
94+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package xcontext
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestCancelWithError(t *testing.T) {
12+
testError := errors.New("test error")
13+
t.Run("SimpleCancel", func(t *testing.T) {
14+
ctx, cancel := WithErrCancel(context.Background())
15+
cancel(testError)
16+
require.ErrorIs(t, ctx.Err(), testError)
17+
})
18+
19+
t.Run("CancelBeforeParent", func(t *testing.T) {
20+
parent, parentCancel := context.WithCancel(context.Background())
21+
ctx, cancel := WithErrCancel(parent)
22+
23+
cancel(testError)
24+
parentCancel()
25+
26+
require.ErrorIs(t, ctx.Err(), testError)
27+
require.ErrorIs(t, ctx.Err(), context.Canceled)
28+
})
29+
30+
t.Run("CancelAfterParent", func(t *testing.T) {
31+
parent, parentCancel := context.WithCancel(context.Background())
32+
ctx, cancel := WithErrCancel(parent)
33+
34+
parentCancel()
35+
cancel(testError)
36+
37+
require.Equal(t, context.Canceled, ctx.Err())
38+
})
39+
40+
t.Run("CancelWithNil", func(t *testing.T) {
41+
ctx, cancel := WithErrCancel(context.Background())
42+
cancel(nil)
43+
require.ErrorIs(t, ctx.Err(), errCancelWithNilError)
44+
require.ErrorIs(t, ctx.Err(), context.Canceled)
45+
})
46+
}

0 commit comments

Comments
 (0)