66 "fmt"
77 "net/url"
88 "reflect"
9+ "slices"
910 "sync"
1011 "sync/atomic"
1112 "time"
@@ -94,6 +95,7 @@ type Manager struct {
9495 exchangeName string
9596 features * protocol.Features
9697 m sync.Mutex
98+ subscriptionsMu sync.RWMutex
9799 connections map [Connection ]* websocket
98100 subscriptions * subscription.Store
99101 connector func () error
@@ -303,6 +305,9 @@ func (m *Manager) SetupNewConnection(c *ConnectionSetup) error {
303305 return err
304306 }
305307
308+ m .m .Lock ()
309+ defer m .m .Unlock ()
310+
306311 if c .ResponseCheckTimeout == 0 && c .ResponseMaxLimit == 0 && c .RateLimit == nil && c .URL == "" && c .ConnectionLevelReporter == nil {
307312 return fmt .Errorf ("%w: %w" , errConnSetup , errExchangeConfigEmpty )
308313 }
@@ -349,6 +354,8 @@ func (m *Manager) SetupNewConnection(c *ConnectionSetup) error {
349354 return errMessageFilterNotComparable
350355 }
351356
357+ m .subscriptionsMu .Lock ()
358+ defer m .subscriptionsMu .Unlock ()
352359 for x := range m .connectionManager {
353360 // Below allows for multiple connections to the same URL with different outbound request signatures. This
354361 // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on
@@ -405,6 +412,31 @@ func (m *Manager) createConnectionFromSetup(c *ConnectionSetup) *connection {
405412 }
406413}
407414
415+ func (m * Manager ) snapshotConnectionManager () []* websocket {
416+ m .subscriptionsMu .RLock ()
417+ defer m .subscriptionsMu .RUnlock ()
418+ return slices .Clone (m .connectionManager )
419+ }
420+
421+ func (m * Manager ) snapshotManagedConnections (ws * websocket ) []Connection {
422+ if ws == nil {
423+ return nil
424+ }
425+ m .subscriptionsMu .RLock ()
426+ defer m .subscriptionsMu .RUnlock ()
427+ return slices .Clone (ws .connections )
428+ }
429+
430+ func (m * Manager ) trackConnection (conn Connection , ws * websocket ) {
431+ m .subscriptionsMu .Lock ()
432+ defer m .subscriptionsMu .Unlock ()
433+ if m .connections == nil {
434+ m .connections = make (map [Connection ]* websocket )
435+ }
436+ m .connections [conn ] = ws
437+ ws .connections = append (ws .connections , conn )
438+ }
439+
408440// Connect initiates a websocket connection by using a package defined connection
409441// function
410442func (m * Manager ) Connect (ctx context.Context ) error {
@@ -466,7 +498,8 @@ func (m *Manager) connect(ctx context.Context) error {
466498 return nil
467499 }
468500
469- if len (m .connectionManager ) == 0 {
501+ connectionManager := m .snapshotConnectionManager ()
502+ if len (connectionManager ) == 0 {
470503 m .setState (disconnectedState )
471504 return fmt .Errorf ("cannot connect: %w" , errNoPendingConnections )
472505 }
@@ -479,64 +512,65 @@ func (m *Manager) connect(ctx context.Context) error {
479512 var subscriptionError error
480513
481514 // TODO: Implement concurrency below.
482- for i := range m . connectionManager {
515+ for i := range connectionManager {
483516 var subs subscription.List
484- if ! m .connectionManager [i ].setup .SubscriptionsNotRequired {
485- if m .connectionManager [i ].setup .GenerateSubscriptions == nil {
486- multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , m .connectionManager [i ].setup .URL , errWebsocketSubscriptionsGeneratorUnset )
517+ ws := connectionManager [i ]
518+ if ! ws .setup .SubscriptionsNotRequired {
519+ if ws .setup .GenerateSubscriptions == nil {
520+ multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , ws .setup .URL , errWebsocketSubscriptionsGeneratorUnset )
487521 break
488522 }
489523
490524 var err error
491- subs , err = m . connectionManager [ i ] .setup .GenerateSubscriptions () // regenerate state on new connection
525+ subs , err = ws .setup .GenerateSubscriptions () // regenerate state on new connection
492526 if err != nil {
493527 multiConnectFatalError = fmt .Errorf ("%s websocket: %w" , m .exchangeName , common .AppendError (ErrSubscriptionFailure , err ))
494528 break
495529 }
496530
497531 if len (subs ) == 0 {
498- // If no subscriptions are generated, we skip the connection
532+ // If no subscriptions are generated, we skip this connection.
499533 if m .verbose {
500- log .Warnf (log .WebsocketMgr , "%s websocket: no subscriptions generated" , m .exchangeName )
534+ log .Debugf (log .WebsocketMgr , "%s websocket: no subscriptions generated for [conn:%d] [URL:%s], skipping " , m .exchangeName , i + 1 , ws . setup . URL )
501535 }
502536 continue
503537 }
504538 }
505539
506- if m . connectionManager [ i ] .setup .Connector == nil {
507- multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , m . connectionManager [ i ] .setup .URL , errNoConnectFunc )
540+ if ws .setup .Connector == nil {
541+ multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , ws .setup .URL , errNoConnectFunc )
508542 break
509543 }
510- if m . connectionManager [ i ] .setup .Handler == nil {
511- multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , m . connectionManager [ i ] .setup .URL , errWebsocketDataHandlerUnset )
544+ if ws .setup .Handler == nil {
545+ multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , ws .setup .URL , errWebsocketDataHandlerUnset )
512546 break
513547 }
514- if m . connectionManager [ i ]. setup .Subscriber == nil && ! m . connectionManager [ i ] .setup .SubscriptionsNotRequired {
515- multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , m . connectionManager [ i ] .setup .URL , errWebsocketSubscriberUnset )
548+ if ws . setup .Subscriber == nil && ! ws .setup .SubscriptionsNotRequired {
549+ multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , ws .setup .URL , errWebsocketSubscriberUnset )
516550 break
517551 }
518552
519- if m . connectionManager [ i ] .setup .SubscriptionsNotRequired && len (subs ) == 0 {
520- if err := m .createConnectAndSubscribe (ctx , m . connectionManager [ i ] , nil ); err != nil {
521- multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , m . connectionManager [ i ] .setup .URL , err )
553+ if ws .setup .SubscriptionsNotRequired && len (subs ) == 0 {
554+ if err := m .createConnectAndSubscribe (ctx , ws , nil ); err != nil {
555+ multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , ws .setup .URL , err )
522556 break
523557 }
524558 if m .verbose {
525- log .Debugf (log .WebsocketMgr , "%s websocket: [URL:%s] connected" , m .exchangeName , m . connectionManager [ i ] .setup .URL )
559+ log .Debugf (log .WebsocketMgr , "%s websocket: [URL:%s] connected" , m .exchangeName , ws .setup .URL )
526560 }
527561 continue
528562 }
529563
530564 for _ , batchedSubs := range common .Batch (subs , m .MaxSubscriptionsPerConnection ) {
531- if err := m .createConnectAndSubscribe (ctx , m . connectionManager [ i ] , batchedSubs ); err != nil {
565+ if err := m .createConnectAndSubscribe (ctx , ws , batchedSubs ); err != nil {
532566 if errors .Is (err , common .ErrFatal ) {
533- multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , m . connectionManager [ i ] .setup .URL , err )
567+ multiConnectFatalError = fmt .Errorf ("cannot connect to [conn:%d] [URL:%s]: %w " , i + 1 , ws .setup .URL , err )
534568 break
535569 }
536- subscriptionError = common .AppendError (subscriptionError , fmt .Errorf ("subscription error on [conn:%d] [URL:%s]: %w " , i + 1 , m . connectionManager [ i ] .setup .URL , err ))
570+ subscriptionError = common .AppendError (subscriptionError , fmt .Errorf ("subscription error on [conn:%d] [URL:%s]: %w " , i + 1 , ws .setup .URL , err ))
537571 }
538572 if m .verbose {
539- log .Debugf (log .WebsocketMgr , "%s websocket: [URL:%s] connected. [Total Subs: %d] [Subscribed: %d]" , m .exchangeName , m . connectionManager [ i ] .setup .URL , len (subs ), len (batchedSubs ))
573+ log .Debugf (log .WebsocketMgr , "%s websocket: [URL:%s] connected. [Total Subs: %d] [Subscribed: %d]" , m .exchangeName , ws .setup .URL , len (subs ), len (batchedSubs ))
540574 }
541575 }
542576
@@ -547,7 +581,9 @@ func (m *Manager) connect(ctx context.Context) error {
547581
548582 if multiConnectFatalError != nil {
549583 // Roll back any successful connections and flush subscriptions
550- for _ , ws := range m .connectionManager {
584+ connectionManager = m .snapshotConnectionManager ()
585+ m .subscriptionsMu .Lock ()
586+ for _ , ws := range connectionManager {
551587 for _ , conn := range ws .connections {
552588 if err := conn .Shutdown (); err != nil {
553589 log .Errorln (log .WebsocketMgr , err )
@@ -558,6 +594,7 @@ func (m *Manager) connect(ctx context.Context) error {
558594 ws .subscriptions .Clear ()
559595 }
560596 clear (m .connections )
597+ m .subscriptionsMu .Unlock ()
561598 m .setState (disconnectedState ) // Flip from connecting to disconnected.
562599
563600 // Drain residual error in the single buffered channel, this mitigates
@@ -596,8 +633,7 @@ func (m *Manager) createConnectAndSubscribe(ctx context.Context, ws *websocket,
596633 return fmt .Errorf ("%w: %w" , common .ErrFatal , ErrNotConnected )
597634 }
598635
599- m .connections [conn ] = ws
600- ws .connections = append (ws .connections , conn )
636+ m .trackConnection (conn , ws )
601637
602638 m .Wg .Add (1 )
603639 go m .Reader (ctx , conn , ws .setup .Handler )
@@ -685,6 +721,7 @@ func (m *Manager) shutdown() error {
685721 var nonFatalCloseConnectionErrors error
686722
687723 // Shutdown managed connections
724+ m .subscriptionsMu .Lock ()
688725 for _ , ws := range m .connectionManager {
689726 for _ , conn := range ws .connections {
690727 if err := conn .Shutdown (); err != nil {
@@ -698,6 +735,7 @@ func (m *Manager) shutdown() error {
698735 }
699736 // Clean map of old connections
700737 clear (m .connections )
738+ m .subscriptionsMu .Unlock ()
701739
702740 if m .Conn != nil {
703741 if err := m .Conn .Shutdown (); err != nil {
@@ -840,6 +878,42 @@ func (m *Manager) GetWebsocketURL() string {
840878 return m .runningURL
841879}
842880
881+ // GetConfiguredWebsocketURLs returns known websocket connection URLs.
882+ func (m * Manager ) GetConfiguredWebsocketURLs () []string {
883+ if err := common .NilGuard (m ); err != nil {
884+ return nil
885+ }
886+
887+ m .m .Lock ()
888+ defer m .m .Unlock ()
889+
890+ if m .useMultiConnectionManagement {
891+ m .subscriptionsMu .RLock ()
892+ defer m .subscriptionsMu .RUnlock ()
893+ urls := make ([]string , 0 , len (m .connectionManager ))
894+ seen := make (map [string ]struct {}, len (m .connectionManager ))
895+ for _ , ws := range m .connectionManager {
896+ if ws == nil || ws .setup .URL == "" {
897+ continue
898+ }
899+ if _ , ok := seen [ws .setup .URL ]; ok {
900+ continue
901+ }
902+ seen [ws .setup .URL ] = struct {}{}
903+ urls = append (urls , ws .setup .URL )
904+ }
905+ return urls
906+ }
907+
908+ if m .runningURL != "" {
909+ return []string {m .runningURL }
910+ }
911+ if m .defaultURL != "" {
912+ return []string {m .defaultURL }
913+ }
914+ return nil
915+ }
916+
843917// SetProxyAddress sets websocket proxy address
844918func (m * Manager ) SetProxyAddress (ctx context.Context , proxyAddr string ) error {
845919 m .m .Lock ()
@@ -858,11 +932,13 @@ func (m *Manager) SetProxyAddress(ctx context.Context, proxyAddr string) error {
858932 log .Debugf (log .ExchangeSys , "%s websocket: removing websocket proxy" , m .exchangeName )
859933 }
860934
935+ m .subscriptionsMu .RLock ()
861936 for _ , ws := range m .connectionManager {
862937 for _ , conn := range ws .connections {
863938 conn .SetProxy (proxyAddr )
864939 }
865940 }
941+ m .subscriptionsMu .RUnlock ()
866942 if m .Conn != nil {
867943 m .Conn .SetProxy (proxyAddr )
868944 }
@@ -974,6 +1050,8 @@ func (m *Manager) observeConnection(ctx context.Context, t *time.Timer) (exit bo
9741050 if shutdownErr := m .Shutdown (); shutdownErr != nil {
9751051 log .Errorf (log .WebsocketMgr , "%v websocket: connectionMonitor shutdown err: %s" , m .exchangeName , shutdownErr )
9761052 }
1053+ } else {
1054+ m .state .CompareAndSwap (connectingState , disconnectedState )
9771055 }
9781056 }
9791057 // Speedier reconnection, instead of waiting for the next cycle.
@@ -1077,6 +1155,8 @@ func (m *Manager) GetConnection(messageFilter any) (Connection, error) {
10771155 return nil , ErrNotConnected
10781156 }
10791157
1158+ m .subscriptionsMu .RLock ()
1159+ defer m .subscriptionsMu .RUnlock ()
10801160 for _ , ws := range m .connectionManager {
10811161 if ws .setup .MessageFilter != messageFilter {
10821162 continue
0 commit comments