@@ -4,8 +4,10 @@ import (
44	"context" 
55	"fmt" 
66	"sort" 
7+ 	"sync/atomic" 
78
89	"google.golang.org/grpc" 
10+ 	grpcCodes "google.golang.org/grpc/codes" 
911
1012	"github.com/ydb-platform/ydb-go-sdk/v3/config" 
1113	balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" 
@@ -19,7 +21,6 @@ import (
1921	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 
2022	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" 
2123	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 
22- 	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" 
2324	"github.com/ydb-platform/ydb-go-sdk/v3/retry" 
2425	"github.com/ydb-platform/ydb-go-sdk/v3/trace" 
2526)
@@ -40,33 +41,13 @@ type Balancer struct {
4041	discoveryRepeater  repeater.Repeater 
4142	localDCDetector    func (ctx  context.Context , endpoints  []endpoint.Endpoint ) (string , error )
4243
43- 	mu                xsync.RWMutex 
44- 	connectionsState  * connectionsState [conn.Conn ]
44+ 	connections  atomic.Pointer [connections [conn.Conn ]]
4545
4646	closed  chan  struct {}
4747
4848	onApplyDiscoveredEndpoints  []func (ctx  context.Context , endpoints  []endpoint.Info )
4949}
5050
51- func  (b  * Balancer ) HasNode (id  uint32 ) bool  {
52- 	if  b .config .SingleConn  {
53- 		return  true 
54- 	}
55- 	b .mu .RLock ()
56- 	defer  b .mu .RUnlock ()
57- 	if  _ , has  :=  b .connectionsState .connByNodeID [id ]; has  {
58- 		return  true 
59- 	}
60- 
61- 	return  false 
62- }
63- 
64- func  (b  * Balancer ) OnUpdate (onApplyDiscoveredEndpoints  func (ctx  context.Context , endpoints  []endpoint.Info )) {
65- 	b .mu .WithLock (func () {
66- 		b .onApplyDiscoveredEndpoints  =  append (b .onApplyDiscoveredEndpoints , onApplyDiscoveredEndpoints )
67- 	})
68- }
69- 
7051func  (b  * Balancer ) clusterDiscovery (ctx  context.Context ) (err  error ) {
7152	return  retry .Retry (
7253		repeater .WithEvent (ctx , repeater .EventInit ),
@@ -135,37 +116,37 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
135116	return  nil 
136117}
137118
138- func  endpointsDiff (newestEndpoints  []endpoint. Endpoint ,  previousConns  []conn. Info ) (
119+ func  endpointsDiff (newestEndpoints  []trace. EndpointInfo ,  previousEndpoints  []trace. EndpointInfo ) (
139120	nodes  []trace.EndpointInfo ,
140121	added  []trace.EndpointInfo ,
141122	dropped  []trace.EndpointInfo ,
142123) {
143124	nodes  =  make ([]trace.EndpointInfo , 0 , len (newestEndpoints ))
144- 	added  =  make ([]trace.EndpointInfo , 0 , len (previousConns ))
145- 	dropped  =  make ([]trace.EndpointInfo , 0 , len (previousConns ))
125+ 	added  =  make ([]trace.EndpointInfo , 0 , len (previousEndpoints ))
126+ 	dropped  =  make ([]trace.EndpointInfo , 0 , len (previousEndpoints ))
146127	var  (
147128		newestMap    =  make (map [string ]struct {}, len (newestEndpoints ))
148- 		previousMap  =  make (map [string ]struct {}, len (previousConns ))
129+ 		previousMap  =  make (map [string ]struct {}, len (previousEndpoints ))
149130	)
150131	sort .Slice (newestEndpoints , func (i , j  int ) bool  {
151132		return  newestEndpoints [i ].Address () <  newestEndpoints [j ].Address ()
152133	})
153- 	sort .Slice (previousConns , func (i , j  int ) bool  {
154- 		return  previousConns [i ].Endpoint (). Address () <  previousConns [j ]. Endpoint () .Address ()
134+ 	sort .Slice (previousEndpoints , func (i , j  int ) bool  {
135+ 		return  previousEndpoints [i ].Address () <  previousEndpoints [j ].Address ()
155136	})
156- 	for  _ , e  :=  range  previousConns  {
157- 		previousMap [e .Endpoint (). Address ()] =  struct {}{}
137+ 	for  _ , e  :=  range  previousEndpoints  {
138+ 		previousMap [e .Address ()] =  struct {}{}
158139	}
159- 	for  _ , e  :=  range  newestEndpoints  {
160- 		nodes  =  append (nodes , e . Copy () )
161- 		newestMap [e .Address ()] =  struct {}{}
162- 		if  _ , has  :=  previousMap [e .Address ()]; ! has  {
163- 			added  =  append (added , e . Copy () )
140+ 	for  _ , info  :=  range  newestEndpoints  {
141+ 		nodes  =  append (nodes , info )
142+ 		newestMap [info .Address ()] =  struct {}{}
143+ 		if  _ , has  :=  previousMap [info .Address ()]; ! has  {
144+ 			added  =  append (added , info )
164145		}
165146	}
166- 	for  _ , c  :=  range  previousConns  {
167- 		if  _ , has  :=  newestMap [c . Endpoint () .Address ()]; ! has  {
168- 			dropped  =  append (dropped , c . Endpoint (). Copy () )
147+ 	for  _ , info  :=  range  previousEndpoints  {
148+ 		if  _ , has  :=  newestMap [info .Address ()]; ! has  {
149+ 			dropped  =  append (dropped , info )
169150		}
170151	}
171152
@@ -180,41 +161,28 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
180161				"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
181162			b .config .DetectLocalDC ,
182163		)
183- 		previousConns  []conn.Info 
184164	)
185- 	defer  func () {
186- 		nodes , added , dropped  :=  endpointsDiff (endpoints , previousConns )
187- 		onDone (nodes , added , dropped , localDC )
188- 	}()
189165
190166	connections  :=  endpointsToConnections (b .pool , endpoints )
191167	for  _ , c  :=  range  connections  {
192- 		if  c .State () ==  conn .Banned  {
193- 			b .pool .Unban (ctx , c )
194- 		}
195168		c .Endpoint ().Touch ()
196169	}
197170
198171	info  :=  balancerConfig.Info {SelfLocation : localDC }
199- 	state  :=  newConnectionsState (connections , b .config .Filter , info , b .config .AllowFallback )
200- 
201- 	endpointsInfo  :=  make ([]endpoint.Info , len (endpoints ))
202- 	for  i , e  :=  range  endpoints  {
203- 		endpointsInfo [i ] =  e 
204- 	}
205- 
206- 	b .mu .WithLock (func () {
207- 		if  b .connectionsState  !=  nil  {
208- 			previousConns  =  make ([]conn.Info , len (b .connectionsState .all ))
209- 			for  i  :=  range  b .connectionsState .all  {
210- 				previousConns [i ] =  b .connectionsState .all [i ]
211- 			}
212- 		}
213- 		b .connectionsState  =  state 
214- 		for  _ , onApplyDiscoveredEndpoints  :=  range  b .onApplyDiscoveredEndpoints  {
215- 			onApplyDiscoveredEndpoints (ctx , endpointsInfo )
172+ 	newestConnections  :=  newConns (connections , b .config .Filter , info , b .config .AllowFallback )
173+ 	previousConnections  :=  b .connections .Swap (newestConnections )
174+ 	defer  func () {
175+ 		if  previousConnections  !=  nil  {
176+ 			nodes , added , dropped  :=  endpointsDiff (newestConnections .all .ToTraceEndpointInfo (), previousConnections .all .ToTraceEndpointInfo ())
177+ 			onDone (nodes , added , dropped , localDC )
178+ 		} else  {
179+ 			nodes , added , dropped  :=  endpointsDiff (newestConnections .all .ToTraceEndpointInfo (), nil )
180+ 			onDone (nodes , added , dropped , localDC )
216181		}
217- 	})
182+ 	}()
183+ 	for  _ , onApplyDiscoveredEndpoints  :=  range  b .onApplyDiscoveredEndpoints  {
184+ 		onApplyDiscoveredEndpoints (ctx , newestConnections .all .ToEndpointInfo ())
185+ 	}
218186}
219187
220188func  (b  * Balancer ) Close (ctx  context.Context ) (err  error ) {
@@ -241,6 +209,44 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
241209	return  nil 
242210}
243211
212+ func  (b  * Balancer ) markConnAsBad (ctx  context.Context , cc  conn.Conn , cause  error ) {
213+ 	onDone  :=  trace .DriverOnBalancerMarkConnAsBad (
214+ 		b .driverConfig .Trace (), & ctx ,
215+ 		stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).markConnAsBad" ),
216+ 		cc .Endpoint (), cause ,
217+ 	)
218+ 
219+ 	if  ! xerrors .IsTransportError (cause ,
220+ 		grpcCodes .ResourceExhausted ,
221+ 		grpcCodes .Unavailable ,
222+ 		// grpcCodes.OK, 
223+ 		// grpcCodes.Canceled, 
224+ 		// grpcCodes.Unknown, 
225+ 		// grpcCodes.InvalidArgument, 
226+ 		// grpcCodes.DeadlineExceeded, 
227+ 		// grpcCodes.NotFound, 
228+ 		// grpcCodes.AlreadyExists, 
229+ 		// grpcCodes.PermissionDenied, 
230+ 		// grpcCodes.FailedPrecondition, 
231+ 		// grpcCodes.Aborted, 
232+ 		// grpcCodes.OutOfRange, 
233+ 		// grpcCodes.Unimplemented, 
234+ 		// grpcCodes.Internal, 
235+ 		// grpcCodes.DataLoss, 
236+ 		// grpcCodes.Unauthenticated, 
237+ 	) {
238+ 		return 
239+ 	}
240+ 
241+ 	newestConns , changed  :=  b .connections .Load ().withBadConn (cc )
242+ 
243+ 	if  changed  {
244+ 		b .connections .Store (newestConns )
245+ 	}
246+ 
247+ 	onDone (newestConns .prefer .ToTraceEndpointInfo (), newestConns .fallback .ToTraceEndpointInfo ())
248+ }
249+ 
244250func  New (
245251	ctx  context.Context ,
246252	driverConfig  * config.Config ,
@@ -353,10 +359,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
353359	}
354360
355361	defer  func () {
356- 		if  err  ==  nil  {
357- 			b .pool .Unban (ctx , cc )
358- 		} else  if  xerrors .MustBanConn (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
359- 			b .pool .Ban (ctx , cc , err )
362+ 		if  err  !=  nil  {
363+ 			b .markConnAsBad (ctx , cc , err )
360364		}
361365	}()
362366
@@ -383,13 +387,6 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
383387	return  nil 
384388}
385389
386- func  (b  * Balancer ) connections () * connectionsState [conn.Conn ] {
387- 	b .mu .RLock ()
388- 	defer  b .mu .RUnlock ()
389- 
390- 	return  b .connectionsState 
391- }
392- 
393390func  (b  * Balancer ) getConn (ctx  context.Context ) (c  conn.Conn , err  error ) {
394391	onDone  :=  trace .DriverOnBalancerChooseEndpoint (
395392		b .driverConfig .Trace (), & ctx ,
@@ -408,17 +405,17 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
408405	}
409406
410407	var  (
411- 		state         =  b .connections ()
408+ 		connections   =  b .connections . Load ()
412409		failedCount  int 
413410	)
414411
415412	defer  func () {
416- 		if  failedCount * 2  >  state .PreferredCount () &&  b .discoveryRepeater  !=  nil  {
413+ 		if  failedCount * 2  >  connections .PreferredCount () &&  b .discoveryRepeater  !=  nil  {
417414			b .discoveryRepeater .Force ()
418415		}
419416	}()
420417
421- 	c , failedCount  =  state . GetConnection (ctx )
418+ 	c , failedCount  =  connections . GetConn (ctx )
422419	if  c  ==  nil  {
423420		return  nil , xerrors .WithStackTrace (
424421			fmt .Errorf ("cannot get connection from Balancer after %d attempts: %w" , failedCount , ErrNoEndpoints ),
@@ -429,10 +426,10 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
429426}
430427
431428func  endpointsToConnections (p  * conn.Pool , endpoints  []endpoint.Endpoint ) []conn.Conn  {
432- 	conns  :=  make ([]conn.Conn , 0 , len (endpoints ))
429+ 	connections  :=  make ([]conn.Conn , 0 , len (endpoints ))
433430	for  _ , e  :=  range  endpoints  {
434- 		conns  =  append (conns , p .Get (e ))
431+ 		connections  =  append (connections , p .Get (e ))
435432	}
436433
437- 	return  conns 
434+ 	return  connections 
438435}
0 commit comments