@@ -3,7 +3,6 @@ package balancer
33import (
44 "context"
55 "fmt"
6- "sync"
76
87 "google.golang.org/grpc"
98
@@ -16,24 +15,25 @@ import (
1615 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
1716 "github.com/ydb-platform/ydb-go-sdk/v3/internal/repeater"
1817 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
18+ "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
1919 "github.com/ydb-platform/ydb-go-sdk/v3/trace"
2020)
2121
2222var ErrNoEndpoints = xerrors .Wrap (fmt .Errorf ("no endpoints" ))
2323
24- type balancer struct {
24+ type Balancer struct {
2525 driverConfig config.Config
2626 balancerConfig balancerConfig.Config
2727 pool * conn.Pool
2828 discovery discovery.Client
2929 discoveryRepeater repeater.Repeater
3030 localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
3131
32- m sync .RWMutex
32+ mu xsync .RWMutex
3333 connectionsState * connectionsState
3434}
3535
36- func (b * balancer ) clusterDiscovery (ctx context.Context ) (err error ) {
36+ func (b * Balancer ) clusterDiscovery (ctx context.Context ) (err error ) {
3737 var (
3838 onDone = trace .DriverOnBalancerUpdate (
3939 b .driverConfig .Trace (),
@@ -73,7 +73,7 @@ func (b *balancer) clusterDiscovery(ctx context.Context) (err error) {
7373 return nil
7474}
7575
76- func (b * balancer ) applyDiscoveredEndpoints (ctx context.Context , endpoints []endpoint.Endpoint , localDC string ) {
76+ func (b * Balancer ) applyDiscoveredEndpoints (ctx context.Context , endpoints []endpoint.Endpoint , localDC string ) {
7777 connections := endpointsToConnections (b .pool , endpoints )
7878 for _ , c := range connections {
7979 b .pool .Allow (ctx , c )
@@ -83,17 +83,16 @@ func (b *balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
8383 info := balancerConfig.Info {SelfLocation : localDC }
8484 state := newConnectionsState (connections , b .balancerConfig .IsPreferConn , info , b .balancerConfig .AllowFalback )
8585
86- b .m .Lock ()
87- defer b .m .Unlock ()
88-
89- b .connectionsState = state
86+ b .mu .WithLock (func () {
87+ b .connectionsState = state
88+ })
9089}
9190
92- func (b * balancer ) Discovery () discovery.Client {
91+ func (b * Balancer ) Discovery () discovery.Client {
9392 return b .discovery
9493}
9594
96- func (b * balancer ) Close (ctx context.Context ) (err error ) {
95+ func (b * Balancer ) Close (ctx context.Context ) (err error ) {
9796 onDone := trace .DriverOnBalancerClose (
9897 b .driverConfig .Trace (),
9998 & ctx ,
@@ -113,7 +112,7 @@ func (b *balancer) Close(ctx context.Context) (err error) {
113112 }
114113
115114 if len (issues ) > 0 {
116- return xerrors .WithStackTrace (xerrors .NewWithIssues ("balancer close failed" , issues ... ))
115+ return xerrors .WithStackTrace (xerrors .NewWithIssues ("Balancer close failed" , issues ... ))
117116 }
118117
119118 return nil
@@ -124,7 +123,7 @@ func New(
124123 c config.Config ,
125124 pool * conn.Pool ,
126125 opts ... discoveryConfig.Option ,
127- ) (_ Connection , err error ) {
126+ ) (b * Balancer , err error ) {
128127 onDone := trace .DriverOnBalancerInit (
129128 c .Trace (),
130129 & ctx ,
@@ -133,7 +132,7 @@ func New(
133132 onDone (err )
134133 }()
135134
136- b : = & balancer {
135+ b = & Balancer {
137136 driverConfig : c ,
138137 pool : pool ,
139138 localDCDetector : detectLocalDC ,
@@ -187,19 +186,19 @@ func New(
187186 return b , nil
188187}
189188
190- func (b * balancer ) Endpoint () string {
189+ func (b * Balancer ) Endpoint () string {
191190 return b .driverConfig .Endpoint ()
192191}
193192
194- func (b * balancer ) Name () string {
193+ func (b * Balancer ) Name () string {
195194 return b .driverConfig .Database ()
196195}
197196
198- func (b * balancer ) Secure () bool {
197+ func (b * Balancer ) Secure () bool {
199198 return b .driverConfig .Secure ()
200199}
201200
202- func (b * balancer ) Invoke (
201+ func (b * Balancer ) Invoke (
203202 ctx context.Context ,
204203 method string ,
205204 args interface {},
@@ -211,7 +210,7 @@ func (b *balancer) Invoke(
211210 })
212211}
213212
214- func (b * balancer ) NewStream (
213+ func (b * Balancer ) NewStream (
215214 ctx context.Context ,
216215 desc * grpc.StreamDesc ,
217216 method string ,
@@ -228,7 +227,7 @@ func (b *balancer) NewStream(
228227 return nil , err
229228}
230229
231- func (b * balancer ) wrapCall (ctx context.Context , f func (ctx context.Context , cc conn.Conn ) error ) (err error ) {
230+ func (b * Balancer ) wrapCall (ctx context.Context , f func (ctx context.Context , cc conn.Conn ) error ) (err error ) {
232231 cc , err := b .getConn (ctx )
233232 if err != nil {
234233 return xerrors .WithStackTrace (err )
@@ -257,14 +256,14 @@ func (b *balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
257256 return nil
258257}
259258
260- func (b * balancer ) connections () * connectionsState {
261- b .m .RLock ()
262- defer b .m .RUnlock ()
259+ func (b * Balancer ) connections () * connectionsState {
260+ b .mu .RLock ()
261+ defer b .mu .RUnlock ()
263262
264263 return b .connectionsState
265264}
266265
267- func (b * balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
266+ func (b * Balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
268267 onDone := trace .DriverOnBalancerChooseEndpoint (
269268 b .driverConfig .Trace (),
270269 & ctx ,
@@ -295,7 +294,7 @@ func (b *balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
295294 c , failedCount = state .GetConnection (ctx )
296295 if c == nil {
297296 return nil , xerrors .WithStackTrace (
298- fmt .Errorf ("%w: cannot get connection from balancer after %d attempts" , ErrNoEndpoints , failedCount ),
297+ fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , ErrNoEndpoints , failedCount ),
299298 )
300299 }
301300 return c , nil
0 commit comments