@@ -3,8 +3,6 @@ package balancer
33import (
44 "context"
55 "fmt"
6- "sync"
7-
86 "google.golang.org/grpc"
97
108 "github.com/ydb-platform/ydb-go-sdk/v3/config"
@@ -16,24 +14,33 @@ import (
1614 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
1715 "github.com/ydb-platform/ydb-go-sdk/v3/internal/repeater"
1816 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
17+ "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
1918 "github.com/ydb-platform/ydb-go-sdk/v3/trace"
2019)
2120
2221var ErrNoEndpoints = xerrors .Wrap (fmt .Errorf ("no endpoints" ))
2322
24- type balancer struct {
23+ type Balancer struct {
2524 driverConfig config.Config
2625 balancerConfig balancerConfig.Config
2726 pool * conn.Pool
2827 discovery discovery.Client
2928 discoveryRepeater repeater.Repeater
3029 localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
3130
32- m sync .RWMutex
31+ mu xsync .RWMutex
3332 connectionsState * connectionsState
33+
34+ onUpdateEndpoints []func (endpoints []endpoint.Info )
35+ }
36+
37+ func (b * Balancer ) OnUpdateEndpoints (cb func (nodes []endpoint.Info )) {
38+ b .mu .WithLock (func () {
39+ b .onUpdateEndpoints = append (b .onUpdateEndpoints , cb )
40+ })
3441}
3542
36- func (b * balancer ) clusterDiscovery (ctx context.Context ) (err error ) {
43+ func (b * Balancer ) clusterDiscovery (ctx context.Context ) (err error ) {
3744 var (
3845 onDone = trace .DriverOnBalancerUpdate (
3946 b .driverConfig .Trace (),
@@ -73,7 +80,7 @@ func (b *balancer) clusterDiscovery(ctx context.Context) (err error) {
7380 return nil
7481}
7582
76- func (b * balancer ) applyDiscoveredEndpoints (ctx context.Context , endpoints []endpoint.Endpoint , localDC string ) {
83+ func (b * Balancer ) applyDiscoveredEndpoints (ctx context.Context , endpoints []endpoint.Endpoint , localDC string ) {
7784 connections := endpointsToConnections (b .pool , endpoints )
7885 for _ , c := range connections {
7986 b .pool .Allow (ctx , c )
@@ -83,17 +90,23 @@ func (b *balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
8390 info := balancerConfig.Info {SelfLocation : localDC }
8491 state := newConnectionsState (connections , b .balancerConfig .IsPreferConn , info , b .balancerConfig .AllowFalback )
8592
86- b .m .Lock ()
87- defer b .m .Unlock ()
88-
89- b .connectionsState = state
93+ b .mu .WithLock (func () {
94+ b .connectionsState = state
95+ nodes := make ([]endpoint.Info , len (endpoints ))
96+ for i := range endpoints {
97+ nodes [i ] = endpoints [i ]
98+ }
99+ for _ , cb := range b .onUpdateEndpoints {
100+ cb (nodes )
101+ }
102+ })
90103}
91104
92- func (b * balancer ) Discovery () discovery.Client {
105+ func (b * Balancer ) Discovery () discovery.Client {
93106 return b .discovery
94107}
95108
96- func (b * balancer ) Close (ctx context.Context ) (err error ) {
109+ func (b * Balancer ) Close (ctx context.Context ) (err error ) {
97110 onDone := trace .DriverOnBalancerClose (
98111 b .driverConfig .Trace (),
99112 & ctx ,
@@ -113,7 +126,7 @@ func (b *balancer) Close(ctx context.Context) (err error) {
113126 }
114127
115128 if len (issues ) > 0 {
116- return xerrors .WithStackTrace (xerrors .NewWithIssues ("balancer close failed" , issues ... ))
129+ return xerrors .WithStackTrace (xerrors .NewWithIssues ("Balancer close failed" , issues ... ))
117130 }
118131
119132 return nil
@@ -124,7 +137,7 @@ func New(
124137 c config.Config ,
125138 pool * conn.Pool ,
126139 opts ... discoveryConfig.Option ,
127- ) (_ Connection , err error ) {
140+ ) (b * Balancer , err error ) {
128141 onDone := trace .DriverOnBalancerInit (
129142 c .Trace (),
130143 & ctx ,
@@ -133,7 +146,7 @@ func New(
133146 onDone (err )
134147 }()
135148
136- b : = & balancer {
149+ b = & Balancer {
137150 driverConfig : c ,
138151 pool : pool ,
139152 localDCDetector : detectLocalDC ,
@@ -187,19 +200,19 @@ func New(
187200 return b , nil
188201}
189202
190- func (b * balancer ) Endpoint () string {
203+ func (b * Balancer ) Endpoint () string {
191204 return b .driverConfig .Endpoint ()
192205}
193206
194- func (b * balancer ) Name () string {
207+ func (b * Balancer ) Name () string {
195208 return b .driverConfig .Database ()
196209}
197210
198- func (b * balancer ) Secure () bool {
211+ func (b * Balancer ) Secure () bool {
199212 return b .driverConfig .Secure ()
200213}
201214
202- func (b * balancer ) Invoke (
215+ func (b * Balancer ) Invoke (
203216 ctx context.Context ,
204217 method string ,
205218 args interface {},
@@ -211,7 +224,7 @@ func (b *balancer) Invoke(
211224 })
212225}
213226
214- func (b * balancer ) NewStream (
227+ func (b * Balancer ) NewStream (
215228 ctx context.Context ,
216229 desc * grpc.StreamDesc ,
217230 method string ,
@@ -228,7 +241,7 @@ func (b *balancer) NewStream(
228241 return nil , err
229242}
230243
231- func (b * balancer ) wrapCall (ctx context.Context , f func (ctx context.Context , cc conn.Conn ) error ) (err error ) {
244+ func (b * Balancer ) wrapCall (ctx context.Context , f func (ctx context.Context , cc conn.Conn ) error ) (err error ) {
232245 cc , err := b .getConn (ctx )
233246 if err != nil {
234247 return xerrors .WithStackTrace (err )
@@ -257,14 +270,14 @@ func (b *balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
257270 return nil
258271}
259272
260- func (b * balancer ) connections () * connectionsState {
261- b .m .RLock ()
262- defer b .m .RUnlock ()
273+ func (b * Balancer ) connections () * connectionsState {
274+ b .mu .RLock ()
275+ defer b .mu .RUnlock ()
263276
264277 return b .connectionsState
265278}
266279
267- func (b * balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
280+ func (b * Balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
268281 onDone := trace .DriverOnBalancerChooseEndpoint (
269282 b .driverConfig .Trace (),
270283 & ctx ,
@@ -295,7 +308,7 @@ func (b *balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
295308 c , failedCount = state .GetConnection (ctx )
296309 if c == nil {
297310 return nil , xerrors .WithStackTrace (
298- fmt .Errorf ("%w: cannot get connection from balancer after %d attempts" , ErrNoEndpoints , failedCount ),
311+ fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , ErrNoEndpoints , failedCount ),
299312 )
300313 }
301314 return c , nil
0 commit comments