Skip to content

Commit 9bda31b

Browse files
committed
add context cancel with reasone
1 parent 364b937 commit 9bda31b

File tree

6 files changed

+58
-13
lines changed

6 files changed

+58
-13
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Added context cancel with specific error
2+
13
## v3.26.10
24
* Fixed syntax mistake in `trace.TablePooStateChangeInfo` to `trace.TablePoolStateChangeInfo`
35

internal/conn/grpc_client_stream.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package conn
22

33
import (
44
"context"
5+
"errors"
56
"time"
67

78
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
89
"google.golang.org/grpc"
910

11+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
12+
1013
"github.com/ydb-platform/ydb-go-sdk/v3/internal/wrap"
1114
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
1215
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
@@ -40,7 +43,7 @@ func (s *grpcClientStream) CloseSend() (err error) {
4043

4144
func (s *grpcClientStream) SendMsg(m interface{}) (err error) {
4245
cancel := createPinger(s.c)
43-
defer cancel()
46+
defer cancel(xerrors.WithStackTrace(errors.New("send msg finished")))
4447

4548
err = s.ClientStream.SendMsg(m)
4649

@@ -61,7 +64,7 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) {
6164

6265
func (s *grpcClientStream) RecvMsg(m interface{}) (err error) {
6366
cancel := createPinger(s.c)
64-
defer cancel()
67+
defer cancel(xerrors.WithStackTrace(errors.New("receive msg finished")))
6568

6669
defer func() {
6770
onDone := s.recv(xerrors.HideEOF(err))
@@ -102,9 +105,9 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) {
102105
return nil
103106
}
104107

105-
func createPinger(c *conn) context.CancelFunc {
108+
func createPinger(c *conn) xcontext.CancelErrFunc {
106109
c.touchLastUsage()
107-
ctx, cancel := context.WithCancel(context.Background())
110+
ctx, cancel := xcontext.WithErrCancel(context.Background())
108111
go func() {
109112
ticker := time.NewTicker(time.Second)
110113
ctxDone := ctx.Done()

internal/repeater/repeater.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@ package repeater
22

33
import (
44
"context"
5+
"fmt"
56
"time"
67

78
"github.com/jonboulle/clockwork"
89

10+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
11+
12+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
13+
914
"github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff"
1015
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
1116
)
@@ -27,7 +32,7 @@ type repeater struct {
2732
// Task is a function that must be executed periodically.
2833
task func(context.Context) error
2934

30-
cancel context.CancelFunc
35+
cancel xcontext.CancelErrFunc
3136
stopped chan struct{}
3237

3338
force chan struct{}
@@ -74,7 +79,7 @@ func New(
7479
task func(ctx context.Context) (err error),
7580
opts ...option,
7681
) *repeater {
77-
ctx, cancel := context.WithCancel(context.Background())
82+
ctx, cancel := xcontext.WithErrCancel(context.Background())
7883

7984
r := &repeater{
8085
interval: interval,
@@ -95,7 +100,7 @@ func New(
95100
}
96101

97102
func (r *repeater) stop(onCancel func()) {
98-
r.cancel()
103+
r.cancel(xerrors.WithStackTrace(fmt.Errorf("ydb: repeater stopped")))
99104
if onCancel != nil {
100105
onCancel()
101106
}

internal/table/session.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package table
22

33
import (
44
"context"
5+
"fmt"
56
"net/url"
67
"strconv"
78
"sync"
@@ -11,6 +12,8 @@ import (
1112
"google.golang.org/grpc/metadata"
1213
"google.golang.org/protobuf/proto"
1314

15+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
16+
1417
"github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer"
1518

1619
"github.com/ydb-platform/ydb-go-genproto/Ydb_Table_V1"
@@ -915,15 +918,15 @@ func (s *session) StreamReadTable(
915918
opt((*options.ReadTableDesc)(&request))
916919
}
917920

918-
ctx, cancel := context.WithCancel(ctx)
921+
ctx, cancel := xcontext.WithErrCancel(ctx)
919922

920923
stream, err = s.tableService.StreamReadTable(
921924
balancer.WithEndpoint(ctx, s),
922925
&request,
923926
)
924927

925928
if err != nil {
926-
cancel()
929+
cancel(xerrors.WithStackTrace(fmt.Errorf("ydb: stream read error: %w", err)))
927930
return nil, xerrors.WithStackTrace(err)
928931
}
929932

@@ -950,7 +953,11 @@ func (s *session) StreamReadTable(
950953
}
951954
},
952955
func(err error) error {
953-
cancel()
956+
if err == nil {
957+
cancel(nil)
958+
} else {
959+
cancel(xerrors.WithStackTrace(fmt.Errorf("ydb: stream closed with: %w", err)))
960+
}
954961
onIntermediate(xerrors.HideEOF(err))(xerrors.HideEOF(err))
955962
return err
956963
},

internal/xcontext/cancel_with_error.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package xcontext
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"sync"
78
"time"
@@ -51,7 +52,6 @@ func (c *ctxError) errUnderLock() error {
5152
}
5253

5354
return c.err
54-
5555
}
5656

5757
func (c *ctxError) Value(key interface{}) interface{} {
@@ -67,8 +67,29 @@ func (c *ctxError) cancel(err error) {
6767
}
6868

6969
if c.errUnderLock() == nil {
70+
err = cancelError{err: err}
7071
c.err = err
7172
}
7273

7374
c.ctxCancel()
7475
}
76+
77+
type cancelError struct {
78+
err error
79+
}
80+
81+
func (e cancelError) Error() string {
82+
return e.err.Error()
83+
}
84+
85+
func (e cancelError) Is(target error) bool {
86+
return errors.Is(e.err, target) || errors.Is(context.Canceled, target)
87+
}
88+
89+
func (e cancelError) As(target interface{}) bool {
90+
return errors.As(e.err, target) || errors.As(context.Canceled, target)
91+
}
92+
93+
func (e cancelError) Unwrap() error {
94+
return e.err
95+
}

internal/xcontext/cancel_with_error_test.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func TestCancelWithError(t *testing.T) {
1313
t.Run("SimpleCancel", func(t *testing.T) {
1414
ctx, cancel := WithErrCancel(context.Background())
1515
cancel(testError)
16-
require.Equal(t, testError, ctx.Err())
16+
require.ErrorIs(t, ctx.Err(), testError)
1717
})
1818

1919
t.Run("CancelBeforeParent", func(t *testing.T) {
@@ -23,7 +23,8 @@ func TestCancelWithError(t *testing.T) {
2323
cancel(testError)
2424
parentCancel()
2525

26-
require.Equal(t, testError, ctx.Err())
26+
require.ErrorIs(t, ctx.Err(), testError)
27+
require.ErrorIs(t, ctx.Err(), context.Canceled)
2728
})
2829

2930
t.Run("CancelAfterParent", func(t *testing.T) {
@@ -36,4 +37,10 @@ func TestCancelWithError(t *testing.T) {
3637
require.Equal(t, context.Canceled, ctx.Err())
3738
})
3839

40+
t.Run("CancelWithNil", func(t *testing.T) {
41+
ctx, cancel := WithErrCancel(context.Background())
42+
cancel(nil)
43+
require.ErrorIs(t, ctx.Err(), errCancelWithNilError)
44+
require.ErrorIs(t, ctx.Err(), context.Canceled)
45+
})
3946
}

0 commit comments

Comments
 (0)