Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions database/models/postgres/boil_main_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

129 changes: 104 additions & 25 deletions exchange/websocket/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/url"
"reflect"
"slices"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -94,6 +95,7 @@ type Manager struct {
exchangeName string
features *protocol.Features
m sync.Mutex
subscriptionsMu sync.RWMutex
connections map[Connection]*websocket
subscriptions *subscription.Store
connector func() error
Expand Down Expand Up @@ -303,6 +305,9 @@ func (m *Manager) SetupNewConnection(c *ConnectionSetup) error {
return err
}

m.m.Lock()
defer m.m.Unlock()

if c.ResponseCheckTimeout == 0 && c.ResponseMaxLimit == 0 && c.RateLimit == nil && c.URL == "" && c.ConnectionLevelReporter == nil {
return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty)
}
Expand Down Expand Up @@ -349,6 +354,8 @@ func (m *Manager) SetupNewConnection(c *ConnectionSetup) error {
return errMessageFilterNotComparable
}

m.subscriptionsMu.Lock()
defer m.subscriptionsMu.Unlock()
for x := range m.connectionManager {
// Below allows for multiple connections to the same URL with different outbound request signatures. This
// allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on
Expand Down Expand Up @@ -405,6 +412,31 @@ func (m *Manager) createConnectionFromSetup(c *ConnectionSetup) *connection {
}
}

func (m *Manager) snapshotConnectionManager() []*websocket {
m.subscriptionsMu.RLock()
defer m.subscriptionsMu.RUnlock()
return slices.Clone(m.connectionManager)
}

func (m *Manager) snapshotManagedConnections(ws *websocket) []Connection {
if ws == nil {
return nil
}
m.subscriptionsMu.RLock()
defer m.subscriptionsMu.RUnlock()
return slices.Clone(ws.connections)
}

func (m *Manager) trackConnection(conn Connection, ws *websocket) {
m.subscriptionsMu.Lock()
defer m.subscriptionsMu.Unlock()
if m.connections == nil {
m.connections = make(map[Connection]*websocket)
}
m.connections[conn] = ws
ws.connections = append(ws.connections, conn)
}

// Connect initiates a websocket connection by using a package defined connection
// function
func (m *Manager) Connect(ctx context.Context) error {
Expand Down Expand Up @@ -466,7 +498,8 @@ func (m *Manager) connect(ctx context.Context) error {
return nil
}

if len(m.connectionManager) == 0 {
connectionManager := m.snapshotConnectionManager()
if len(connectionManager) == 0 {
m.setState(disconnectedState)
return fmt.Errorf("cannot connect: %w", errNoPendingConnections)
}
Expand All @@ -479,64 +512,64 @@ func (m *Manager) connect(ctx context.Context) error {
var subscriptionError error

// TODO: Implement concurrency below.
for i := range m.connectionManager {
for i, ws := range connectionManager {
var subs subscription.List
if !m.connectionManager[i].setup.SubscriptionsNotRequired {
if m.connectionManager[i].setup.GenerateSubscriptions == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketSubscriptionsGeneratorUnset)
if !ws.setup.SubscriptionsNotRequired {
if ws.setup.GenerateSubscriptions == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, ws.setup.URL, errWebsocketSubscriptionsGeneratorUnset)
break
}

var err error
subs, err = m.connectionManager[i].setup.GenerateSubscriptions() // regenerate state on new connection
subs, err = ws.setup.GenerateSubscriptions() // regenerate state on new connection
if err != nil {
multiConnectFatalError = fmt.Errorf("%s websocket: %w", m.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
break
}

if len(subs) == 0 {
// If no subscriptions are generated, we skip the connection
// If no subscriptions are generated, we skip this connection.
if m.verbose {
log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", m.exchangeName)
log.Debugf(log.WebsocketMgr, "%s websocket: no subscriptions generated for [conn:%d] [URL:%s], skipping", m.exchangeName, i+1, ws.setup.URL)
}
continue
}
}

if m.connectionManager[i].setup.Connector == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errNoConnectFunc)
if ws.setup.Connector == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, ws.setup.URL, errNoConnectFunc)
break
}
if m.connectionManager[i].setup.Handler == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketDataHandlerUnset)
if ws.setup.Handler == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, ws.setup.URL, errWebsocketDataHandlerUnset)
break
}
if m.connectionManager[i].setup.Subscriber == nil && !m.connectionManager[i].setup.SubscriptionsNotRequired {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketSubscriberUnset)
if ws.setup.Subscriber == nil && !ws.setup.SubscriptionsNotRequired {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, ws.setup.URL, errWebsocketSubscriberUnset)
break
}

if m.connectionManager[i].setup.SubscriptionsNotRequired && len(subs) == 0 {
if err := m.createConnectAndSubscribe(ctx, m.connectionManager[i], nil); err != nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, err)
if ws.setup.SubscriptionsNotRequired && len(subs) == 0 {
if err := m.createConnectAndSubscribe(ctx, ws, nil); err != nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, ws.setup.URL, err)
break
}
if m.verbose {
log.Debugf(log.WebsocketMgr, "%s websocket: [URL:%s] connected", m.exchangeName, m.connectionManager[i].setup.URL)
log.Debugf(log.WebsocketMgr, "%s websocket: [URL:%s] connected", m.exchangeName, ws.setup.URL)
}
continue
}

