@@ -49,8 +49,12 @@ func (b *Balancer) OnUpdate(onDiscovery func(ctx context.Context, endpoints []en
4949}
5050
5151func (b * Balancer ) clusterDiscovery (ctx context.Context ) (err error ) {
52- if err = retry .Retry (ctx , func (ctx context.Context ) (err error ) {
53- if err = b .clusterDiscoveryAttempt (ctx ); err != nil {
52+ if err = retry .Retry (ctx , func (childCtx context.Context ) (err error ) {
53+ if err = b .clusterDiscoveryAttempt (childCtx ); err != nil {
54+ // if got err but parent context is not done - mark error as retryable
55+ if err != nil && ctx .Err () == nil && xerrors .IsTimeoutError (err ) {
56+ return xerrors .WithStackTrace (xerrors .Retryable (err ))
57+ }
5458 return xerrors .WithStackTrace (err )
5559 }
5660 return nil
@@ -69,59 +73,45 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
6973 )
7074 endpoints []endpoint.Endpoint
7175 localDC string
76+ cancel context.CancelFunc
7277 )
7378
7479 defer func () {
75- // if got err but parent context is not done - mark error as retryable
76- if err != nil && ctx .Err () == nil && xerrors .Is (err ,
77- context .DeadlineExceeded ,
78- context .Canceled ,
79- ) {
80- err = xerrors .WithStackTrace (xerrors .Retryable (err ))
81- }
8280 nodes := make ([]trace.EndpointInfo , 0 , len (endpoints ))
8381 for _ , e := range endpoints {
8482 nodes = append (nodes , e .Copy ())
8583 }
86- onDone (
87- nodes ,
88- localDC ,
89- err ,
90- )
84+ onDone (nodes , localDC , err )
9185 }()
9286
93- var (
94- childCtx context.Context
95- cancel context.CancelFunc
96- )
9787 if dialTimeout := b .driverConfig .DialTimeout (); dialTimeout > 0 {
98- childCtx , cancel = context .WithTimeout (ctx , dialTimeout )
88+ ctx , cancel = context .WithTimeout (ctx , dialTimeout )
9989 } else {
100- childCtx , cancel = context .WithCancel (ctx )
90+ ctx , cancel = context .WithCancel (ctx )
10191 }
10292 defer cancel ()
10393
104- client , err := b .discoveryClient (childCtx )
94+ client , err := b .discoveryClient (ctx )
10595 if err != nil {
10696 return xerrors .WithStackTrace (err )
10797 }
10898 defer func () {
10999 _ = client .Close (ctx )
110100 }()
111101
112- endpoints , err = client .Discover (childCtx )
102+ endpoints , err = client .Discover (ctx )
113103 if err != nil {
114104 return xerrors .WithStackTrace (err )
115105 }
116106
117107 if b .balancerConfig .DetectlocalDC {
118- localDC , err = b .localDCDetector (childCtx , endpoints )
108+ localDC , err = b .localDCDetector (ctx , endpoints )
119109 if err != nil {
120110 return xerrors .WithStackTrace (err )
121111 }
122112 }
123113
124- b .applyDiscoveredEndpoints (childCtx , endpoints , localDC )
114+ b .applyDiscoveredEndpoints (ctx , endpoints , localDC )
125115
126116 return nil
127117}
0 commit comments