for _, batchedSubs := range common.Batch(subs, m.MaxSubscriptionsPerConnection) {
if err := m.createConnectAndSubscribe(ctx, m.connectionManager[i], batchedSubs); err != nil {
if err := m.createConnectAndSubscribe(ctx, ws, batchedSubs); err != nil {
if errors.Is(err, common.ErrFatal) {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, err)
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, ws.setup.URL, err)
break
}
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("subscription error on [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, err))
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("subscription error on [conn:%d] [URL:%s]: %w ", i+1, ws.setup.URL, err))
}
if m.verbose {
log.Debugf(log.WebsocketMgr, "%s websocket: [URL:%s] connected. [Total Subs: %d] [Subscribed: %d]", m.exchangeName, m.connectionManager[i].setup.URL, len(subs), len(batchedSubs))
log.Debugf(log.WebsocketMgr, "%s websocket: [URL:%s] connected. [Total Subs: %d] [Subscribed: %d]", m.exchangeName, ws.setup.URL, len(subs), len(batchedSubs))
}
}

Expand All @@ -547,7 +580,9 @@ func (m *Manager) connect(ctx context.Context) error {

if multiConnectFatalError != nil {
// Roll back any successful connections and flush subscriptions
for _, ws := range m.connectionManager {
connectionManager = m.snapshotConnectionManager()
m.subscriptionsMu.Lock()
for _, ws := range connectionManager {
for _, conn := range ws.connections {
if err := conn.Shutdown(); err != nil {
log.Errorln(log.WebsocketMgr, err)
Expand All @@ -558,6 +593,7 @@ func (m *Manager) connect(ctx context.Context) error {
ws.subscriptions.Clear()
}
clear(m.connections)
m.subscriptionsMu.Unlock()
m.setState(disconnectedState) // Flip from connecting to disconnected.

// Drain residual error in the single buffered channel, this mitigates
Expand Down Expand Up @@ -596,8 +632,7 @@ func (m *Manager) createConnectAndSubscribe(ctx context.Context, ws *websocket,
return fmt.Errorf("%w: %w", common.ErrFatal, ErrNotConnected)
}

m.connections[conn] = ws
ws.connections = append(ws.connections, conn)
m.trackConnection(conn, ws)

m.Wg.Add(1)
go m.Reader(ctx, conn, ws.setup.Handler)
Expand Down Expand Up @@ -685,6 +720,7 @@ func (m *Manager) shutdown() error {
var nonFatalCloseConnectionErrors error

// Shutdown managed connections
m.subscriptionsMu.Lock()
for _, ws := range m.connectionManager {
for _, conn := range ws.connections {
if err := conn.Shutdown(); err != nil {
Expand All @@ -698,6 +734,7 @@ func (m *Manager) shutdown() error {
}
// Clean map of old connections
clear(m.connections)
m.subscriptionsMu.Unlock()

if m.Conn != nil {
if err := m.Conn.Shutdown(); err != nil {
Expand Down Expand Up @@ -840,6 +877,42 @@ func (m *Manager) GetWebsocketURL() string {
return m.runningURL
}

// GetConfiguredWebsocketURLs returns known websocket connection URLs.
func (m *Manager) GetConfiguredWebsocketURLs() ([]string, error) {
if err := common.NilGuard(m); err != nil {
return nil, err
}

m.m.Lock()
defer m.m.Unlock()

if m.useMultiConnectionManagement {
m.subscriptionsMu.RLock()
defer m.subscriptionsMu.RUnlock()
urls := make([]string, 0, len(m.connectionManager))
seen := make(map[string]struct{}, len(m.connectionManager))
for _, ws := range m.connectionManager {
if ws == nil || ws.setup.URL == "" {
continue
}
if _, ok := seen[ws.setup.URL]; ok {
continue
}
seen[ws.setup.URL] = struct{}{}
urls = append(urls, ws.setup.URL)
}
return urls, nil
}

if m.runningURL != "" {
return []string{m.runningURL}, nil
}
if m.defaultURL != "" {
return []string{m.defaultURL}, nil
}
return nil, nil
}

// SetProxyAddress sets websocket proxy address
func (m *Manager) SetProxyAddress(ctx context.Context, proxyAddr string) error {
m.m.Lock()
Expand All @@ -858,11 +931,13 @@ func (m *Manager) SetProxyAddress(ctx context.Context, proxyAddr string) error {
log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", m.exchangeName)
}

m.subscriptionsMu.RLock()
for _, ws := range m.connectionManager {
for _, conn := range ws.connections {
conn.SetProxy(proxyAddr)
}
}
m.subscriptionsMu.RUnlock()
if m.Conn != nil {
m.Conn.SetProxy(proxyAddr)
}
Expand Down Expand Up @@ -974,6 +1049,8 @@ func (m *Manager) observeConnection(ctx context.Context, t *time.Timer) (exit bo
if shutdownErr := m.Shutdown(); shutdownErr != nil {
log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", m.exchangeName, shutdownErr)
}
} else {
m.state.CompareAndSwap(connectingState, disconnectedState)
}
}
// Speedier reconnection, instead of waiting for the next cycle.
Expand Down Expand Up @@ -1077,6 +1154,8 @@ func (m *Manager) GetConnection(messageFilter any) (Connection, error) {
return nil, ErrNotConnected
}

m.subscriptionsMu.RLock()
defer m.subscriptionsMu.RUnlock()
for _, ws := range m.connectionManager {
if ws.setup.MessageFilter != messageFilter {
continue
Expand Down
Loading
Loading